Added STNet-prepended TransCrowd GAP

This commit is contained in:
Zhengyi Chen 2024-02-27 18:57:13 +00:00
parent c2a1bb46ef
commit 49a913a328
2 changed files with 16 additions and 2 deletions

View file

@ -20,7 +20,9 @@ from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_
class VisionTransformerGAPwithFeatureMap(VisionTransformer):
from .stn import STNet
class VisionTransformerGAP(VisionTransformer):
# [XXX] It might be a bad idea to use vision transformer for small datasets.
# ref: ViT paper -- "transformers lack some of the inductive biases inherent
# to CNNs, such as translation equivariance and locality".
@ -93,3 +95,13 @@ class VisionTransformerGAPwithFeatureMap(VisionTransformer):
# Resized to ???
x = self.output1(x) # Regression head
return x
class STNet_VisionTransformerGAP(VisionTransformerGAP):
def __init__(self, img_shape: torch.Size, *args, **kwargs):
super(STNet_VisionTransformerGAP, self).__init__(*args, **kwargs)
self.stnet = STNet(img_shape)
def forward(self, x):
x = self.stnet(x)
return super(STNet_VisionTransformerGAP, self).forward(x)