diff --git a/model/reverse_perspective.py b/model/reverse_perspective.py index 4a95b2e..d807c89 100644 --- a/model/reverse_perspective.py +++ b/model/reverse_perspective.py @@ -37,9 +37,8 @@ class PerspectiveEstimator(nn.Module): After all, it is reasonable to say that you see more when you look at faraway places. - This do imply that **we need to obtain a reasonably good feature extractor - from general images before training this submodule**. Hence, for now, we - prob. should work on transformer first. + The paper utilizes a unsupervised loss -- "row feature density" refers to + the density of features computed from ? :param input_shape: (N, C, H, W) :param conv_kernel_shape: Oriented as (H, W) @@ -63,6 +62,8 @@ class PerspectiveEstimator(nn.Module): (_, _, height, width) = input_shape # Sanity checking + # [TODO] Maybe this is unnecessary, maybe we can automatically suggest new params, + # but right now let's just do this... (_conv_height, _conv_width) = ( np.floor( (height + 2 * conv_padding - conv_dilation * (conv_kernel_shape[0] - 1) - 1) @@ -112,4 +113,5 @@ class PerspectiveEstimator(nn.Module): return out + # def unsupervised_loss(predictions, targets): diff --git a/model/transcrowd_gap.py b/model/transcrowd_gap.py new file mode 100644 index 0000000..48fc520 --- /dev/null +++ b/model/transcrowd_gap.py @@ -0,0 +1,76 @@ +r"""Transformer-encoder-regressor, adapted for reverse-perspective network. + +This model is identical to the *TransCrowd* [#]_ model, which we use for the +subproblem of actually counting the heads in each *transformed* raw image. + +.. [#] Liang, D., Chen, X., Xu, W., Zhou, Y., & Bai, X. (2022). + Transcrowd: weakly-supervised crowd counting with transformers. + Science China Information Sciences, 65(6), 160104. +""" +from functools import partial + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +# The original paper uses timm to create and import/export custom models, +# so we follow suit +from timm.models.vision_transformer import VisionTransformer, _cfg +from timm.models.registry import register_model +from timm.models.layers import trunc_normal_ + +class VisionTransformer_GAP(VisionTransformer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + num_patches = self.patch_embed.num_patches + + # That {p_1, p_2, ..., p_N} pos embedding + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, self.embed_dim) + ) + # Fill self.pos_embed with N(0, 1) truncated to |0.2 * std|. + trunc_normal_(self.pos_embed, std=.02) + + # The "regression head" (I think? [XXX]) + self.output1 = nn.ModuleDict({ + "output1.relu0": nn.ReLU(), + "output1.linear0": nn.Linear(in_features=6912 * 4, out_features=128), + "output1.relu1": nn.ReLU(), + "output1.dropout0": nn.Dropout(p=0.5), + "output1.linear1": nn.Linear(in_features=128, out_features=1), + }) + self.output1.apply(self._init_weights) + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + # [XXX] Why do we need class token here? (ref. prev papers) + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) # Concatenate along j + + # 3.2 Patch embedding + x = x + self.pos_embed + x = self.pos_drop(x) # [XXX] Drop some patches out -- or not? + + # 3.3 Transformer-encoder + for block in self.blocks: + x = block(x) + + # [TODO] Interpret + x = self.norm(x) + + x = x[:, 1:] + + return x + + def forward(self, x): + x = self.forward_features(x) + x = F.adaptive_avg_pool1d(x, (48)) + x = x.view(x.shape[0], -1) + x = self.output1(x) + return x + + +