This commit is contained in:
Zhengyi Chen 2024-03-03 19:40:22 +00:00
parent a9dd8dee04
commit ab15419d2f
5 changed files with 63 additions and 19 deletions

View file

@ -69,13 +69,13 @@ class STNet(nn.Module):
theta = self.fc_loc(xs)
theta = theta.view(-1, 2, 3) # -> (2, 3)
grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)
grid = F.affine_grid(theta, x.size(), align_corners=False)
x = F.grid_sample(x, grid, align_corners=False)
# Do the same transformation to t sans training
with torch.no_grad():
t = t.view(t.size(0), 1, t.size(1), t.size(2))
t = F.grid_sample(t, grid)
t = F.grid_sample(t, grid, align_corners=False)
t = t.squeeze(1)
return x, t