Extra changes, also not sure if works

This commit is contained in:
Zhengyi Chen 2024-03-02 23:32:19 +00:00
parent f059924b75
commit c88b938680
2 changed files with 16 additions and 7 deletions

View file

@ -22,6 +22,7 @@ from timm.models.registry import register_model
from timm.models.layers import trunc_normal_
from .stn import STNet
from .glue import SquareCropTransformLayer
class VisionTransformerGAP(VisionTransformer):
# [XXX] It might be a bad idea to use vision transformer for small datasets.
@ -102,10 +103,12 @@ class STNet_VisionTransformerGAP(VisionTransformerGAP):
def __init__(self, img_shape: torch.Size, *args, **kwargs):
super(STNet_VisionTransformerGAP, self).__init__(*args, **kwargs)
self.stnet = STNet(img_shape)
self.glue = SquareCropTransformLayer(img_size)
def forward(self, x):
x = self.stnet(x)
return super(STNet_VisionTransformerGAP, self).forward(x)
def forward(self, x, t):
x, t = self.stnet(x, t)
x, t = self.glue(x, t)
return super(STNet_VisionTransformerGAP, self).forward(x), t
@register_model