232 lines
7.7 KiB
Python
232 lines
7.7 KiB
Python
# stole from pytorch tutorial
|
|
# "Great artists steal" -- they say,
|
|
# but thieves also steal so you know :P
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
import torchvision
|
|
from torchvision import datasets, transforms
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
class STNet(nn.Module):
|
|
def __init__(self, input_size: torch.Size):
|
|
super(STNet, self).__init__()
|
|
|
|
# Sanity check
|
|
assert 5 > len(input_size) > 2 # single or batch ([N, ]C, H, W)
|
|
if len(input_size) == 3:
|
|
channels, height, width = input_size
|
|
_dummy_size_ = torch.Size([1]) + input_size
|
|
else:
|
|
channels, height, width = input_size[1:]
|
|
_dummy_size_ = input_size
|
|
|
|
# shape checking
|
|
_dummy_x_ = torch.zeros(_dummy_size_)
|
|
|
|
# (3.1) Spatial transformer localization-network
|
|
self.localization_net = nn.Sequential( # (N, C, H, W)
|
|
nn.Conv2d(channels, 8, kernel_size=7), # (N, 8, H-6, W-6)
|
|
nn.MaxPool2d(2, stride=2), # (N, 8, (H-6)/2, (W-6)/2)
|
|
nn.ReLU(True),
|
|
nn.Conv2d(8, 10, kernel_size=5), # (N, 10, ((H-6)/2)-4, ((W-6)/2)-4)
|
|
nn.MaxPool2d(2, stride=2), # (N, 10, (((H-6)/2)-4)/2, (((H-6)/2)-4)/2)
|
|
nn.ReLU(True)
|
|
) # -> (N, 10, (((H-6)/2)-4)/2, (((H-6)/2)-4)/2)
|
|
_dummy_x_ = self.localization_net(_dummy_x_)
|
|
self._loc_net_out_shape = _dummy_x_.shape[1:]
|
|
|
|
# (3.2) Regressor for the 3 * 2 affine matrix
|
|
self.fc_loc = nn.Sequential(
|
|
# XXX: Dimensionality reduction across channels or not?
|
|
nn.Linear(np.prod(self._loc_net_out_shape), 32),
|
|
nn.ReLU(True),
|
|
nn.Linear(32, 3 * 2)
|
|
) # -> (6,)
|
|
|
|
# Initialize the weights/bias with identity transformation
|
|
self.fc_loc[2].weight.data.zero_()
|
|
self.fc_loc[2].bias.data.copy_(
|
|
torch.tensor(
|
|
[1, 0, 0,
|
|
0, 1, 0],
|
|
dtype=torch.float
|
|
)
|
|
)
|
|
_dummy_x_ = self.fc_loc(
|
|
_dummy_x_.view(-1, np.prod(self._loc_net_out_shape))
|
|
)
|
|
return
|
|
|
|
|
|
# Spatial transformer network forward function
|
|
def stn(self, x, t):
|
|
xs = self.localization_net(x)
|
|
xs = xs.view(-1, np.prod(self._loc_net_out_shape)) # -> (N, whatever)
|
|
theta = self.fc_loc(xs)
|
|
theta = theta.view(-1, 2, 3) # -> (2, 3)
|
|
|
|
grid = F.affine_grid(theta, x.size(), align_corners=False)
|
|
x = F.grid_sample(x, grid, align_corners=False)
|
|
|
|
# Do the same transformation to t sans training
|
|
with torch.no_grad():
|
|
t = t.view(t.size(0), 1, t.size(1), t.size(2))
|
|
t = F.grid_sample(t, grid, align_corners=False)
|
|
t = t.squeeze(1)
|
|
|
|
return x, t
|
|
|
|
|
|
def forward(self, x, t):
|
|
# print("STN: {} | {}".format(x.shape, t.shape))
|
|
# transform the input, do nothing else
|
|
return self.stn(x, t)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
class STNetDebug(STNet):
|
|
def __init__(self, input_size: torch.Size):
|
|
super(STNetDebug, self).__init__(input_size)
|
|
|
|
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
|
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
|
self.conv2_drop = nn.Dropout2d()
|
|
self.fc1 = nn.Linear(320, 50)
|
|
self.fc2 = nn.Linear(50, 10)
|
|
|
|
def forward(self, x):
|
|
# Transform the input
|
|
x = self.stn(x)
|
|
|
|
# Perform usual forward pass for MNIST
|
|
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
|
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
|
x = x.view(-1, 320)
|
|
x = F.relu(self.fc1(x))
|
|
x = F.dropout(x, training=self.training)
|
|
x = self.fc2(x)
|
|
return F.log_softmax(x, dim=1)
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
train_loader = torch.utils.data.DataLoader(
|
|
dataset=datasets.MNIST(
|
|
root="./synchronous/",
|
|
train=True,
|
|
download=True,
|
|
transform=transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((.1307, ), (.3081, ))
|
|
]),
|
|
),
|
|
batch_size=64,
|
|
shuffle=True,
|
|
num_workers=4,
|
|
)
|
|
valid_loader = torch.utils.data.DataLoader(
|
|
dataset=datasets.MNIST(
|
|
root = "./synchronous/",
|
|
train=False,
|
|
transform=transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((.1307, ), (.3081, ))
|
|
]),
|
|
),
|
|
batch_size=64,
|
|
shuffle=True,
|
|
num_workers=4,
|
|
)
|
|
shape_of_input = next(iter(train_loader))[0].shape
|
|
model = STNetDebug(shape_of_input).to(device)
|
|
optimizer = optim.SGD(model.parameters(), lr=.01)
|
|
def train(epoch):
|
|
model.train()
|
|
for i, (x_, t_) in enumerate(train_loader):
|
|
# XXX: x_.shape == (N, C, H, W)
|
|
# Inference
|
|
x_, t_ = x_.to(device), t_.to(device)
|
|
optimizer.zero_grad()
|
|
y_ = model(x_)
|
|
# Backprop
|
|
l_ = F.nll_loss(y_, t_)
|
|
l_.backward()
|
|
optimizer.step()
|
|
if i % 500 == 0:
|
|
print("Epoch {}: [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
|
|
epoch,
|
|
i * len(x_),
|
|
len(train_loader.dataset),
|
|
100. * i / len(train_loader),
|
|
l_.item()
|
|
))
|
|
|
|
def valid():
|
|
with torch.no_grad():
|
|
model.eval()
|
|
valid_loss = 0
|
|
correct = 0
|
|
for x_, t_ in valid_loader:
|
|
x_, t_ = x_.to(device), t_.to(device)
|
|
y_ = model(x_)
|
|
# Sum batch loss
|
|
valid_loss += F.nll_loss(y_, t_, size_average=False).item()
|
|
pred = y_.max(1, keepdim=True)[1]
|
|
correct += pred.eq(t_.view_as(pred)).sum().item()
|
|
|
|
valid_loss /= len(valid_loader.dataset)
|
|
print("\nValid set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n"
|
|
.format(
|
|
valid_loss, correct, len(valid_loader.dataset),
|
|
100. * correct / len(valid_loader.dataset)
|
|
)
|
|
)
|
|
|
|
def convert_image_np(inp):
|
|
"""Convert a Tensor to numpy image."""
|
|
inp = inp.numpy().transpose((1, 2, 0))
|
|
mean = np.array([0.485, 0.456, 0.406])
|
|
std = np.array([0.229, 0.224, 0.225])
|
|
inp = std * inp + mean
|
|
inp = np.clip(inp, 0, 1)
|
|
return inp
|
|
|
|
# We want to visualize the output of the spatial transformers layer
|
|
# after the training, we visualize a batch of input images and
|
|
# the corresponding transformed batch using STN.
|
|
def visualize_stn():
|
|
with torch.no_grad():
|
|
# Get a batch of training data
|
|
data = next(iter(valid_loader))[0].to(device)
|
|
|
|
input_tensor = data.cpu()
|
|
transformed_input_tensor = model.stn(data).cpu()
|
|
|
|
in_grid = convert_image_np(
|
|
torchvision.utils.make_grid(input_tensor))
|
|
|
|
out_grid = convert_image_np(
|
|
torchvision.utils.make_grid(transformed_input_tensor))
|
|
|
|
# Plot the results side-by-side
|
|
f, axarr = plt.subplots(1, 2)
|
|
axarr[0].imshow(in_grid)
|
|
axarr[0].set_title('Dataset Images')
|
|
|
|
axarr[1].imshow(out_grid)
|
|
axarr[1].set_title('Transformed Images')
|
|
|
|
for epoch in range(10):
|
|
train(epoch + 1)
|
|
valid()
|
|
|
|
# Visualize the STN transformation on some input batch
|
|
visualize_stn()
|
|
|
|
plt.ioff()
|
|
plt.show()
|
|
|
|
|
|
|