r"""Transformer-encoder-regressor, adapted for reverse-perspective network. This model is identical to the *TransCrowd* [#]_ model, which we use for the subproblem of actually counting the heads in each *transformed* raw image. .. [#] Liang, D., Chen, X., Xu, W., Zhou, Y., & Bai, X. (2022). Transcrowd: weakly-supervised crowd counting with transformers. Science China Information Sciences, 65(6), 160104. """ from typing import Optional from functools import partial import numpy as np import torch import torch.nn as nn import torch.nn.functional as F # The original paper uses timm to create and import/export custom models, # so we follow suit from timm.models.vision_transformer import VisionTransformer, _cfg from timm.models.registry import register_model from timm.models.layers import trunc_normal_ from .stn import STNet from .glue import SquareCropTransformLayer class VisionTransformerGAP(VisionTransformer): # [XXX] It might be a bad idea to use vision transformer for small datasets. # ref: ViT paper -- "transformers lack some of the inductive biases inherent # to CNNs, such as translation equivariance and locality". # convolution is specifically equivariant in translation (linear and # shift-equivariant), specifically. # tl;dr: CNNs might perform better for small datasets AND should perform # better for embedded systems. def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) num_patches = self.patch_embed.num_patches # That {p_1, p_2, ..., p_N} pos embedding self.pos_embed = nn.Parameter( torch.zeros(1, num_patches + 1, self.embed_dim) ) # Fill self.pos_embed with N(0, 1) truncated to |0.2 * std|. trunc_normal_(self.pos_embed, std=.02) # The "regression head" self.output1 = nn.ModuleDict({ "output1.relu0": nn.ReLU(), "output1.linear0": nn.Linear(in_features=6912 * 4, out_features=128), "output1.relu1": nn.ReLU(), "output1.dropout0": nn.Dropout(p=0.5), "output1.linear1": nn.Linear(in_features=128, out_features=1), }) self.output1.apply(self._init_weights) # Attention map, which we use to train self.attention_map = torch.Tensor(np.zeros((1152, 768))) # (3, 2) resized imgs def forward_features(self, x): B = x.shape[0] # 3.2 Patch embed x = self.patch_embed(x) # ViT: Classification token # This idea originated from BERT. # Essentially, because we are performing encoding without decoding, we # cannot fix the output dimensionality -- which the classification # problem absolutely needs. Instead, we use the classification token as # the sole input which the transformer would need to learn to encode # whatever it learnt from input into that token. # Source: https://datascience.stackexchange.com/a/110637 # That said, I don't think this is useful for GAP... cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) # [[cls_token, x_i, ...]...] x = x + self.pos_embed x = self.pos_drop(x) # [XXX] Drop some patches out -- or not? # 3.3 Transformer-encoder for block in self.blocks: x = block(x) # Normalize x = self.norm(x) # Remove the classification token x = x[:, 1:] return x def forward(self, x): x = self.forward_features(x) # Compute encoding x = F.adaptive_avg_pool1d(x, (48)) x = x.view(x.shape[0], -1) # Move data for regression head # Resized to ??? x = self.output1(x) # Regression head return x class STNet_VisionTransformerGAP(VisionTransformerGAP): def __init__(self, img_shape: torch.Size, *args, **kwargs): super(STNet_VisionTransformerGAP, self).__init__(*args, **kwargs) self.stnet = STNet(img_shape) self.glue = SquareCropTransformLayer(img_size) def forward(self, x, t): x, t = self.stnet(x, t) x, t = self.glue(x, t) return super(STNet_VisionTransformerGAP, self).forward(x), t @register_model def base_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs): model = VisionTransformerGAP( img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs ) model.default_cfg = _cfg() if pth_tar is not None: checkpoint = torch.load(pth_tar) model.load_state_dict(checkpoint["model"], strict=False) print("Loaded pre-trained pth.tar from \'{}\'".format(pth_tar)) return model @register_model def stn_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs): model = STNet_VisionTransformerGAP( img_shape=torch.Size((3, 384, 384)), img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs ) model.default_cfg = _cfg() if pth_tar is not None: checkpoint = torch.load(pth_tar) model.load_state_dict(checkpoint["model"], strict=False) print("Loaded pre-trained pth.tar from \'{}\'".format(pth_tar)) return model