Added dummy input for one-shot size-checking
This commit is contained in:
parent
dcc3f57596
commit
ab633da4a5
4 changed files with 21 additions and 7 deletions
8
checkpoint.py
Normal file
8
checkpoint.py
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
import torch
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
def save_checkpoint(state, is_best: bool, task_id, fname="checkpoint.pth.tar"):
|
||||||
|
fdir = "./"+str(task_id)+"/"
|
||||||
|
torch.save(state, fdir + fname)
|
||||||
|
if is_best:
|
||||||
|
shutil.copyfile(fdir + fname, fdir + "best.pth.tar")
|
||||||
20
model/stn.py
20
model/stn.py
|
|
@ -18,8 +18,13 @@ class STNet(nn.Module):
|
||||||
assert 5 > len(input_size) > 2 # single or batch ([N, ]C, H, W)
|
assert 5 > len(input_size) > 2 # single or batch ([N, ]C, H, W)
|
||||||
if len(input_size) == 3:
|
if len(input_size) == 3:
|
||||||
channels, height, width = input_size
|
channels, height, width = input_size
|
||||||
|
_dummy_size_ = torch.Size([1]) + input_size
|
||||||
else:
|
else:
|
||||||
channels, height, width = input_size[1:]
|
channels, height, width = input_size[1:]
|
||||||
|
_dummy_size_ = input_size
|
||||||
|
|
||||||
|
# shape checking
|
||||||
|
_dummy_x_ = torch.zeros(_dummy_size_)
|
||||||
|
|
||||||
# (3.1) Spatial transformer localization-network
|
# (3.1) Spatial transformer localization-network
|
||||||
self.localization_net = nn.Sequential( # (N, C, H, W)
|
self.localization_net = nn.Sequential( # (N, C, H, W)
|
||||||
|
|
@ -30,11 +35,8 @@ class STNet(nn.Module):
|
||||||
nn.MaxPool2d(2, stride=2), # (N, 10, (((H-6)/2)-4)/2, (((H-6)/2)-4)/2)
|
nn.MaxPool2d(2, stride=2), # (N, 10, (((H-6)/2)-4)/2, (((H-6)/2)-4)/2)
|
||||||
nn.ReLU(True)
|
nn.ReLU(True)
|
||||||
) # -> (N, 10, (((H-6)/2)-4)/2, (((H-6)/2)-4)/2)
|
) # -> (N, 10, (((H-6)/2)-4)/2, (((H-6)/2)-4)/2)
|
||||||
self._loc_net_out_shape = (
|
_dummy_x_ = self.localization_net(_dummy_x_)
|
||||||
10,
|
self._loc_net_out_shape = _dummy_x_.shape[1:]
|
||||||
(((height - 6) // 2) - 4) // 2,
|
|
||||||
(((width - 6) // 2) - 4) // 2
|
|
||||||
) # TODO: PLEASE let me know if there are better ways of doing this...
|
|
||||||
|
|
||||||
# (3.2) Regressor for the 3 * 2 affine matrix
|
# (3.2) Regressor for the 3 * 2 affine matrix
|
||||||
self.fc_loc = nn.Sequential(
|
self.fc_loc = nn.Sequential(
|
||||||
|
|
@ -53,6 +55,10 @@ class STNet(nn.Module):
|
||||||
dtype=torch.float
|
dtype=torch.float
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
_dummy_x_ = self.fc_loc(
|
||||||
|
_dummy_x_.view(-1, np.prod(self._loc_net_out_shape))
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
# Spatial transformer network forward function
|
# Spatial transformer network forward function
|
||||||
|
|
@ -205,8 +211,8 @@ if __name__ == "__main__":
|
||||||
axarr[1].imshow(out_grid)
|
axarr[1].imshow(out_grid)
|
||||||
axarr[1].set_title('Transformed Images')
|
axarr[1].set_title('Transformed Images')
|
||||||
|
|
||||||
for epoch in range(1, 20 + 1):
|
for epoch in range(10):
|
||||||
train(epoch)
|
train(epoch + 1)
|
||||||
valid()
|
valid()
|
||||||
|
|
||||||
# Visualize the STN transformation on some input batch
|
# Visualize the STN transformation on some input batch
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue