More working than not

Not sure if validation works, call it a day
This commit is contained in:
Zhengyi Chen 2024-03-03 03:16:54 +00:00
parent 4a03211c83
commit 12aabb0d3f
10 changed files with 116 additions and 105 deletions

View file

@ -46,7 +46,7 @@ class SquareCropTransformLayer(nn.Module):
torch.tensor_split(
torch.cat(
torch.tensor_split(
t_,
kpoints_,
h_split_count,
dim=1
)

View file

@ -24,6 +24,7 @@ class STNet(nn.Module):
_dummy_size_ = input_size
# shape checking
print("STN: dummy_size {}".format(_dummy_size_))
_dummy_x_ = torch.zeros(_dummy_size_)
# (3.1) Spatial transformer localization-network
@ -81,6 +82,7 @@ class STNet(nn.Module):
def forward(self, x, t):
# print("STN: {} | {}".format(x.shape, t.shape))
# transform the input, do nothing else
return self.stn(x, t)

View file

@ -25,16 +25,8 @@ from .stn import STNet
from .glue import SquareCropTransformLayer
class VisionTransformerGAP(VisionTransformer):
# [XXX] It might be a bad idea to use vision transformer for small datasets.
# ref: ViT paper -- "transformers lack some of the inductive biases inherent
# to CNNs, such as translation equivariance and locality".
# convolution is specifically equivariant in translation (linear and
# shift-equivariant), specifically.
# tl;dr: CNNs might perform better for small datasets AND should perform
# better for embedded systems.
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, img_size: int, *args, **kwargs):
super().__init__(img_size=img_size, *args, **kwargs)
num_patches = self.patch_embed.num_patches
# That {p_1, p_2, ..., p_N} pos embedding
@ -45,17 +37,17 @@ class VisionTransformerGAP(VisionTransformer):
trunc_normal_(self.pos_embed, std=.02)
# The "regression head"
self.output1 = nn.ModuleDict({
"output1.relu0": nn.ReLU(),
"output1.linear0": nn.Linear(in_features=6912 * 4, out_features=128),
"output1.relu1": nn.ReLU(),
"output1.dropout0": nn.Dropout(p=0.5),
"output1.linear1": nn.Linear(in_features=128, out_features=1),
})
self.output1 = nn.Sequential(
nn.ReLU(),
nn.Linear(in_features=6912 * 4, out_features=128),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(in_features=128, out_features=1),
)
self.output1.apply(self._init_weights)
# Attention map, which we use to train
self.attention_map = torch.Tensor(np.zeros((1152, 768))) # (3, 2) resized imgs
# glue layer -- since we delay image cropping here
self.glue = SquareCropTransformLayer(img_size)
def forward_features(self, x):
B = x.shape[0]
@ -90,25 +82,26 @@ class VisionTransformerGAP(VisionTransformer):
return x
def forward(self, x):
def forward(self, x, t):
with torch.no_grad():
x, t = self.glue(x, t)
print(f"Glue: {x.shape} | {t.shape}")
x = self.forward_features(x) # Compute encoding
x = F.adaptive_avg_pool1d(x, (48))
x = x.view(x.shape[0], -1) # Move data for regression head
# Resized to ???
x = self.output1(x) # Regression head
return x
return x, t
class STNet_VisionTransformerGAP(VisionTransformerGAP):
def __init__(self, img_shape: torch.Size, *args, **kwargs):
super(STNet_VisionTransformerGAP, self).__init__(*args, **kwargs)
def __init__(self, img_shape: torch.Size, img_size: int, *args, **kwargs):
super(STNet_VisionTransformerGAP, self).__init__(img_size, *args, **kwargs)
self.stnet = STNet(img_shape)
self.glue = SquareCropTransformLayer(img_size)
def forward(self, x, t):
x, t = self.stnet(x, t)
x, t = self.glue(x, t)
return super(STNet_VisionTransformerGAP, self).forward(x), t
return super(STNet_VisionTransformerGAP, self).forward(x, t)
@register_model
@ -131,7 +124,7 @@ def base_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
@register_model
def stn_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
model = STNet_VisionTransformerGAP(
img_shape=torch.Size((3, 384, 384)),
img_shape=torch.Size((3, 1152, 768)),
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs