Extra changes, also not sure if works

This commit is contained in:
Zhengyi Chen 2024-03-02 23:32:19 +00:00
parent f059924b75
commit c88b938680
2 changed files with 16 additions and 7 deletions

View file

@ -62,7 +62,7 @@ class STNet(nn.Module):
# Spatial transformer network forward function # Spatial transformer network forward function
def stn(self, x): def stn(self, x, t):
xs = self.localization_net(x) xs = self.localization_net(x)
xs = xs.view(-1, np.prod(self._loc_net_out_shape)) # -> (N, whatever) xs = xs.view(-1, np.prod(self._loc_net_out_shape)) # -> (N, whatever)
theta = self.fc_loc(xs) theta = self.fc_loc(xs)
@ -71,12 +71,18 @@ class STNet(nn.Module):
grid = F.affine_grid(theta, x.size()) grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid) x = F.grid_sample(x, grid)
return x # Do the same transformation to t sans training
with torch.no_grad():
t = t.view(t.size(0), 1, t.size(1), t.size(2))
t = F.grid_sample(t, grid)
t = t.squeeze(1)
return x, t
def forward(self, x): def forward(self, x, t):
# transform the input, do nothing else # transform the input, do nothing else
return self.stn(x) return self.stn(x, t)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -22,6 +22,7 @@ from timm.models.registry import register_model
from timm.models.layers import trunc_normal_ from timm.models.layers import trunc_normal_
from .stn import STNet from .stn import STNet
from .glue import SquareCropTransformLayer
class VisionTransformerGAP(VisionTransformer): class VisionTransformerGAP(VisionTransformer):
# [XXX] It might be a bad idea to use vision transformer for small datasets. # [XXX] It might be a bad idea to use vision transformer for small datasets.
@ -102,10 +103,12 @@ class STNet_VisionTransformerGAP(VisionTransformerGAP):
def __init__(self, img_shape: torch.Size, *args, **kwargs): def __init__(self, img_shape: torch.Size, *args, **kwargs):
super(STNet_VisionTransformerGAP, self).__init__(*args, **kwargs) super(STNet_VisionTransformerGAP, self).__init__(*args, **kwargs)
self.stnet = STNet(img_shape) self.stnet = STNet(img_shape)
self.glue = SquareCropTransformLayer(img_size)
def forward(self, x): def forward(self, x, t):
x = self.stnet(x) x, t = self.stnet(x, t)
return super(STNet_VisionTransformerGAP, self).forward(x) x, t = self.glue(x, t)
return super(STNet_VisionTransformerGAP, self).forward(x), t
@register_model @register_model