Added sth more

This commit is contained in:
Zhengyi Chen 2024-02-27 21:27:02 +00:00
parent 49a913a328
commit 99266d9c92
2 changed files with 77 additions and 30 deletions

View file

@ -7,6 +7,7 @@ subproblem of actually counting the heads in each *transformed* raw image.
Transcrowd: weakly-supervised crowd counting with transformers. Transcrowd: weakly-supervised crowd counting with transformers.
Science China Information Sciences, 65(6), 160104. Science China Information Sciences, 65(6), 160104.
""" """
from typing import Optional
from functools import partial from functools import partial
import numpy as np import numpy as np
@ -69,7 +70,7 @@ class VisionTransformerGAP(VisionTransformer):
# the sole input which the transformer would need to learn to encode # the sole input which the transformer would need to learn to encode
# whatever it learnt from input into that token. # whatever it learnt from input into that token.
# Source: https://datascience.stackexchange.com/a/110637 # Source: https://datascience.stackexchange.com/a/110637
# That said, I don't think this is useful in this case... # That said, I don't think this is useful for GAP...
cls_tokens = self.cls_token.expand(B, -1, -1) cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1) # [[cls_token, x_i, ...]...] x = torch.cat((cls_tokens, x), dim=1) # [[cls_token, x_i, ...]...]
@ -105,3 +106,38 @@ class STNet_VisionTransformerGAP(VisionTransformerGAP):
def forward(self, x): def forward(self, x):
x = self.stnet(x) x = self.stnet(x)
return super(STNet_VisionTransformerGAP, self).forward(x) return super(STNet_VisionTransformerGAP, self).forward(x)
@register_model
def base_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
model = VisionTransformerGAP(
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs
)
model.default_cfg = _cfg()
if pth_tar is not None:
checkpoint = torch.load(pth_tar)
model.load_state_dict(checkpoint["model"], strict=False)
print("Loaded pre-trained pth.tar from \'{}\'".format(pth_tar))
return model
@register_model
def stn_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
model = STNet_VisionTransformerGAP(
img_shape=torch.Size((3, 384, 384)),
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs
)
model.default_cfg = _cfg()
if pth_tar is not None:
checkpoint = torch.load(pth_tar)
model.load_state_dict(checkpoint["model"], strict=False)
print("Loaded pre-trained pth.tar from \'{}\'".format(pth_tar))
return model

View file

@ -1,46 +1,63 @@
import os
import random
from typing import Optional
from argparse import Namespace from argparse import Namespace
import timm import timm
import torch import torch
import torch.nn as nn
import torch.multiprocessing as torch_mp import torch.multiprocessing as torch_mp
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import nni import nni
import logging import logging
import numpy as np
from model.csrnet import CSRNet from model.transcrowd_gap import VisionTransformerGAP
from model.reverse_perspective import PerspectiveEstimator
from arguments import args, ret_args from arguments import args, ret_args
logger = logging.getLogger("train-revpers") logger = logging.getLogger("train")
def setup_process_group(
rank: int,
world_size: int,
master_addr: str = "localhost",
master_port: Optional[np.ushort] = None
):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = (
str(random.randint(40000, 65545))
if master_port is None
else str(master_port)
)
# join point!
torch.distributed.init_process_group(
backend="nccl", rank=rank, world_size=world_size
)
# TODO:
# The shape for each batch in transcrowd is [3, 384, 384],
# this is due to images being cropped before training.
# To preserve image semantics wrt the entire layout, we want to apply cropping
# i.e., as encoder input during the inference/training pipeline.
# This should be okay since our transformations are all deterministic?
# not sure...
# We use 2 separate networks as opposed to 1 whole network --
# this is more flexible, as we only train one of them...
def gen_csrnet(pth_tar: str = None) -> CSRNet:
if pth_tar is not None:
model = CSRNet(load_weights=True)
checkpoint = torch.load(pth_tar)
model.load_state_dict(checkpoint["state_dict"], strict=False)
else:
model = CSRNet(load_weights=False)
return model
def gen_revpers(pth_tar: str = None, **kwargs) -> PerspectiveEstimator:
model = PerspectiveEstimator(**kwargs)
if pth_tar is not None:
checkpoint = torch.load(pth_tar)
model.load_state_dict(checkpoint["state_dict"], strict=False)
return model
def build_train_loader(): def build_train_loader():
pass pass
def build_valid_loader(): def build_valid_loader():
pass pass
def train_one_epoch( def train_one_epoch(
train_loader: DataLoader, train_loader: DataLoader,
revpers_net: PerspectiveEstimator, model: VisionTransformerGAP,
csr_net: CSRNet,
criterion, criterion,
optimizer, optimizer,
scheduler, scheduler,
@ -59,16 +76,10 @@ def train_one_epoch(
# In one epoch, for each training sample # In one epoch, for each training sample
for i, (fname, img, gt_count) in enumerate(train_loader): for i, (fname, img, gt_count) in enumerate(train_loader):
# move stuff to device
# fpass (revpers) # fpass (revpers)
img = img.cuda() img = img.cuda()
out_revpers = revpers_net(img)
# We need to perform image transformation here...
img = img.cpu()
# fpass (csrnet -- do not train)
img = img.cuda()
out_csrnet = csr_net(img)
# loss wrt revpers # loss wrt revpers
loss = criterion() loss = criterion()