Loss revamp & Renamed model to network
This commit is contained in:
parent
0d35d607fe
commit
9d2a30a226
7 changed files with 44 additions and 32 deletions
76
train.py
76
train.py
|
|
@ -16,11 +16,16 @@ import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from model.transcrowd_gap import VisionTransformerGAP
|
from network.transcrowd_gap import (
|
||||||
|
VisionTransformerGAP,
|
||||||
|
STNet_VisionTransformerGAP,
|
||||||
|
base_patch16_384_gap,
|
||||||
|
stn_patch16_384_gap
|
||||||
|
)
|
||||||
from arguments import args, ret_args
|
from arguments import args, ret_args
|
||||||
import dataset
|
import dataset
|
||||||
from dataset import *
|
from dataset import *
|
||||||
from model.transcrowd_gap import *
|
# from model.transcrowd_gap import *
|
||||||
from checkpoint import save_checkpoint
|
from checkpoint import save_checkpoint
|
||||||
|
|
||||||
logger = logging.getLogger("train")
|
logger = logging.getLogger("train")
|
||||||
|
|
@ -171,7 +176,7 @@ def worker(rank: int, args: Namespace):
|
||||||
weight_decay=args.weight_decay
|
weight_decay=args.weight_decay
|
||||||
)
|
)
|
||||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||||
optimizer, milestones=[300], gamma=.1, last_epoch=-1
|
optimizer, milestones=[50], gamma=.1, last_epoch=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Checkpointing
|
# Checkpointing
|
||||||
|
|
@ -267,39 +272,49 @@ def train_one_epoch(
|
||||||
gt_count_whole = gt_count_whole.cuda()
|
gt_count_whole = gt_count_whole.cuda()
|
||||||
device_type = "cuda"
|
device_type = "cuda"
|
||||||
|
|
||||||
# Desperate measure to reduce mem footprint...
|
|
||||||
# with torch.autocast(device_type):
|
|
||||||
# fpass
|
# fpass
|
||||||
out, gt_count = model(img, kpoint)
|
out, gt_count = model(img, kpoint)
|
||||||
# loss
|
|
||||||
loss = criterion(out, gt_count) # wrt. transformer
|
# loss & bpass & etc.
|
||||||
if args.export_to_h5:
|
if isinstance(model.module, STNet_VisionTransformerGAP):
|
||||||
train_df.loc[epoch * i, "l1loss"] = float(loss.item())
|
loss_xformer = criterion(out, gt_count)
|
||||||
else:
|
loss_stn = (
|
||||||
writer.add_scalar(
|
F.mse_loss(
|
||||||
"L1-loss wrt. xformer (train)", loss, epoch * i
|
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
|
||||||
|
gt_count_whole)
|
||||||
|
+ F.sigmoid(
|
||||||
|
gt_count.view(batch_size, -1).var(dim=1).mean())
|
||||||
)
|
)
|
||||||
|
loss_stn.requires_grad = True
|
||||||
loss += (
|
optimizer.zero_grad(set_to_none=True)
|
||||||
F.mse_loss( # stn: info retainment
|
# Accum first for STN
|
||||||
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
|
loss_stn.backward(
|
||||||
gt_count_whole)
|
inputs=list(model.module.stnet.parameters()), retain_graph=True
|
||||||
+ F.threshold( # stn: perspective correction
|
|
||||||
gt_count.view(batch_size, -1).var(dim=1).mean(),
|
|
||||||
threshold=loss.item(),
|
|
||||||
value=loss.item()
|
|
||||||
)
|
)
|
||||||
)
|
# Avoid double accum
|
||||||
if args.export_to_h5:
|
for param in model.module.stnet.parameters():
|
||||||
train_df.loc[epoch * i, "composite-loss"] = float(loss.item())
|
param.grad = None
|
||||||
|
# Then, backward for entire net
|
||||||
|
loss_xformer.backward()
|
||||||
|
|
||||||
|
if args.export_to_h5:
|
||||||
|
train_df.loc[epoch * i, "l1loss"] = float(loss_xformer.item())
|
||||||
|
train_df.loc[epoch * i, "composite-loss"] = float(loss_stn.item())
|
||||||
|
else:
|
||||||
|
writer.add_scalar("l1loss", loss_xformer, epoch * i)
|
||||||
|
writer.add_scalar("composite-loss", loss_stn, epoch * i)
|
||||||
else:
|
else:
|
||||||
writer.add_scalar("Composite loss (train)", loss, epoch * i)
|
loss = criterion(out, gt_count)
|
||||||
|
if args.export_to_h5:
|
||||||
|
train_df.loc[epoch * i, "l1loss"] = float(loss.item())
|
||||||
|
else:
|
||||||
|
writer.add_scalar("l1loss", loss, epoch * i)
|
||||||
|
|
||||||
# free grad from mem
|
# free grad from mem
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
# bpass
|
# bpass
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# optimizer
|
# optimizer
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
@ -355,9 +370,6 @@ def valid_one_epoch(test_loader, model, device, epoch, args):
|
||||||
mse += diff ** 2
|
mse += diff ** 2
|
||||||
|
|
||||||
if i % 5 == 0:
|
if i % 5 == 0:
|
||||||
# with torch.no_grad():
|
|
||||||
# img_xformed = model.stnet(img).to("cpu")
|
|
||||||
# xformed.append(img_xformed)
|
|
||||||
print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} |".format(
|
print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} |".format(
|
||||||
fname[0],
|
fname[0],
|
||||||
torch.sum(gt_count_whole).item(),
|
torch.sum(gt_count_whole).item(),
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue