From c88b93868070559888f02bfcc5a335123942bce9 Mon Sep 17 00:00:00 2001 From: Zhengyi Chen Date: Sat, 2 Mar 2024 23:32:19 +0000 Subject: [PATCH] Extra changes, also not sure if works --- model/stn.py | 14 ++++++++++---- model/transcrowd_gap.py | 9 ++++++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/model/stn.py b/model/stn.py index e9cd307..53fa8ff 100644 --- a/model/stn.py +++ b/model/stn.py @@ -62,7 +62,7 @@ class STNet(nn.Module): # Spatial transformer network forward function - def stn(self, x): + def stn(self, x, t): xs = self.localization_net(x) xs = xs.view(-1, np.prod(self._loc_net_out_shape)) # -> (N, whatever) theta = self.fc_loc(xs) @@ -71,12 +71,18 @@ class STNet(nn.Module): grid = F.affine_grid(theta, x.size()) x = F.grid_sample(x, grid) - return x + # Do the same transformation to t sans training + with torch.no_grad(): + t = t.view(t.size(0), 1, t.size(1), t.size(2)) + t = F.grid_sample(t, grid) + t = t.squeeze(1) + + return x, t - def forward(self, x): + def forward(self, x, t): # transform the input, do nothing else - return self.stn(x) + return self.stn(x, t) if __name__ == "__main__": diff --git a/model/transcrowd_gap.py b/model/transcrowd_gap.py index ebd807f..a733b79 100644 --- a/model/transcrowd_gap.py +++ b/model/transcrowd_gap.py @@ -22,6 +22,7 @@ from timm.models.registry import register_model from timm.models.layers import trunc_normal_ from .stn import STNet +from .glue import SquareCropTransformLayer class VisionTransformerGAP(VisionTransformer): # [XXX] It might be a bad idea to use vision transformer for small datasets. @@ -102,10 +103,12 @@ class STNet_VisionTransformerGAP(VisionTransformerGAP): def __init__(self, img_shape: torch.Size, *args, **kwargs): super(STNet_VisionTransformerGAP, self).__init__(*args, **kwargs) self.stnet = STNet(img_shape) + self.glue = SquareCropTransformLayer(img_size) - def forward(self, x): - x = self.stnet(x) - return super(STNet_VisionTransformerGAP, self).forward(x) + def forward(self, x, t): + x, t = self.stnet(x, t) + x, t = self.glue(x, t) + return super(STNet_VisionTransformerGAP, self).forward(x), t @register_model