Yay, works in DP via CPU

This commit is contained in:
Zhengyi Chen 2024-03-03 21:59:58 +00:00
parent ab15419d2f
commit fc941ebaf7
6 changed files with 18 additions and 9 deletions

View file

@ -85,7 +85,6 @@ class VisionTransformerGAP(VisionTransformer):
def forward(self, x, t):
with torch.no_grad():
x, t = self.glue(x, t)
print(f"Glue: {x.shape} | {t.shape}")
x = self.forward_features(x) # Compute encoding
x = F.adaptive_avg_pool1d(x, (48))
x = x.view(x.shape[0], -1) # Move data for regression head