Added transcrowd_gap
This commit is contained in:
parent
bcff06f9c2
commit
322d7f9ea5
2 changed files with 81 additions and 3 deletions
|
|
@ -37,9 +37,8 @@ class PerspectiveEstimator(nn.Module):
|
||||||
After all, it is reasonable to say that you see more when you look at
|
After all, it is reasonable to say that you see more when you look at
|
||||||
faraway places.
|
faraway places.
|
||||||
|
|
||||||
This do imply that **we need to obtain a reasonably good feature extractor
|
The paper utilizes a unsupervised loss -- "row feature density" refers to
|
||||||
from general images before training this submodule**. Hence, for now, we
|
the density of features computed from ?
|
||||||
prob. should work on transformer first.
|
|
||||||
|
|
||||||
:param input_shape: (N, C, H, W)
|
:param input_shape: (N, C, H, W)
|
||||||
:param conv_kernel_shape: Oriented as (H, W)
|
:param conv_kernel_shape: Oriented as (H, W)
|
||||||
|
|
@ -63,6 +62,8 @@ class PerspectiveEstimator(nn.Module):
|
||||||
(_, _, height, width) = input_shape
|
(_, _, height, width) = input_shape
|
||||||
|
|
||||||
# Sanity checking
|
# Sanity checking
|
||||||
|
# [TODO] Maybe this is unnecessary, maybe we can automatically suggest new params,
|
||||||
|
# but right now let's just do this...
|
||||||
(_conv_height, _conv_width) = (
|
(_conv_height, _conv_width) = (
|
||||||
np.floor(
|
np.floor(
|
||||||
(height + 2 * conv_padding - conv_dilation * (conv_kernel_shape[0] - 1) - 1)
|
(height + 2 * conv_padding - conv_dilation * (conv_kernel_shape[0] - 1) - 1)
|
||||||
|
|
@ -112,4 +113,5 @@ class PerspectiveEstimator(nn.Module):
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
# def unsupervised_loss(predictions, targets):
|
||||||
|
|
||||||
|
|
|
||||||
76
model/transcrowd_gap.py
Normal file
76
model/transcrowd_gap.py
Normal file
|
|
@ -0,0 +1,76 @@
|
||||||
|
r"""Transformer-encoder-regressor, adapted for reverse-perspective network.
|
||||||
|
|
||||||
|
This model is identical to the *TransCrowd* [#]_ model, which we use for the
|
||||||
|
subproblem of actually counting the heads in each *transformed* raw image.
|
||||||
|
|
||||||
|
.. [#] Liang, D., Chen, X., Xu, W., Zhou, Y., & Bai, X. (2022).
|
||||||
|
Transcrowd: weakly-supervised crowd counting with transformers.
|
||||||
|
Science China Information Sciences, 65(6), 160104.
|
||||||
|
"""
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
# The original paper uses timm to create and import/export custom models,
|
||||||
|
# so we follow suit
|
||||||
|
from timm.models.vision_transformer import VisionTransformer, _cfg
|
||||||
|
from timm.models.registry import register_model
|
||||||
|
from timm.models.layers import trunc_normal_
|
||||||
|
|
||||||
|
class VisionTransformer_GAP(VisionTransformer):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
num_patches = self.patch_embed.num_patches
|
||||||
|
|
||||||
|
# That {p_1, p_2, ..., p_N} pos embedding
|
||||||
|
self.pos_embed = nn.Parameter(
|
||||||
|
torch.zeros(1, num_patches + 1, self.embed_dim)
|
||||||
|
)
|
||||||
|
# Fill self.pos_embed with N(0, 1) truncated to |0.2 * std|.
|
||||||
|
trunc_normal_(self.pos_embed, std=.02)
|
||||||
|
|
||||||
|
# The "regression head" (I think? [XXX])
|
||||||
|
self.output1 = nn.ModuleDict({
|
||||||
|
"output1.relu0": nn.ReLU(),
|
||||||
|
"output1.linear0": nn.Linear(in_features=6912 * 4, out_features=128),
|
||||||
|
"output1.relu1": nn.ReLU(),
|
||||||
|
"output1.dropout0": nn.Dropout(p=0.5),
|
||||||
|
"output1.linear1": nn.Linear(in_features=128, out_features=1),
|
||||||
|
})
|
||||||
|
self.output1.apply(self._init_weights)
|
||||||
|
|
||||||
|
def forward_features(self, x):
|
||||||
|
B = x.shape[0]
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
|
||||||
|
# [XXX] Why do we need class token here? (ref. prev papers)
|
||||||
|
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||||
|
x = torch.cat((cls_tokens, x), dim=1) # Concatenate along j
|
||||||
|
|
||||||
|
# 3.2 Patch embedding
|
||||||
|
x = x + self.pos_embed
|
||||||
|
x = self.pos_drop(x) # [XXX] Drop some patches out -- or not?
|
||||||
|
|
||||||
|
# 3.3 Transformer-encoder
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
# [TODO] Interpret
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
x = x[:, 1:]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.forward_features(x)
|
||||||
|
x = F.adaptive_avg_pool1d(x, (48))
|
||||||
|
x = x.view(x.shape[0], -1)
|
||||||
|
x = self.output1(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue