# stole from pytorch tutorial # "Great artists steal" -- they say, # but thieves also steal so you know :P import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision from torchvision import datasets, transforms import matplotlib.pyplot as plt import numpy as np class STNet(nn.Module): def __init__(self, input_size: torch.Size): super(STNet, self).__init__() # Sanity check assert 5 > len(input_size) > 2 # single or batch ([N, ]C, H, W) if len(input_size) == 3: channels, height, width = input_size _dummy_size_ = torch.Size([1]) + input_size else: channels, height, width = input_size[1:] _dummy_size_ = input_size # shape checking _dummy_x_ = torch.zeros(_dummy_size_) # (3.1) Spatial transformer localization-network self.localization_net = nn.Sequential( # (N, C, H, W) nn.Conv2d(channels, 8, kernel_size=7), # (N, 8, H-6, W-6) nn.MaxPool2d(2, stride=2), # (N, 8, (H-6)/2, (W-6)/2) nn.ReLU(True), nn.Conv2d(8, 10, kernel_size=5), # (N, 10, ((H-6)/2)-4, ((W-6)/2)-4) nn.MaxPool2d(2, stride=2), # (N, 10, (((H-6)/2)-4)/2, (((H-6)/2)-4)/2) nn.ReLU(True) ) # -> (N, 10, (((H-6)/2)-4)/2, (((H-6)/2)-4)/2) _dummy_x_ = self.localization_net(_dummy_x_) self._loc_net_out_shape = _dummy_x_.shape[1:] # (3.2) Regressor for the 3 * 2 affine matrix self.fc_loc = nn.Sequential( # XXX: Dimensionality reduction across channels or not? nn.Linear(np.prod(self._loc_net_out_shape), 32), nn.ReLU(True), nn.Linear(32, 3 * 2) ) # -> (6,) # Initialize the weights/bias with identity transformation self.fc_loc[2].weight.data.zero_() self.fc_loc[2].bias.data.copy_( torch.tensor( [1, 0, 0, 0, 1, 0], dtype=torch.float ) ) _dummy_x_ = self.fc_loc( _dummy_x_.view(-1, np.prod(self._loc_net_out_shape)) ) return # Spatial transformer network forward function def stn(self, x, t): xs = self.localization_net(x) xs = xs.view(-1, np.prod(self._loc_net_out_shape)) # -> (N, whatever) theta = self.fc_loc(xs) theta = theta.view(-1, 2, 3) # -> (2, 3) grid = F.affine_grid(theta, x.size(), align_corners=False) x = F.grid_sample(x, grid, align_corners=False) # Do the same transformation to t sans training with torch.no_grad(): t = t.view(t.size(0), 1, t.size(1), t.size(2)) t = F.grid_sample(t, grid, align_corners=False) t = t.squeeze(1) return x, t def forward(self, x, t): # print("STN: {} | {}".format(x.shape, t.shape)) # transform the input, do nothing else return self.stn(x, t) if __name__ == "__main__": class STNetDebug(STNet): def __init__(self, input_size: torch.Size): super(STNetDebug, self).__init__(input_size) self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.conv2_drop = nn.Dropout2d() self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) def forward(self, x): # Transform the input x = self.stn(x) # Perform usual forward pass for MNIST x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) return F.log_softmax(x, dim=1) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_loader = torch.utils.data.DataLoader( dataset=datasets.MNIST( root="./synchronous/", train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((.1307, ), (.3081, )) ]), ), batch_size=64, shuffle=True, num_workers=4, ) valid_loader = torch.utils.data.DataLoader( dataset=datasets.MNIST( root = "./synchronous/", train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((.1307, ), (.3081, )) ]), ), batch_size=64, shuffle=True, num_workers=4, ) shape_of_input = next(iter(train_loader))[0].shape model = STNetDebug(shape_of_input).to(device) optimizer = optim.SGD(model.parameters(), lr=.01) def train(epoch): model.train() for i, (x_, t_) in enumerate(train_loader): # XXX: x_.shape == (N, C, H, W) # Inference x_, t_ = x_.to(device), t_.to(device) optimizer.zero_grad() y_ = model(x_) # Backprop l_ = F.nll_loss(y_, t_) l_.backward() optimizer.step() if i % 500 == 0: print("Epoch {}: [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( epoch, i * len(x_), len(train_loader.dataset), 100. * i / len(train_loader), l_.item() )) def valid(): with torch.no_grad(): model.eval() valid_loss = 0 correct = 0 for x_, t_ in valid_loader: x_, t_ = x_.to(device), t_.to(device) y_ = model(x_) # Sum batch loss valid_loss += F.nll_loss(y_, t_, size_average=False).item() pred = y_.max(1, keepdim=True)[1] correct += pred.eq(t_.view_as(pred)).sum().item() valid_loss /= len(valid_loader.dataset) print("\nValid set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n" .format( valid_loss, correct, len(valid_loader.dataset), 100. * correct / len(valid_loader.dataset) ) ) def convert_image_np(inp): """Convert a Tensor to numpy image.""" inp = inp.numpy().transpose((1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) inp = std * inp + mean inp = np.clip(inp, 0, 1) return inp # We want to visualize the output of the spatial transformers layer # after the training, we visualize a batch of input images and # the corresponding transformed batch using STN. def visualize_stn(): with torch.no_grad(): # Get a batch of training data data = next(iter(valid_loader))[0].to(device) input_tensor = data.cpu() transformed_input_tensor = model.stn(data).cpu() in_grid = convert_image_np( torchvision.utils.make_grid(input_tensor)) out_grid = convert_image_np( torchvision.utils.make_grid(transformed_input_tensor)) # Plot the results side-by-side f, axarr = plt.subplots(1, 2) axarr[0].imshow(in_grid) axarr[0].set_title('Dataset Images') axarr[1].imshow(out_grid) axarr[1].set_title('Transformed Images') for epoch in range(10): train(epoch + 1) valid() # Visualize the STN transformation on some input batch visualize_stn() plt.ioff() plt.show()