Isolated STNet from test pipeline

This commit is contained in:
Zhengyi Chen 2024-02-27 17:55:57 +00:00
parent b8f7f922f1
commit c2a1bb46ef

View file

@ -8,9 +8,9 @@ from torchvision import datasets, transforms
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
class StnNet(nn.Module): class STNet(nn.Module):
def __init__(self, input_size: torch.Size): def __init__(self, input_size: torch.Size):
super(StnNet, self).__init__() super(STNet, self).__init__()
# Sanity check # Sanity check
assert 5 > len(input_size) > 2 # single or batch ([N, ]C, H, W) assert 5 > len(input_size) > 2 # single or batch ([N, ]C, H, W)
@ -19,13 +19,7 @@ class StnNet(nn.Module):
else: else:
channels, height, width = input_size[1:] channels, height, width = input_size[1:]
self.conv1 = nn.Conv2d(1, 10, kernel_size=5) # (3.1) Spatial transformer localization-network
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
# Spatial transformer localization-network
self.localization_net = nn.Sequential( # (N, C, H, W) self.localization_net = nn.Sequential( # (N, C, H, W)
nn.Conv2d(channels, 8, kernel_size=7), # (N, 8, H-6, W-6) nn.Conv2d(channels, 8, kernel_size=7), # (N, 8, H-6, W-6)
nn.MaxPool2d(2, stride=2), # (N, 8, (H-6)/2, (W-6)/2) nn.MaxPool2d(2, stride=2), # (N, 8, (H-6)/2, (W-6)/2)
@ -40,8 +34,9 @@ class StnNet(nn.Module):
(((width - 6) // 2) - 4) // 2 (((width - 6) // 2) - 4) // 2
) # TODO: PLEASE let me know if there are better ways of doing this... ) # TODO: PLEASE let me know if there are better ways of doing this...
# Regressor for the 3 * 2 affine matrix # (3.2) Regressor for the 3 * 2 affine matrix
self.fc_loc = nn.Sequential( self.fc_loc = nn.Sequential(
# XXX: Should
nn.Linear(np.prod(self._loc_net_out_shape), 32), nn.Linear(np.prod(self._loc_net_out_shape), 32),
nn.ReLU(True), nn.ReLU(True),
nn.Linear(32, 3 * 2) nn.Linear(32, 3 * 2)
@ -57,6 +52,7 @@ class StnNet(nn.Module):
) )
) )
# Spatial transformer network forward function # Spatial transformer network forward function
def stn(self, x): def stn(self, x):
xs = self.localization_net(x) xs = self.localization_net(x)
@ -69,11 +65,28 @@ class StnNet(nn.Module):
return x return x
def forward(self, x): def forward(self, x):
# transform the input # transform the input, do nothing else
return self.stn(x)
if __name__ == "__main__":
class STNetDebug(STNet):
def __init__(self, input_size: torch.Size):
super(STNetDebug, self).__init__(input_size)
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
# Transform the input
x = self.stn(x) x = self.stn(x)
# Perform the usual forward pass # Perform usual forward pass for MNIST
x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320) x = x.view(-1, 320)
@ -82,7 +95,7 @@ class StnNet(nn.Module):
x = self.fc2(x) x = self.fc2(x)
return F.log_softmax(x, dim=1) return F.log_softmax(x, dim=1)
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
dataset=datasets.MNIST( dataset=datasets.MNIST(
@ -112,7 +125,7 @@ if __name__ == "__main__":
num_workers=4, num_workers=4,
) )
shape_of_input = next(iter(train_loader))[0].shape shape_of_input = next(iter(train_loader))[0].shape
model = StnNet(shape_of_input).to(device) model = STNetDebug(shape_of_input).to(device)
optimizer = optim.SGD(model.parameters(), lr=.01) optimizer = optim.SGD(model.parameters(), lr=.01)
def train(epoch): def train(epoch):
model.train() model.train()