Loss revamp & Renamed model to network
This commit is contained in:
parent
0d35d607fe
commit
9d2a30a226
7 changed files with 44 additions and 32 deletions
|
|
@ -1,138 +0,0 @@
|
|||
r"""Transformer-encoder-regressor, adapted for reverse-perspective network.
|
||||
|
||||
This model is identical to the *TransCrowd* [#]_ model, which we use for the
|
||||
subproblem of actually counting the heads in each *transformed* raw image.
|
||||
|
||||
.. [#] Liang, D., Chen, X., Xu, W., Zhou, Y., & Bai, X. (2022).
|
||||
Transcrowd: weakly-supervised crowd counting with transformers.
|
||||
Science China Information Sciences, 65(6), 160104.
|
||||
"""
|
||||
from typing import Optional
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# The original paper uses timm to create and import/export custom models,
|
||||
# so we follow suit
|
||||
from timm.models.vision_transformer import VisionTransformer, _cfg
|
||||
from timm.models.registry import register_model
|
||||
from timm.models.layers import trunc_normal_
|
||||
|
||||
from .stn import STNet
|
||||
from .glue import SquareCropTransformLayer
|
||||
|
||||
class VisionTransformerGAP(VisionTransformer):
|
||||
def __init__(self, img_size: int, *args, **kwargs):
|
||||
super().__init__(img_size=img_size, *args, **kwargs)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
# That {p_1, p_2, ..., p_N} pos embedding
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches + 1, self.embed_dim)
|
||||
)
|
||||
# Fill self.pos_embed with N(0, 1) truncated to |0.2 * std|.
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
|
||||
# The "regression head"
|
||||
self.output1 = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(in_features=6912 * 4, out_features=128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(p=0.5),
|
||||
nn.Linear(in_features=128, out_features=1),
|
||||
)
|
||||
self.output1.apply(self._init_weights)
|
||||
|
||||
# glue layer -- since we delay image cropping here
|
||||
self.glue = SquareCropTransformLayer(img_size)
|
||||
|
||||
def forward_features(self, x):
|
||||
B = x.shape[0]
|
||||
|
||||
# 3.2 Patch embed
|
||||
x = self.patch_embed(x)
|
||||
|
||||
# ViT: Classification token
|
||||
# This idea originated from BERT.
|
||||
# Essentially, because we are performing encoding without decoding, we
|
||||
# cannot fix the output dimensionality -- which the classification
|
||||
# problem absolutely needs. Instead, we use the classification token as
|
||||
# the sole input which the transformer would need to learn to encode
|
||||
# whatever it learnt from input into that token.
|
||||
# Source: https://datascience.stackexchange.com/a/110637
|
||||
# That said, I don't think this is useful for GAP...
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1) # [[cls_token, x_i, ...]...]
|
||||
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x) # [XXX] Drop some patches out -- or not?
|
||||
|
||||
# 3.3 Transformer-encoder
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
# Normalize
|
||||
x = self.norm(x)
|
||||
|
||||
# Remove the classification token
|
||||
x = x[:, 1:]
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x, t):
|
||||
with torch.no_grad():
|
||||
x, t = self.glue(x, t)
|
||||
x = self.forward_features(x) # Compute encoding
|
||||
x = F.adaptive_avg_pool1d(x, (48))
|
||||
x = x.view(x.shape[0], -1) # Move data for regression head
|
||||
# Resized to ???
|
||||
x = self.output1(x) # Regression head
|
||||
return x, t
|
||||
|
||||
|
||||
class STNet_VisionTransformerGAP(VisionTransformerGAP):
|
||||
def __init__(self, img_shape: torch.Size, img_size: int, *args, **kwargs):
|
||||
super(STNet_VisionTransformerGAP, self).__init__(img_size, *args, **kwargs)
|
||||
self.stnet = STNet(img_shape)
|
||||
|
||||
def forward(self, x, t):
|
||||
x, t = self.stnet(x, t)
|
||||
return super(STNet_VisionTransformerGAP, self).forward(x, t)
|
||||
|
||||
|
||||
@register_model
|
||||
def base_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
|
||||
model = VisionTransformerGAP(
|
||||
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
||||
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
**kwargs
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
|
||||
if pth_tar is not None:
|
||||
checkpoint = torch.load(pth_tar)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
print("Loaded pre-trained pth.tar from \'{}\'".format(pth_tar))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def stn_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
|
||||
model = STNet_VisionTransformerGAP(
|
||||
img_shape=torch.Size((3, 1152, 768)),
|
||||
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
||||
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
**kwargs
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
|
||||
if pth_tar is not None:
|
||||
checkpoint = torch.load(pth_tar)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
print("Loaded pre-trained pth.tar from \'{}\'".format(pth_tar))
|
||||
|
||||
return model
|
||||
Loading…
Add table
Add a link
Reference in a new issue