diff --git a/checkpoint.py b/checkpoint.py new file mode 100644 index 0000000..0140058 --- /dev/null +++ b/checkpoint.py @@ -0,0 +1,8 @@ +import torch +import shutil + +def save_checkpoint(state, is_best: bool, task_id, fname="checkpoint.pth.tar"): + fdir = "./"+str(task_id)+"/" + torch.save(state, fdir + fname) + if is_best: + shutil.copyfile(fdir + fname, fdir + "best.pth.tar") diff --git a/eval-transcrowd.py b/eval-transcrowd.py deleted file mode 100644 index e69de29..0000000 diff --git a/model/stn.py b/model/stn.py index 552b70b..e9cd307 100644 --- a/model/stn.py +++ b/model/stn.py @@ -18,8 +18,13 @@ class STNet(nn.Module): assert 5 > len(input_size) > 2 # single or batch ([N, ]C, H, W) if len(input_size) == 3: channels, height, width = input_size + _dummy_size_ = torch.Size([1]) + input_size else: channels, height, width = input_size[1:] + _dummy_size_ = input_size + + # shape checking + _dummy_x_ = torch.zeros(_dummy_size_) # (3.1) Spatial transformer localization-network self.localization_net = nn.Sequential( # (N, C, H, W) @@ -30,11 +35,8 @@ class STNet(nn.Module): nn.MaxPool2d(2, stride=2), # (N, 10, (((H-6)/2)-4)/2, (((H-6)/2)-4)/2) nn.ReLU(True) ) # -> (N, 10, (((H-6)/2)-4)/2, (((H-6)/2)-4)/2) - self._loc_net_out_shape = ( - 10, - (((height - 6) // 2) - 4) // 2, - (((width - 6) // 2) - 4) // 2 - ) # TODO: PLEASE let me know if there are better ways of doing this... + _dummy_x_ = self.localization_net(_dummy_x_) + self._loc_net_out_shape = _dummy_x_.shape[1:] # (3.2) Regressor for the 3 * 2 affine matrix self.fc_loc = nn.Sequential( @@ -53,6 +55,10 @@ class STNet(nn.Module): dtype=torch.float ) ) + _dummy_x_ = self.fc_loc( + _dummy_x_.view(-1, np.prod(self._loc_net_out_shape)) + ) + return # Spatial transformer network forward function @@ -205,8 +211,8 @@ if __name__ == "__main__": axarr[1].imshow(out_grid) axarr[1].set_title('Transformed Images') - for epoch in range(1, 20 + 1): - train(epoch) + for epoch in range(10): + train(epoch + 1) valid() # Visualize the STN transformation on some input batch diff --git a/train-dist.py b/train-dist.py deleted file mode 100644 index e69de29..0000000