Added sth more
This commit is contained in:
parent
49a913a328
commit
99266d9c92
2 changed files with 77 additions and 30 deletions
|
|
@ -7,6 +7,7 @@ subproblem of actually counting the heads in each *transformed* raw image.
|
|||
Transcrowd: weakly-supervised crowd counting with transformers.
|
||||
Science China Information Sciences, 65(6), 160104.
|
||||
"""
|
||||
from typing import Optional
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -69,7 +70,7 @@ class VisionTransformerGAP(VisionTransformer):
|
|||
# the sole input which the transformer would need to learn to encode
|
||||
# whatever it learnt from input into that token.
|
||||
# Source: https://datascience.stackexchange.com/a/110637
|
||||
# That said, I don't think this is useful in this case...
|
||||
# That said, I don't think this is useful for GAP...
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1) # [[cls_token, x_i, ...]...]
|
||||
|
||||
|
|
@ -104,4 +105,39 @@ class STNet_VisionTransformerGAP(VisionTransformerGAP):
|
|||
|
||||
def forward(self, x):
|
||||
x = self.stnet(x)
|
||||
return super(STNet_VisionTransformerGAP, self).forward(x)
|
||||
return super(STNet_VisionTransformerGAP, self).forward(x)
|
||||
|
||||
|
||||
@register_model
|
||||
def base_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
|
||||
model = VisionTransformerGAP(
|
||||
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
||||
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
**kwargs
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
|
||||
if pth_tar is not None:
|
||||
checkpoint = torch.load(pth_tar)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
print("Loaded pre-trained pth.tar from \'{}\'".format(pth_tar))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def stn_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
|
||||
model = STNet_VisionTransformerGAP(
|
||||
img_shape=torch.Size((3, 384, 384)),
|
||||
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
||||
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
**kwargs
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
|
||||
if pth_tar is not None:
|
||||
checkpoint = torch.load(pth_tar)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
print("Loaded pre-trained pth.tar from \'{}\'".format(pth_tar))
|
||||
|
||||
return model
|
||||
|
|
@ -1,46 +1,63 @@
|
|||
import os
|
||||
import random
|
||||
from typing import Optional
|
||||
from argparse import Namespace
|
||||
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.multiprocessing as torch_mp
|
||||
from torch.utils.data import DataLoader
|
||||
import nni
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from model.csrnet import CSRNet
|
||||
from model.reverse_perspective import PerspectiveEstimator
|
||||
from model.transcrowd_gap import VisionTransformerGAP
|
||||
from arguments import args, ret_args
|
||||
|
||||
logger = logging.getLogger("train-revpers")
|
||||
logger = logging.getLogger("train")
|
||||
|
||||
|
||||
def setup_process_group(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
master_addr: str = "localhost",
|
||||
master_port: Optional[np.ushort] = None
|
||||
):
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = (
|
||||
str(random.randint(40000, 65545))
|
||||
if master_port is None
|
||||
else str(master_port)
|
||||
)
|
||||
|
||||
# join point!
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl", rank=rank, world_size=world_size
|
||||
)
|
||||
|
||||
# TODO:
|
||||
# The shape for each batch in transcrowd is [3, 384, 384],
|
||||
# this is due to images being cropped before training.
|
||||
# To preserve image semantics wrt the entire layout, we want to apply cropping
|
||||
# i.e., as encoder input during the inference/training pipeline.
|
||||
# This should be okay since our transformations are all deterministic?
|
||||
# not sure...
|
||||
|
||||
|
||||
# We use 2 separate networks as opposed to 1 whole network --
|
||||
# this is more flexible, as we only train one of them...
|
||||
def gen_csrnet(pth_tar: str = None) -> CSRNet:
|
||||
if pth_tar is not None:
|
||||
model = CSRNet(load_weights=True)
|
||||
checkpoint = torch.load(pth_tar)
|
||||
model.load_state_dict(checkpoint["state_dict"], strict=False)
|
||||
else:
|
||||
model = CSRNet(load_weights=False)
|
||||
return model
|
||||
|
||||
def gen_revpers(pth_tar: str = None, **kwargs) -> PerspectiveEstimator:
|
||||
model = PerspectiveEstimator(**kwargs)
|
||||
if pth_tar is not None:
|
||||
checkpoint = torch.load(pth_tar)
|
||||
model.load_state_dict(checkpoint["state_dict"], strict=False)
|
||||
return model
|
||||
|
||||
def build_train_loader():
|
||||
pass
|
||||
|
||||
|
||||
def build_valid_loader():
|
||||
pass
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
train_loader: DataLoader,
|
||||
revpers_net: PerspectiveEstimator,
|
||||
csr_net: CSRNet,
|
||||
model: VisionTransformerGAP,
|
||||
criterion,
|
||||
optimizer,
|
||||
scheduler,
|
||||
|
|
@ -59,16 +76,10 @@ def train_one_epoch(
|
|||
|
||||
# In one epoch, for each training sample
|
||||
for i, (fname, img, gt_count) in enumerate(train_loader):
|
||||
# move stuff to device
|
||||
# fpass (revpers)
|
||||
img = img.cuda()
|
||||
out_revpers = revpers_net(img)
|
||||
# We need to perform image transformation here...
|
||||
|
||||
img = img.cpu()
|
||||
|
||||
# fpass (csrnet -- do not train)
|
||||
img = img.cuda()
|
||||
out_csrnet = csr_net(img)
|
||||
# loss wrt revpers
|
||||
loss = criterion()
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue