From c2a1bb46efd5725d4b1dee03692065cf3cd78bbc Mon Sep 17 00:00:00 2001 From: rubberhead Date: Tue, 27 Feb 2024 17:55:57 +0000 Subject: [PATCH] Isolated STNet from test pipeline --- model/stn.py | 57 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 22 deletions(-) diff --git a/model/stn.py b/model/stn.py index 0fcf786..9cadd05 100644 --- a/model/stn.py +++ b/model/stn.py @@ -8,9 +8,9 @@ from torchvision import datasets, transforms import matplotlib.pyplot as plt import numpy as np -class StnNet(nn.Module): +class STNet(nn.Module): def __init__(self, input_size: torch.Size): - super(StnNet, self).__init__() + super(STNet, self).__init__() # Sanity check assert 5 > len(input_size) > 2 # single or batch ([N, ]C, H, W) @@ -19,13 +19,7 @@ class StnNet(nn.Module): else: channels, height, width = input_size[1:] - self.conv1 = nn.Conv2d(1, 10, kernel_size=5) - self.conv2 = nn.Conv2d(10, 20, kernel_size=5) - self.conv2_drop = nn.Dropout2d() - self.fc1 = nn.Linear(320, 50) - self.fc2 = nn.Linear(50, 10) - - # Spatial transformer localization-network + # (3.1) Spatial transformer localization-network self.localization_net = nn.Sequential( # (N, C, H, W) nn.Conv2d(channels, 8, kernel_size=7), # (N, 8, H-6, W-6) nn.MaxPool2d(2, stride=2), # (N, 8, (H-6)/2, (W-6)/2) @@ -40,8 +34,9 @@ class StnNet(nn.Module): (((width - 6) // 2) - 4) // 2 ) # TODO: PLEASE let me know if there are better ways of doing this... - # Regressor for the 3 * 2 affine matrix + # (3.2) Regressor for the 3 * 2 affine matrix self.fc_loc = nn.Sequential( + # XXX: Should nn.Linear(np.prod(self._loc_net_out_shape), 32), nn.ReLU(True), nn.Linear(32, 3 * 2) @@ -57,6 +52,7 @@ class StnNet(nn.Module): ) ) + # Spatial transformer network forward function def stn(self, x): xs = self.localization_net(x) @@ -69,20 +65,37 @@ class StnNet(nn.Module): return x - def forward(self, x): - # transform the input - x = self.stn(x) - # Perform the usual forward pass - x = F.relu(F.max_pool2d(self.conv1(x), 2)) - x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) - x = x.view(-1, 320) - x = F.relu(self.fc1(x)) - x = F.dropout(x, training=self.training) - x = self.fc2(x) - return F.log_softmax(x, dim=1) + def forward(self, x): + # transform the input, do nothing else + return self.stn(x) + if __name__ == "__main__": + class STNetDebug(STNet): + def __init__(self, input_size: torch.Size): + super(STNetDebug, self).__init__(input_size) + + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + # Transform the input + x = self.stn(x) + + # Perform usual forward pass for MNIST + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_loader = torch.utils.data.DataLoader( dataset=datasets.MNIST( @@ -112,7 +125,7 @@ if __name__ == "__main__": num_workers=4, ) shape_of_input = next(iter(train_loader))[0].shape - model = StnNet(shape_of_input).to(device) + model = STNetDebug(shape_of_input).to(device) optimizer = optim.SGD(model.parameters(), lr=.01) def train(epoch): model.train()