Added dummy input for one-shot size-checking
This commit is contained in:
parent
dcc3f57596
commit
ab633da4a5
4 changed files with 21 additions and 7 deletions
8
checkpoint.py
Normal file
8
checkpoint.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
import torch
|
||||
import shutil
|
||||
|
||||
def save_checkpoint(state, is_best: bool, task_id, fname="checkpoint.pth.tar"):
|
||||
fdir = "./"+str(task_id)+"/"
|
||||
torch.save(state, fdir + fname)
|
||||
if is_best:
|
||||
shutil.copyfile(fdir + fname, fdir + "best.pth.tar")
|
||||
20
model/stn.py
20
model/stn.py
|
|
@ -18,8 +18,13 @@ class STNet(nn.Module):
|
|||
assert 5 > len(input_size) > 2 # single or batch ([N, ]C, H, W)
|
||||
if len(input_size) == 3:
|
||||
channels, height, width = input_size
|
||||
_dummy_size_ = torch.Size([1]) + input_size
|
||||
else:
|
||||
channels, height, width = input_size[1:]
|
||||
_dummy_size_ = input_size
|
||||
|
||||
# shape checking
|
||||
_dummy_x_ = torch.zeros(_dummy_size_)
|
||||
|
||||
# (3.1) Spatial transformer localization-network
|
||||
self.localization_net = nn.Sequential( # (N, C, H, W)
|
||||
|
|
@ -30,11 +35,8 @@ class STNet(nn.Module):
|
|||
nn.MaxPool2d(2, stride=2), # (N, 10, (((H-6)/2)-4)/2, (((H-6)/2)-4)/2)
|
||||
nn.ReLU(True)
|
||||
) # -> (N, 10, (((H-6)/2)-4)/2, (((H-6)/2)-4)/2)
|
||||
self._loc_net_out_shape = (
|
||||
10,
|
||||
(((height - 6) // 2) - 4) // 2,
|
||||
(((width - 6) // 2) - 4) // 2
|
||||
) # TODO: PLEASE let me know if there are better ways of doing this...
|
||||
_dummy_x_ = self.localization_net(_dummy_x_)
|
||||
self._loc_net_out_shape = _dummy_x_.shape[1:]
|
||||
|
||||
# (3.2) Regressor for the 3 * 2 affine matrix
|
||||
self.fc_loc = nn.Sequential(
|
||||
|
|
@ -53,6 +55,10 @@ class STNet(nn.Module):
|
|||
dtype=torch.float
|
||||
)
|
||||
)
|
||||
_dummy_x_ = self.fc_loc(
|
||||
_dummy_x_.view(-1, np.prod(self._loc_net_out_shape))
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
# Spatial transformer network forward function
|
||||
|
|
@ -205,8 +211,8 @@ if __name__ == "__main__":
|
|||
axarr[1].imshow(out_grid)
|
||||
axarr[1].set_title('Transformed Images')
|
||||
|
||||
for epoch in range(1, 20 + 1):
|
||||
train(epoch)
|
||||
for epoch in range(10):
|
||||
train(epoch + 1)
|
||||
valid()
|
||||
|
||||
# Visualize the STN transformation on some input batch
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue