More working than not
Not sure if validation works, call it a day
This commit is contained in:
parent
4a03211c83
commit
12aabb0d3f
10 changed files with 116 additions and 105 deletions
|
|
@ -25,16 +25,8 @@ from .stn import STNet
|
|||
from .glue import SquareCropTransformLayer
|
||||
|
||||
class VisionTransformerGAP(VisionTransformer):
|
||||
# [XXX] It might be a bad idea to use vision transformer for small datasets.
|
||||
# ref: ViT paper -- "transformers lack some of the inductive biases inherent
|
||||
# to CNNs, such as translation equivariance and locality".
|
||||
# convolution is specifically equivariant in translation (linear and
|
||||
# shift-equivariant), specifically.
|
||||
# tl;dr: CNNs might perform better for small datasets AND should perform
|
||||
# better for embedded systems.
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self, img_size: int, *args, **kwargs):
|
||||
super().__init__(img_size=img_size, *args, **kwargs)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
# That {p_1, p_2, ..., p_N} pos embedding
|
||||
|
|
@ -45,17 +37,17 @@ class VisionTransformerGAP(VisionTransformer):
|
|||
trunc_normal_(self.pos_embed, std=.02)
|
||||
|
||||
# The "regression head"
|
||||
self.output1 = nn.ModuleDict({
|
||||
"output1.relu0": nn.ReLU(),
|
||||
"output1.linear0": nn.Linear(in_features=6912 * 4, out_features=128),
|
||||
"output1.relu1": nn.ReLU(),
|
||||
"output1.dropout0": nn.Dropout(p=0.5),
|
||||
"output1.linear1": nn.Linear(in_features=128, out_features=1),
|
||||
})
|
||||
self.output1 = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(in_features=6912 * 4, out_features=128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(p=0.5),
|
||||
nn.Linear(in_features=128, out_features=1),
|
||||
)
|
||||
self.output1.apply(self._init_weights)
|
||||
|
||||
# Attention map, which we use to train
|
||||
self.attention_map = torch.Tensor(np.zeros((1152, 768))) # (3, 2) resized imgs
|
||||
# glue layer -- since we delay image cropping here
|
||||
self.glue = SquareCropTransformLayer(img_size)
|
||||
|
||||
def forward_features(self, x):
|
||||
B = x.shape[0]
|
||||
|
|
@ -90,25 +82,26 @@ class VisionTransformerGAP(VisionTransformer):
|
|||
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, t):
|
||||
with torch.no_grad():
|
||||
x, t = self.glue(x, t)
|
||||
print(f"Glue: {x.shape} | {t.shape}")
|
||||
x = self.forward_features(x) # Compute encoding
|
||||
x = F.adaptive_avg_pool1d(x, (48))
|
||||
x = x.view(x.shape[0], -1) # Move data for regression head
|
||||
# Resized to ???
|
||||
x = self.output1(x) # Regression head
|
||||
return x
|
||||
return x, t
|
||||
|
||||
|
||||
class STNet_VisionTransformerGAP(VisionTransformerGAP):
|
||||
def __init__(self, img_shape: torch.Size, *args, **kwargs):
|
||||
super(STNet_VisionTransformerGAP, self).__init__(*args, **kwargs)
|
||||
def __init__(self, img_shape: torch.Size, img_size: int, *args, **kwargs):
|
||||
super(STNet_VisionTransformerGAP, self).__init__(img_size, *args, **kwargs)
|
||||
self.stnet = STNet(img_shape)
|
||||
self.glue = SquareCropTransformLayer(img_size)
|
||||
|
||||
def forward(self, x, t):
|
||||
x, t = self.stnet(x, t)
|
||||
x, t = self.glue(x, t)
|
||||
return super(STNet_VisionTransformerGAP, self).forward(x), t
|
||||
return super(STNet_VisionTransformerGAP, self).forward(x, t)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
@ -131,7 +124,7 @@ def base_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
|
|||
@register_model
|
||||
def stn_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
|
||||
model = STNet_VisionTransformerGAP(
|
||||
img_shape=torch.Size((3, 384, 384)),
|
||||
img_shape=torch.Size((3, 1152, 768)),
|
||||
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
||||
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
**kwargs
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue