Yay, works in DP via CPU

This commit is contained in:
Zhengyi Chen 2024-03-03 21:59:58 +00:00
parent ab15419d2f
commit fc941ebaf7
6 changed files with 18 additions and 9 deletions

View file

@ -57,7 +57,10 @@ class SquareCropTransformLayer(nn.Module):
)
# Sum into gt_count
ret_gt_count = torch.sum(split_t.view(split_t.size(0), -1), dim=1)
ret_gt_count = (torch
.sum(split_t.view(split_t.size(0), -1), dim=1)
.unsqueeze(1)
)
return ret_x, ret_gt_count

View file

@ -85,7 +85,6 @@ class VisionTransformerGAP(VisionTransformer):
def forward(self, x, t):
with torch.no_grad():
x, t = self.glue(x, t)
print(f"Glue: {x.shape} | {t.shape}")
x = self.forward_features(x) # Compute encoding
x = F.adaptive_avg_pool1d(x, (48))
x = x.view(x.shape[0], -1) # Move data for regression head