Added sth more

This commit is contained in:
Zhengyi Chen 2024-02-27 21:27:02 +00:00
parent 49a913a328
commit 99266d9c92
2 changed files with 77 additions and 30 deletions

View file

@ -7,6 +7,7 @@ subproblem of actually counting the heads in each *transformed* raw image.
Transcrowd: weakly-supervised crowd counting with transformers.
Science China Information Sciences, 65(6), 160104.
"""
from typing import Optional
from functools import partial
import numpy as np
@ -69,7 +70,7 @@ class VisionTransformerGAP(VisionTransformer):
# the sole input which the transformer would need to learn to encode
# whatever it learnt from input into that token.
# Source: https://datascience.stackexchange.com/a/110637
# That said, I don't think this is useful in this case...
# That said, I don't think this is useful for GAP...
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1) # [[cls_token, x_i, ...]...]
@ -104,4 +105,39 @@ class STNet_VisionTransformerGAP(VisionTransformerGAP):
def forward(self, x):
x = self.stnet(x)
return super(STNet_VisionTransformerGAP, self).forward(x)
return super(STNet_VisionTransformerGAP, self).forward(x)
@register_model
def base_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
model = VisionTransformerGAP(
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs
)
model.default_cfg = _cfg()
if pth_tar is not None:
checkpoint = torch.load(pth_tar)
model.load_state_dict(checkpoint["model"], strict=False)
print("Loaded pre-trained pth.tar from \'{}\'".format(pth_tar))
return model
@register_model
def stn_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
model = STNet_VisionTransformerGAP(
img_shape=torch.Size((3, 384, 384)),
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs
)
model.default_cfg = _cfg()
if pth_tar is not None:
checkpoint = torch.load(pth_tar)
model.load_state_dict(checkpoint["model"], strict=False)
print("Loaded pre-trained pth.tar from \'{}\'".format(pth_tar))
return model