Loss revamp & Renamed model to network

This commit is contained in:
Zhengyi Chen 2024-03-06 20:44:37 +00:00
parent 0d35d607fe
commit 9d2a30a226
7 changed files with 44 additions and 32 deletions

View file

@ -16,11 +16,16 @@ import logging
import numpy as np
import pandas as pd
from model.transcrowd_gap import VisionTransformerGAP
from network.transcrowd_gap import (
VisionTransformerGAP,
STNet_VisionTransformerGAP,
base_patch16_384_gap,
stn_patch16_384_gap
)
from arguments import args, ret_args
import dataset
from dataset import *
from model.transcrowd_gap import *
# from model.transcrowd_gap import *
from checkpoint import save_checkpoint
logger = logging.getLogger("train")
@ -171,7 +176,7 @@ def worker(rank: int, args: Namespace):
weight_decay=args.weight_decay
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[300], gamma=.1, last_epoch=-1
optimizer, milestones=[50], gamma=.1, last_epoch=-1
)
# Checkpointing
@ -267,39 +272,49 @@ def train_one_epoch(
gt_count_whole = gt_count_whole.cuda()
device_type = "cuda"
# Desperate measure to reduce mem footprint...
# with torch.autocast(device_type):
# fpass
out, gt_count = model(img, kpoint)
# loss
loss = criterion(out, gt_count) # wrt. transformer
if args.export_to_h5:
train_df.loc[epoch * i, "l1loss"] = float(loss.item())
else:
writer.add_scalar(
"L1-loss wrt. xformer (train)", loss, epoch * i
# loss & bpass & etc.
if isinstance(model.module, STNet_VisionTransformerGAP):
loss_xformer = criterion(out, gt_count)
loss_stn = (
F.mse_loss(
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
gt_count_whole)
+ F.sigmoid(
gt_count.view(batch_size, -1).var(dim=1).mean())
)
loss += (
F.mse_loss( # stn: info retainment
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
gt_count_whole)
+ F.threshold( # stn: perspective correction
gt_count.view(batch_size, -1).var(dim=1).mean(),
threshold=loss.item(),
value=loss.item()
loss_stn.requires_grad = True
optimizer.zero_grad(set_to_none=True)
# Accum first for STN
loss_stn.backward(
inputs=list(model.module.stnet.parameters()), retain_graph=True
)
)
if args.export_to_h5:
train_df.loc[epoch * i, "composite-loss"] = float(loss.item())
# Avoid double accum
for param in model.module.stnet.parameters():
param.grad = None
# Then, backward for entire net
loss_xformer.backward()
if args.export_to_h5:
train_df.loc[epoch * i, "l1loss"] = float(loss_xformer.item())
train_df.loc[epoch * i, "composite-loss"] = float(loss_stn.item())
else:
writer.add_scalar("l1loss", loss_xformer, epoch * i)
writer.add_scalar("composite-loss", loss_stn, epoch * i)
else:
writer.add_scalar("Composite loss (train)", loss, epoch * i)
loss = criterion(out, gt_count)
if args.export_to_h5:
train_df.loc[epoch * i, "l1loss"] = float(loss.item())
else:
writer.add_scalar("l1loss", loss, epoch * i)
# free grad from mem
optimizer.zero_grad(set_to_none=True)
# free grad from mem
optimizer.zero_grad(set_to_none=True)
# bpass
loss.backward()
# bpass
loss.backward()
# optimizer
optimizer.step()
@ -355,9 +370,6 @@ def valid_one_epoch(test_loader, model, device, epoch, args):
mse += diff ** 2
if i % 5 == 0:
# with torch.no_grad():
# img_xformed = model.stnet(img).to("cpu")
# xformed.append(img_xformed)
print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} |".format(
fname[0],
torch.sum(gt_count_whole).item(),