Isolated STNet from test pipeline
This commit is contained in:
parent
b8f7f922f1
commit
c2a1bb46ef
1 changed files with 35 additions and 22 deletions
57
model/stn.py
57
model/stn.py
|
|
@ -8,9 +8,9 @@ from torchvision import datasets, transforms
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
class StnNet(nn.Module):
|
class STNet(nn.Module):
|
||||||
def __init__(self, input_size: torch.Size):
|
def __init__(self, input_size: torch.Size):
|
||||||
super(StnNet, self).__init__()
|
super(STNet, self).__init__()
|
||||||
|
|
||||||
# Sanity check
|
# Sanity check
|
||||||
assert 5 > len(input_size) > 2 # single or batch ([N, ]C, H, W)
|
assert 5 > len(input_size) > 2 # single or batch ([N, ]C, H, W)
|
||||||
|
|
@ -19,13 +19,7 @@ class StnNet(nn.Module):
|
||||||
else:
|
else:
|
||||||
channels, height, width = input_size[1:]
|
channels, height, width = input_size[1:]
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
# (3.1) Spatial transformer localization-network
|
||||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
|
||||||
self.conv2_drop = nn.Dropout2d()
|
|
||||||
self.fc1 = nn.Linear(320, 50)
|
|
||||||
self.fc2 = nn.Linear(50, 10)
|
|
||||||
|
|
||||||
# Spatial transformer localization-network
|
|
||||||
self.localization_net = nn.Sequential( # (N, C, H, W)
|
self.localization_net = nn.Sequential( # (N, C, H, W)
|
||||||
nn.Conv2d(channels, 8, kernel_size=7), # (N, 8, H-6, W-6)
|
nn.Conv2d(channels, 8, kernel_size=7), # (N, 8, H-6, W-6)
|
||||||
nn.MaxPool2d(2, stride=2), # (N, 8, (H-6)/2, (W-6)/2)
|
nn.MaxPool2d(2, stride=2), # (N, 8, (H-6)/2, (W-6)/2)
|
||||||
|
|
@ -40,8 +34,9 @@ class StnNet(nn.Module):
|
||||||
(((width - 6) // 2) - 4) // 2
|
(((width - 6) // 2) - 4) // 2
|
||||||
) # TODO: PLEASE let me know if there are better ways of doing this...
|
) # TODO: PLEASE let me know if there are better ways of doing this...
|
||||||
|
|
||||||
# Regressor for the 3 * 2 affine matrix
|
# (3.2) Regressor for the 3 * 2 affine matrix
|
||||||
self.fc_loc = nn.Sequential(
|
self.fc_loc = nn.Sequential(
|
||||||
|
# XXX: Should
|
||||||
nn.Linear(np.prod(self._loc_net_out_shape), 32),
|
nn.Linear(np.prod(self._loc_net_out_shape), 32),
|
||||||
nn.ReLU(True),
|
nn.ReLU(True),
|
||||||
nn.Linear(32, 3 * 2)
|
nn.Linear(32, 3 * 2)
|
||||||
|
|
@ -57,6 +52,7 @@ class StnNet(nn.Module):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Spatial transformer network forward function
|
# Spatial transformer network forward function
|
||||||
def stn(self, x):
|
def stn(self, x):
|
||||||
xs = self.localization_net(x)
|
xs = self.localization_net(x)
|
||||||
|
|
@ -69,20 +65,37 @@ class StnNet(nn.Module):
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# transform the input
|
|
||||||
x = self.stn(x)
|
|
||||||
|
|
||||||
# Perform the usual forward pass
|
def forward(self, x):
|
||||||
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
# transform the input, do nothing else
|
||||||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
return self.stn(x)
|
||||||
x = x.view(-1, 320)
|
|
||||||
x = F.relu(self.fc1(x))
|
|
||||||
x = F.dropout(x, training=self.training)
|
|
||||||
x = self.fc2(x)
|
|
||||||
return F.log_softmax(x, dim=1)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
class STNetDebug(STNet):
|
||||||
|
def __init__(self, input_size: torch.Size):
|
||||||
|
super(STNetDebug, self).__init__(input_size)
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||||
|
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||||
|
self.conv2_drop = nn.Dropout2d()
|
||||||
|
self.fc1 = nn.Linear(320, 50)
|
||||||
|
self.fc2 = nn.Linear(50, 10)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Transform the input
|
||||||
|
x = self.stn(x)
|
||||||
|
|
||||||
|
# Perform usual forward pass for MNIST
|
||||||
|
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
||||||
|
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
||||||
|
x = x.view(-1, 320)
|
||||||
|
x = F.relu(self.fc1(x))
|
||||||
|
x = F.dropout(x, training=self.training)
|
||||||
|
x = self.fc2(x)
|
||||||
|
return F.log_softmax(x, dim=1)
|
||||||
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
train_loader = torch.utils.data.DataLoader(
|
train_loader = torch.utils.data.DataLoader(
|
||||||
dataset=datasets.MNIST(
|
dataset=datasets.MNIST(
|
||||||
|
|
@ -112,7 +125,7 @@ if __name__ == "__main__":
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
)
|
)
|
||||||
shape_of_input = next(iter(train_loader))[0].shape
|
shape_of_input = next(iter(train_loader))[0].shape
|
||||||
model = StnNet(shape_of_input).to(device)
|
model = STNetDebug(shape_of_input).to(device)
|
||||||
optimizer = optim.SGD(model.parameters(), lr=.01)
|
optimizer = optim.SGD(model.parameters(), lr=.01)
|
||||||
def train(epoch):
|
def train(epoch):
|
||||||
model.train()
|
model.train()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue