Debug
This commit is contained in:
parent
a9dd8dee04
commit
ab15419d2f
5 changed files with 63 additions and 19 deletions
|
|
@ -69,13 +69,13 @@ class STNet(nn.Module):
|
|||
theta = self.fc_loc(xs)
|
||||
theta = theta.view(-1, 2, 3) # -> (2, 3)
|
||||
|
||||
grid = F.affine_grid(theta, x.size())
|
||||
x = F.grid_sample(x, grid)
|
||||
grid = F.affine_grid(theta, x.size(), align_corners=False)
|
||||
x = F.grid_sample(x, grid, align_corners=False)
|
||||
|
||||
# Do the same transformation to t sans training
|
||||
with torch.no_grad():
|
||||
t = t.view(t.size(0), 1, t.size(1), t.size(2))
|
||||
t = F.grid_sample(t, grid)
|
||||
t = F.grid_sample(t, grid, align_corners=False)
|
||||
t = t.squeeze(1)
|
||||
|
||||
return x, t
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue