95 lines
3.7 KiB
Python
95 lines
3.7 KiB
Python
r"""Transformer-encoder-regressor, adapted for reverse-perspective network.
|
|
|
|
This model is identical to the *TransCrowd* [#]_ model, which we use for the
|
|
subproblem of actually counting the heads in each *transformed* raw image.
|
|
|
|
.. [#] Liang, D., Chen, X., Xu, W., Zhou, Y., & Bai, X. (2022).
|
|
Transcrowd: weakly-supervised crowd counting with transformers.
|
|
Science China Information Sciences, 65(6), 160104.
|
|
"""
|
|
from functools import partial
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
# The original paper uses timm to create and import/export custom models,
|
|
# so we follow suit
|
|
from timm.models.vision_transformer import VisionTransformer, _cfg
|
|
from timm.models.registry import register_model
|
|
from timm.models.layers import trunc_normal_
|
|
|
|
class VisionTransformerGAPwithFeatureMap(VisionTransformer):
|
|
# [XXX] It might be a bad idea to use vision transformer for small datasets.
|
|
# ref: ViT paper -- "transformers lack some of the inductive biases inherent
|
|
# to CNNs, such as translation equivariance and locality".
|
|
# convolution is specifically equivariant in translation (linear and
|
|
# shift-equivariant), specifically.
|
|
# tl;dr: CNNs might perform better for small datasets AND should perform
|
|
# better for embedded systems.
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
num_patches = self.patch_embed.num_patches
|
|
|
|
# That {p_1, p_2, ..., p_N} pos embedding
|
|
self.pos_embed = nn.Parameter(
|
|
torch.zeros(1, num_patches + 1, self.embed_dim)
|
|
)
|
|
# Fill self.pos_embed with N(0, 1) truncated to |0.2 * std|.
|
|
trunc_normal_(self.pos_embed, std=.02)
|
|
|
|
# The "regression head"
|
|
self.output1 = nn.ModuleDict({
|
|
"output1.relu0": nn.ReLU(),
|
|
"output1.linear0": nn.Linear(in_features=6912 * 4, out_features=128),
|
|
"output1.relu1": nn.ReLU(),
|
|
"output1.dropout0": nn.Dropout(p=0.5),
|
|
"output1.linear1": nn.Linear(in_features=128, out_features=1),
|
|
})
|
|
self.output1.apply(self._init_weights)
|
|
|
|
# Attention map, which we use to train
|
|
self.attention_map = torch.Tensor(np.zeros((1152, 768))) # (3, 2) resized imgs
|
|
|
|
def forward_features(self, x):
|
|
B = x.shape[0]
|
|
|
|
# 3.2 Patch embed
|
|
x = self.patch_embed(x)
|
|
|
|
# ViT: Classification token
|
|
# This idea originated from BERT.
|
|
# Essentially, because we are performing encoding without decoding, we
|
|
# cannot fix the output dimensionality -- which the classification
|
|
# problem absolutely needs. Instead, we use the classification token as
|
|
# the sole input which the transformer would need to learn to encode
|
|
# whatever it learnt from input into that token.
|
|
# Source: https://datascience.stackexchange.com/a/110637
|
|
# That said, I don't think this is useful in this case...
|
|
cls_tokens = self.cls_token.expand(B, -1, -1)
|
|
x = torch.cat((cls_tokens, x), dim=1) # [[cls_token, x_i, ...]...]
|
|
|
|
x = x + self.pos_embed
|
|
x = self.pos_drop(x) # [XXX] Drop some patches out -- or not?
|
|
|
|
# 3.3 Transformer-encoder
|
|
for block in self.blocks:
|
|
x = block(x)
|
|
|
|
# Normalize
|
|
x = self.norm(x)
|
|
|
|
# Remove the classification token
|
|
x = x[:, 1:]
|
|
|
|
return x
|
|
|
|
def forward(self, x):
|
|
x = self.forward_features(x) # Compute encoding
|
|
x = F.adaptive_avg_pool1d(x, (48))
|
|
x = x.view(x.shape[0], -1) # Move data for regression head
|
|
# Resized to ???
|
|
x = self.output1(x) # Regression head
|
|
return x
|