r"""Transformer-encoder-regressor, adapted for reverse-perspective network. This model is identical to the *TransCrowd* [#]_ model, which we use for the subproblem of actually counting the heads in each *transformed* raw image. .. [#] Liang, D., Chen, X., Xu, W., Zhou, Y., & Bai, X. (2022). Transcrowd: weakly-supervised crowd counting with transformers. Science China Information Sciences, 65(6), 160104. """ from functools import partial import numpy as np import torch import torch.nn as nn import torch.nn.functional as F # The original paper uses timm to create and import/export custom models, # so we follow suit from timm.models.vision_transformer import VisionTransformer, _cfg from timm.models.registry import register_model from timm.models.layers import trunc_normal_ class VisionTransformer_GAP(VisionTransformer): # [XXX] It might be a bad idea to use vision transformer for small datasets. # ref: ViT paper -- "transformers lack some of the inductive biases inherent # to CNNs, such as translation equivariance and locality". # convolution is specifically equivariant in translation (linear and # shift-equivariant), specifically. # tl;dr: CNNs might perform better for small datasets. Not sure abt performance. def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) num_patches = self.patch_embed.num_patches # That {p_1, p_2, ..., p_N} pos embedding self.pos_embed = nn.Parameter( torch.zeros(1, num_patches + 1, self.embed_dim) ) # Fill self.pos_embed with N(0, 1) truncated to |0.2 * std|. trunc_normal_(self.pos_embed, std=.02) # The "regression head" (I think? [XXX]) self.output1 = nn.ModuleDict({ "output1.relu0": nn.ReLU(), "output1.linear0": nn.Linear(in_features=6912 * 4, out_features=128), "output1.relu1": nn.ReLU(), "output1.dropout0": nn.Dropout(p=0.5), "output1.linear1": nn.Linear(in_features=128, out_features=1), }) self.output1.apply(self._init_weights) def forward_features(self, x): B = x.shape[0] x = self.patch_embed(x) # [XXX] Why do we need class token here? (ref. prev papers) cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) # Concatenate along j # 3.2 Patch embedding x = x + self.pos_embed x = self.pos_drop(x) # [XXX] Drop some patches out -- or not? # 3.3 Transformer-encoder for block in self.blocks: x = block(x) # [TODO] Interpret x = self.norm(x) x = x[:, 1:] return x def forward(self, x): x = self.forward_features(x) x = F.adaptive_avg_pool1d(x, (48)) x = x.view(x.shape[0], -1) x = self.output1(x) return x