Added dummy input for one-shot size-checking

This commit is contained in:
Zhengyi Chen 2024-02-29 18:44:42 +00:00
parent dcc3f57596
commit ab633da4a5
4 changed files with 21 additions and 7 deletions

8
checkpoint.py Normal file
View file

@ -0,0 +1,8 @@
import torch
import shutil
def save_checkpoint(state, is_best: bool, task_id, fname="checkpoint.pth.tar"):
fdir = "./"+str(task_id)+"/"
torch.save(state, fdir + fname)
if is_best:
shutil.copyfile(fdir + fname, fdir + "best.pth.tar")

View file

View file

@ -18,8 +18,13 @@ class STNet(nn.Module):
assert 5 > len(input_size) > 2 # single or batch ([N, ]C, H, W)
if len(input_size) == 3:
channels, height, width = input_size
_dummy_size_ = torch.Size([1]) + input_size
else:
channels, height, width = input_size[1:]
_dummy_size_ = input_size
# shape checking
_dummy_x_ = torch.zeros(_dummy_size_)
# (3.1) Spatial transformer localization-network
self.localization_net = nn.Sequential( # (N, C, H, W)
@ -30,11 +35,8 @@ class STNet(nn.Module):
nn.MaxPool2d(2, stride=2), # (N, 10, (((H-6)/2)-4)/2, (((H-6)/2)-4)/2)
nn.ReLU(True)
) # -> (N, 10, (((H-6)/2)-4)/2, (((H-6)/2)-4)/2)
self._loc_net_out_shape = (
10,
(((height - 6) // 2) - 4) // 2,
(((width - 6) // 2) - 4) // 2
) # TODO: PLEASE let me know if there are better ways of doing this...
_dummy_x_ = self.localization_net(_dummy_x_)
self._loc_net_out_shape = _dummy_x_.shape[1:]
# (3.2) Regressor for the 3 * 2 affine matrix
self.fc_loc = nn.Sequential(
@ -53,6 +55,10 @@ class STNet(nn.Module):
dtype=torch.float
)
)
_dummy_x_ = self.fc_loc(
_dummy_x_.view(-1, np.prod(self._loc_net_out_shape))
)
return
# Spatial transformer network forward function
@ -205,8 +211,8 @@ if __name__ == "__main__":
axarr[1].imshow(out_grid)
axarr[1].set_title('Transformed Images')
for epoch in range(1, 20 + 1):
train(epoch)
for epoch in range(10):
train(epoch + 1)
valid()
# Visualize the STN transformation on some input batch

View file