Extra changes, also not sure if works
This commit is contained in:
parent
f059924b75
commit
c88b938680
2 changed files with 16 additions and 7 deletions
|
|
@ -22,6 +22,7 @@ from timm.models.registry import register_model
|
|||
from timm.models.layers import trunc_normal_
|
||||
|
||||
from .stn import STNet
|
||||
from .glue import SquareCropTransformLayer
|
||||
|
||||
class VisionTransformerGAP(VisionTransformer):
|
||||
# [XXX] It might be a bad idea to use vision transformer for small datasets.
|
||||
|
|
@ -102,10 +103,12 @@ class STNet_VisionTransformerGAP(VisionTransformerGAP):
|
|||
def __init__(self, img_shape: torch.Size, *args, **kwargs):
|
||||
super(STNet_VisionTransformerGAP, self).__init__(*args, **kwargs)
|
||||
self.stnet = STNet(img_shape)
|
||||
self.glue = SquareCropTransformLayer(img_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stnet(x)
|
||||
return super(STNet_VisionTransformerGAP, self).forward(x)
|
||||
def forward(self, x, t):
|
||||
x, t = self.stnet(x, t)
|
||||
x, t = self.glue(x, t)
|
||||
return super(STNet_VisionTransformerGAP, self).forward(x), t
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue