diff --git a/model/stn.py b/model/stn.py index 9cadd05..552b70b 100644 --- a/model/stn.py +++ b/model/stn.py @@ -1,4 +1,6 @@ # stole from pytorch tutorial +# "Great artists steal" -- they say, +# but thieves also steal so you know :P import torch import torch.nn as nn import torch.nn.functional as F @@ -36,7 +38,7 @@ class STNet(nn.Module): # (3.2) Regressor for the 3 * 2 affine matrix self.fc_loc = nn.Sequential( - # XXX: Should + # XXX: Dimensionality reduction across channels or not? nn.Linear(np.prod(self._loc_net_out_shape), 32), nn.ReLU(True), nn.Linear(32, 3 * 2) diff --git a/model/transcrowd_gap.py b/model/transcrowd_gap.py index c9ae9c0..c6158e6 100644 --- a/model/transcrowd_gap.py +++ b/model/transcrowd_gap.py @@ -20,7 +20,9 @@ from timm.models.vision_transformer import VisionTransformer, _cfg from timm.models.registry import register_model from timm.models.layers import trunc_normal_ -class VisionTransformerGAPwithFeatureMap(VisionTransformer): +from .stn import STNet + +class VisionTransformerGAP(VisionTransformer): # [XXX] It might be a bad idea to use vision transformer for small datasets. # ref: ViT paper -- "transformers lack some of the inductive biases inherent # to CNNs, such as translation equivariance and locality". @@ -93,3 +95,13 @@ class VisionTransformerGAPwithFeatureMap(VisionTransformer): # Resized to ??? x = self.output1(x) # Regression head return x + + +class STNet_VisionTransformerGAP(VisionTransformerGAP): + def __init__(self, img_shape: torch.Size, *args, **kwargs): + super(STNet_VisionTransformerGAP, self).__init__(*args, **kwargs) + self.stnet = STNet(img_shape) + + def forward(self, x): + x = self.stnet(x) + return super(STNet_VisionTransformerGAP, self).forward(x) \ No newline at end of file