Loss revamp & Renamed model to network

This commit is contained in:
Zhengyi Chen 2024-03-06 20:44:37 +00:00
parent 0d35d607fe
commit 9d2a30a226
7 changed files with 44 additions and 32 deletions

View file

@ -16,11 +16,16 @@ import logging
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from model.transcrowd_gap import VisionTransformerGAP from network.transcrowd_gap import (
VisionTransformerGAP,
STNet_VisionTransformerGAP,
base_patch16_384_gap,
stn_patch16_384_gap
)
from arguments import args, ret_args from arguments import args, ret_args
import dataset import dataset
from dataset import * from dataset import *
from model.transcrowd_gap import * # from model.transcrowd_gap import *
from checkpoint import save_checkpoint from checkpoint import save_checkpoint
logger = logging.getLogger("train") logger = logging.getLogger("train")
@ -171,7 +176,7 @@ def worker(rank: int, args: Namespace):
weight_decay=args.weight_decay weight_decay=args.weight_decay
) )
scheduler = torch.optim.lr_scheduler.MultiStepLR( scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[300], gamma=.1, last_epoch=-1 optimizer, milestones=[50], gamma=.1, last_epoch=-1
) )
# Checkpointing # Checkpointing
@ -267,39 +272,49 @@ def train_one_epoch(
gt_count_whole = gt_count_whole.cuda() gt_count_whole = gt_count_whole.cuda()
device_type = "cuda" device_type = "cuda"
# Desperate measure to reduce mem footprint...
# with torch.autocast(device_type):
# fpass # fpass
out, gt_count = model(img, kpoint) out, gt_count = model(img, kpoint)
# loss
loss = criterion(out, gt_count) # wrt. transformer # loss & bpass & etc.
if args.export_to_h5: if isinstance(model.module, STNet_VisionTransformerGAP):
train_df.loc[epoch * i, "l1loss"] = float(loss.item()) loss_xformer = criterion(out, gt_count)
else: loss_stn = (
writer.add_scalar( F.mse_loss(
"L1-loss wrt. xformer (train)", loss, epoch * i gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
gt_count_whole)
+ F.sigmoid(
gt_count.view(batch_size, -1).var(dim=1).mean())
) )
loss_stn.requires_grad = True
loss += ( optimizer.zero_grad(set_to_none=True)
F.mse_loss( # stn: info retainment # Accum first for STN
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True), loss_stn.backward(
gt_count_whole) inputs=list(model.module.stnet.parameters()), retain_graph=True
+ F.threshold( # stn: perspective correction
gt_count.view(batch_size, -1).var(dim=1).mean(),
threshold=loss.item(),
value=loss.item()
) )
) # Avoid double accum
if args.export_to_h5: for param in model.module.stnet.parameters():
train_df.loc[epoch * i, "composite-loss"] = float(loss.item()) param.grad = None
# Then, backward for entire net
loss_xformer.backward()
if args.export_to_h5:
train_df.loc[epoch * i, "l1loss"] = float(loss_xformer.item())
train_df.loc[epoch * i, "composite-loss"] = float(loss_stn.item())
else:
writer.add_scalar("l1loss", loss_xformer, epoch * i)
writer.add_scalar("composite-loss", loss_stn, epoch * i)
else: else:
writer.add_scalar("Composite loss (train)", loss, epoch * i) loss = criterion(out, gt_count)
if args.export_to_h5:
train_df.loc[epoch * i, "l1loss"] = float(loss.item())
else:
writer.add_scalar("l1loss", loss, epoch * i)
# free grad from mem # free grad from mem
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
# bpass # bpass
loss.backward() loss.backward()
# optimizer # optimizer
optimizer.step() optimizer.step()
@ -355,9 +370,6 @@ def valid_one_epoch(test_loader, model, device, epoch, args):
mse += diff ** 2 mse += diff ** 2
if i % 5 == 0: if i % 5 == 0:
# with torch.no_grad():
# img_xformed = model.stnet(img).to("cpu")
# xformed.append(img_xformed)
print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} |".format( print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} |".format(
fname[0], fname[0],
torch.sum(gt_count_whole).item(), torch.sum(gt_count_whole).item(),