From 99266d9c929b0fc74a885622712ddf08b85b7dfb Mon Sep 17 00:00:00 2001 From: Zhengyi Chen Date: Tue, 27 Feb 2024 21:27:02 +0000 Subject: [PATCH] Added sth more --- model/transcrowd_gap.py | 40 +++++++++++++++++++-- train-revpers.py => train.py | 67 +++++++++++++++++++++--------------- 2 files changed, 77 insertions(+), 30 deletions(-) rename train-revpers.py => train.py (60%) diff --git a/model/transcrowd_gap.py b/model/transcrowd_gap.py index c6158e6..ebd807f 100644 --- a/model/transcrowd_gap.py +++ b/model/transcrowd_gap.py @@ -7,6 +7,7 @@ subproblem of actually counting the heads in each *transformed* raw image. Transcrowd: weakly-supervised crowd counting with transformers. Science China Information Sciences, 65(6), 160104. """ +from typing import Optional from functools import partial import numpy as np @@ -69,7 +70,7 @@ class VisionTransformerGAP(VisionTransformer): # the sole input which the transformer would need to learn to encode # whatever it learnt from input into that token. # Source: https://datascience.stackexchange.com/a/110637 - # That said, I don't think this is useful in this case... + # That said, I don't think this is useful for GAP... cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) # [[cls_token, x_i, ...]...] @@ -104,4 +105,39 @@ class STNet_VisionTransformerGAP(VisionTransformerGAP): def forward(self, x): x = self.stnet(x) - return super(STNet_VisionTransformerGAP, self).forward(x) \ No newline at end of file + return super(STNet_VisionTransformerGAP, self).forward(x) + + +@register_model +def base_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs): + model = VisionTransformerGAP( + img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, + mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) + model.default_cfg = _cfg() + + if pth_tar is not None: + checkpoint = torch.load(pth_tar) + model.load_state_dict(checkpoint["model"], strict=False) + print("Loaded pre-trained pth.tar from \'{}\'".format(pth_tar)) + + return model + + +@register_model +def stn_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs): + model = STNet_VisionTransformerGAP( + img_shape=torch.Size((3, 384, 384)), + img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, + mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) + model.default_cfg = _cfg() + + if pth_tar is not None: + checkpoint = torch.load(pth_tar) + model.load_state_dict(checkpoint["model"], strict=False) + print("Loaded pre-trained pth.tar from \'{}\'".format(pth_tar)) + + return model \ No newline at end of file diff --git a/train-revpers.py b/train.py similarity index 60% rename from train-revpers.py rename to train.py index b9a29b4..9beb2ec 100644 --- a/train-revpers.py +++ b/train.py @@ -1,46 +1,63 @@ +import os +import random +from typing import Optional from argparse import Namespace import timm import torch +import torch.nn as nn import torch.multiprocessing as torch_mp from torch.utils.data import DataLoader import nni import logging +import numpy as np -from model.csrnet import CSRNet -from model.reverse_perspective import PerspectiveEstimator +from model.transcrowd_gap import VisionTransformerGAP from arguments import args, ret_args -logger = logging.getLogger("train-revpers") +logger = logging.getLogger("train") + + +def setup_process_group( + rank: int, + world_size: int, + master_addr: str = "localhost", + master_port: Optional[np.ushort] = None +): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = ( + str(random.randint(40000, 65545)) + if master_port is None + else str(master_port) + ) + + # join point! + torch.distributed.init_process_group( + backend="nccl", rank=rank, world_size=world_size + ) + +# TODO: +# The shape for each batch in transcrowd is [3, 384, 384], +# this is due to images being cropped before training. +# To preserve image semantics wrt the entire layout, we want to apply cropping +# i.e., as encoder input during the inference/training pipeline. +# This should be okay since our transformations are all deterministic? +# not sure... + -# We use 2 separate networks as opposed to 1 whole network -- -# this is more flexible, as we only train one of them... -def gen_csrnet(pth_tar: str = None) -> CSRNet: - if pth_tar is not None: - model = CSRNet(load_weights=True) - checkpoint = torch.load(pth_tar) - model.load_state_dict(checkpoint["state_dict"], strict=False) - else: - model = CSRNet(load_weights=False) - return model -def gen_revpers(pth_tar: str = None, **kwargs) -> PerspectiveEstimator: - model = PerspectiveEstimator(**kwargs) - if pth_tar is not None: - checkpoint = torch.load(pth_tar) - model.load_state_dict(checkpoint["state_dict"], strict=False) - return model def build_train_loader(): pass + def build_valid_loader(): pass + def train_one_epoch( train_loader: DataLoader, - revpers_net: PerspectiveEstimator, - csr_net: CSRNet, + model: VisionTransformerGAP, criterion, optimizer, scheduler, @@ -59,16 +76,10 @@ def train_one_epoch( # In one epoch, for each training sample for i, (fname, img, gt_count) in enumerate(train_loader): + # move stuff to device # fpass (revpers) img = img.cuda() - out_revpers = revpers_net(img) - # We need to perform image transformation here... - - img = img.cpu() - # fpass (csrnet -- do not train) - img = img.cuda() - out_csrnet = csr_net(img) # loss wrt revpers loss = criterion()