Loss revamp & Renamed model to network

This commit is contained in:
Zhengyi Chen 2024-03-06 20:44:37 +00:00
parent 0d35d607fe
commit 9d2a30a226
7 changed files with 44 additions and 32 deletions

View file

@ -1,138 +0,0 @@
r"""Transformer-encoder-regressor, adapted for reverse-perspective network.
This model is identical to the *TransCrowd* [#]_ model, which we use for the
subproblem of actually counting the heads in each *transformed* raw image.
.. [#] Liang, D., Chen, X., Xu, W., Zhou, Y., & Bai, X. (2022).
Transcrowd: weakly-supervised crowd counting with transformers.
Science China Information Sciences, 65(6), 160104.
"""
from typing import Optional
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# The original paper uses timm to create and import/export custom models,
# so we follow suit
from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_
from .stn import STNet
from .glue import SquareCropTransformLayer
class VisionTransformerGAP(VisionTransformer):
def __init__(self, img_size: int, *args, **kwargs):
super().__init__(img_size=img_size, *args, **kwargs)
num_patches = self.patch_embed.num_patches
# That {p_1, p_2, ..., p_N} pos embedding
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, self.embed_dim)
)
# Fill self.pos_embed with N(0, 1) truncated to |0.2 * std|.
trunc_normal_(self.pos_embed, std=.02)
# The "regression head"
self.output1 = nn.Sequential(
nn.ReLU(),
nn.Linear(in_features=6912 * 4, out_features=128),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(in_features=128, out_features=1),
)
self.output1.apply(self._init_weights)
# glue layer -- since we delay image cropping here
self.glue = SquareCropTransformLayer(img_size)
def forward_features(self, x):
B = x.shape[0]
# 3.2 Patch embed
x = self.patch_embed(x)
# ViT: Classification token
# This idea originated from BERT.
# Essentially, because we are performing encoding without decoding, we
# cannot fix the output dimensionality -- which the classification
# problem absolutely needs. Instead, we use the classification token as
# the sole input which the transformer would need to learn to encode
# whatever it learnt from input into that token.
# Source: https://datascience.stackexchange.com/a/110637
# That said, I don't think this is useful for GAP...
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1) # [[cls_token, x_i, ...]...]
x = x + self.pos_embed
x = self.pos_drop(x) # [XXX] Drop some patches out -- or not?
# 3.3 Transformer-encoder
for block in self.blocks:
x = block(x)
# Normalize
x = self.norm(x)
# Remove the classification token
x = x[:, 1:]
return x
def forward(self, x, t):
with torch.no_grad():
x, t = self.glue(x, t)
x = self.forward_features(x) # Compute encoding
x = F.adaptive_avg_pool1d(x, (48))
x = x.view(x.shape[0], -1) # Move data for regression head
# Resized to ???
x = self.output1(x) # Regression head
return x, t
class STNet_VisionTransformerGAP(VisionTransformerGAP):
def __init__(self, img_shape: torch.Size, img_size: int, *args, **kwargs):
super(STNet_VisionTransformerGAP, self).__init__(img_size, *args, **kwargs)
self.stnet = STNet(img_shape)
def forward(self, x, t):
x, t = self.stnet(x, t)
return super(STNet_VisionTransformerGAP, self).forward(x, t)
@register_model
def base_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
model = VisionTransformerGAP(
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs
)
model.default_cfg = _cfg()
if pth_tar is not None:
checkpoint = torch.load(pth_tar)
model.load_state_dict(checkpoint["model"], strict=False)
print("Loaded pre-trained pth.tar from \'{}\'".format(pth_tar))
return model
@register_model
def stn_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
model = STNet_VisionTransformerGAP(
img_shape=torch.Size((3, 1152, 768)),
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs
)
model.default_cfg = _cfg()
if pth_tar is not None:
checkpoint = torch.load(pth_tar)
model.load_state_dict(checkpoint["model"], strict=False)
print("Loaded pre-trained pth.tar from \'{}\'".format(pth_tar))
return model