Yay, works in DP via CPU

This commit is contained in:
Zhengyi Chen 2024-03-03 21:59:58 +00:00
parent ab15419d2f
commit fc941ebaf7
6 changed files with 18 additions and 9 deletions

View file

@ -57,7 +57,10 @@ class SquareCropTransformLayer(nn.Module):
)
# Sum into gt_count
ret_gt_count = torch.sum(split_t.view(split_t.size(0), -1), dim=1)
ret_gt_count = (torch
.sum(split_t.view(split_t.size(0), -1), dim=1)
.unsqueeze(1)
)
return ret_x, ret_gt_count