From 9d2a30a22608e688303e93843e68beeca2f67e0a Mon Sep 17 00:00:00 2001 From: Zhengyi Chen Date: Wed, 6 Mar 2024 20:44:37 +0000 Subject: [PATCH] Loss revamp & Renamed model to network --- {model => network}/csrnet.py | 0 {model => network}/glue.py | 0 {model => network}/reverse_perspective.py | 0 {model => network}/revpers_csrnet.py | 0 {model => network}/stn.py | 0 {model => network}/transcrowd_gap.py | 0 train.py | 76 +++++++++++++---------- 7 files changed, 44 insertions(+), 32 deletions(-) rename {model => network}/csrnet.py (100%) rename {model => network}/glue.py (100%) rename {model => network}/reverse_perspective.py (100%) rename {model => network}/revpers_csrnet.py (100%) rename {model => network}/stn.py (100%) rename {model => network}/transcrowd_gap.py (100%) diff --git a/model/csrnet.py b/network/csrnet.py similarity index 100% rename from model/csrnet.py rename to network/csrnet.py diff --git a/model/glue.py b/network/glue.py similarity index 100% rename from model/glue.py rename to network/glue.py diff --git a/model/reverse_perspective.py b/network/reverse_perspective.py similarity index 100% rename from model/reverse_perspective.py rename to network/reverse_perspective.py diff --git a/model/revpers_csrnet.py b/network/revpers_csrnet.py similarity index 100% rename from model/revpers_csrnet.py rename to network/revpers_csrnet.py diff --git a/model/stn.py b/network/stn.py similarity index 100% rename from model/stn.py rename to network/stn.py diff --git a/model/transcrowd_gap.py b/network/transcrowd_gap.py similarity index 100% rename from model/transcrowd_gap.py rename to network/transcrowd_gap.py diff --git a/train.py b/train.py index 474ac53..a6491d1 100644 --- a/train.py +++ b/train.py @@ -16,11 +16,16 @@ import logging import numpy as np import pandas as pd -from model.transcrowd_gap import VisionTransformerGAP +from network.transcrowd_gap import ( + VisionTransformerGAP, + STNet_VisionTransformerGAP, + base_patch16_384_gap, + stn_patch16_384_gap +) from arguments import args, ret_args import dataset from dataset import * -from model.transcrowd_gap import * +# from model.transcrowd_gap import * from checkpoint import save_checkpoint logger = logging.getLogger("train") @@ -171,7 +176,7 @@ def worker(rank: int, args: Namespace): weight_decay=args.weight_decay ) scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, milestones=[300], gamma=.1, last_epoch=-1 + optimizer, milestones=[50], gamma=.1, last_epoch=-1 ) # Checkpointing @@ -267,39 +272,49 @@ def train_one_epoch( gt_count_whole = gt_count_whole.cuda() device_type = "cuda" - # Desperate measure to reduce mem footprint... - # with torch.autocast(device_type): # fpass out, gt_count = model(img, kpoint) - # loss - loss = criterion(out, gt_count) # wrt. transformer - if args.export_to_h5: - train_df.loc[epoch * i, "l1loss"] = float(loss.item()) - else: - writer.add_scalar( - "L1-loss wrt. xformer (train)", loss, epoch * i + + # loss & bpass & etc. + if isinstance(model.module, STNet_VisionTransformerGAP): + loss_xformer = criterion(out, gt_count) + loss_stn = ( + F.mse_loss( + gt_count.view(batch_size, -1).sum(axis=1, keepdim=True), + gt_count_whole) + + F.sigmoid( + gt_count.view(batch_size, -1).var(dim=1).mean()) ) - - loss += ( - F.mse_loss( # stn: info retainment - gt_count.view(batch_size, -1).sum(axis=1, keepdim=True), - gt_count_whole) - + F.threshold( # stn: perspective correction - gt_count.view(batch_size, -1).var(dim=1).mean(), - threshold=loss.item(), - value=loss.item() + loss_stn.requires_grad = True + optimizer.zero_grad(set_to_none=True) + # Accum first for STN + loss_stn.backward( + inputs=list(model.module.stnet.parameters()), retain_graph=True ) - ) - if args.export_to_h5: - train_df.loc[epoch * i, "composite-loss"] = float(loss.item()) + # Avoid double accum + for param in model.module.stnet.parameters(): + param.grad = None + # Then, backward for entire net + loss_xformer.backward() + + if args.export_to_h5: + train_df.loc[epoch * i, "l1loss"] = float(loss_xformer.item()) + train_df.loc[epoch * i, "composite-loss"] = float(loss_stn.item()) + else: + writer.add_scalar("l1loss", loss_xformer, epoch * i) + writer.add_scalar("composite-loss", loss_stn, epoch * i) else: - writer.add_scalar("Composite loss (train)", loss, epoch * i) + loss = criterion(out, gt_count) + if args.export_to_h5: + train_df.loc[epoch * i, "l1loss"] = float(loss.item()) + else: + writer.add_scalar("l1loss", loss, epoch * i) - # free grad from mem - optimizer.zero_grad(set_to_none=True) + # free grad from mem + optimizer.zero_grad(set_to_none=True) - # bpass - loss.backward() + # bpass + loss.backward() # optimizer optimizer.step() @@ -355,9 +370,6 @@ def valid_one_epoch(test_loader, model, device, epoch, args): mse += diff ** 2 if i % 5 == 0: - # with torch.no_grad(): - # img_xformed = model.stnet(img).to("cpu") - # xformed.append(img_xformed) print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} |".format( fname[0], torch.sum(gt_count_whole).item(),