r"""Transformer-encoder-regressor, adapted for reverse-perspective network. This model is identical to the *TransCrowd* [#]_ model, which we use for the subproblem of actually counting the heads in each *transformed* raw image. .. [#] Liang, D., Chen, X., Xu, W., Zhou, Y., & Bai, X. (2022). Transcrowd: weakly-supervised crowd counting with transformers. Science China Information Sciences, 65(6), 160104. """ from functools import partial import numpy as np import torch import torch.nn as nn import torch.nn.functional as F # The original paper uses timm to create and import/export custom models, # so we follow suit from timm.models.vision_transformer import VisionTransformer, _cfg from timm.models.registry import register_model from timm.models.layers import trunc_normal_ from .stn import STNet class VisionTransformerGAP(VisionTransformer): # [XXX] It might be a bad idea to use vision transformer for small datasets. # ref: ViT paper -- "transformers lack some of the inductive biases inherent # to CNNs, such as translation equivariance and locality". # convolution is specifically equivariant in translation (linear and # shift-equivariant), specifically. # tl;dr: CNNs might perform better for small datasets AND should perform # better for embedded systems. def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) num_patches = self.patch_embed.num_patches # That {p_1, p_2, ..., p_N} pos embedding self.pos_embed = nn.Parameter( torch.zeros(1, num_patches + 1, self.embed_dim) ) # Fill self.pos_embed with N(0, 1) truncated to |0.2 * std|. trunc_normal_(self.pos_embed, std=.02) # The "regression head" self.output1 = nn.ModuleDict({ "output1.relu0": nn.ReLU(), "output1.linear0": nn.Linear(in_features=6912 * 4, out_features=128), "output1.relu1": nn.ReLU(), "output1.dropout0": nn.Dropout(p=0.5), "output1.linear1": nn.Linear(in_features=128, out_features=1), }) self.output1.apply(self._init_weights) # Attention map, which we use to train self.attention_map = torch.Tensor(np.zeros((1152, 768))) # (3, 2) resized imgs def forward_features(self, x): B = x.shape[0] # 3.2 Patch embed x = self.patch_embed(x) # ViT: Classification token # This idea originated from BERT. # Essentially, because we are performing encoding without decoding, we # cannot fix the output dimensionality -- which the classification # problem absolutely needs. Instead, we use the classification token as # the sole input which the transformer would need to learn to encode # whatever it learnt from input into that token. # Source: https://datascience.stackexchange.com/a/110637 # That said, I don't think this is useful in this case... cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) # [[cls_token, x_i, ...]...] x = x + self.pos_embed x = self.pos_drop(x) # [XXX] Drop some patches out -- or not? # 3.3 Transformer-encoder for block in self.blocks: x = block(x) # Normalize x = self.norm(x) # Remove the classification token x = x[:, 1:] return x def forward(self, x): x = self.forward_features(x) # Compute encoding x = F.adaptive_avg_pool1d(x, (48)) x = x.view(x.shape[0], -1) # Move data for regression head # Resized to ??? x = self.output1(x) # Regression head return x class STNet_VisionTransformerGAP(VisionTransformerGAP): def __init__(self, img_shape: torch.Size, *args, **kwargs): super(STNet_VisionTransformerGAP, self).__init__(*args, **kwargs) self.stnet = STNet(img_shape) def forward(self, x): x = self.stnet(x) return super(STNet_VisionTransformerGAP, self).forward(x)