TEST: use tensorboard for stuff

This commit is contained in:
Zhengyi Chen 2024-03-04 20:32:26 +00:00
parent 2d31162c58
commit 524ee03187
10 changed files with 212 additions and 12 deletions

View file

@ -9,6 +9,8 @@ import torch
import torch.nn as nn
import torch.multiprocessing as torch_mp
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision
import nni
import logging
import numpy as np
@ -17,10 +19,11 @@ from model.transcrowd_gap import VisionTransformerGAP
from arguments import args, ret_args
import dataset
from dataset import *
from model.transcrowd_gap import base_patch16_384_gap, stn_patch16_384_gap
from model.transcrowd_gap import *
from checkpoint import save_checkpoint
logger = logging.getLogger("train")
writer = SummaryWriter(args.save_path + "/tensorboard-run")
def setup_process_group(
rank: int,
@ -196,7 +199,7 @@ def worker(rank: int, args: Namespace):
# Validate
if epoch % 5 == 0 or args.debug:
prec1 = valid_one_epoch(test_loader, model, device, args)
prec1 = valid_one_epoch(test_loader, model, device, epoch, args)
end_valid = time.time()
is_best = prec1 < args.best_pred
args.best_pred = min(prec1, args.best_pred)
@ -255,11 +258,14 @@ def train_one_epoch(
gt_count_whole = gt_count_whole.cuda()
device_type = "cuda"
# Desperate measure to reduce mem footprint...
with torch.autocast(device_type):
# fpass
out, gt_count = model(img, kpoint)
# loss
loss = criterion(out, gt_count) # wrt. transformer
writer.add_scalar("L1-loss wrt. xformer (train)", loss, epoch * i)
loss += (
F.mse_loss( # stn: info retainment
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
@ -270,6 +276,7 @@ def train_one_epoch(
value=loss.item()
)
)
writer.add_scalar("Composite loss (train)", loss, epoch * i)
# free grad from mem
optimizer.zero_grad(set_to_none=True)
@ -283,10 +290,13 @@ def train_one_epoch(
if args.debug:
break
# Flush writer
writer.flush()
scheduler.step()
def valid_one_epoch(test_loader, model, device, args):
def valid_one_epoch(test_loader, model, device, epoch, args):
print("[valid_one_epoch] Validating...")
batch_size = 1
model.eval()
@ -295,6 +305,7 @@ def valid_one_epoch(test_loader, model, device, args):
mse = .0
visi = []
index = 0
xformed = []
for i, (fname, img, kpoint, gt_count_whole) in enumerate(test_loader):
kpoint = kpoint.type(torch.FloatTensor)
@ -324,6 +335,10 @@ def valid_one_epoch(test_loader, model, device, args):
mse += diff ** 2
if i % 5 == 0:
if isinstance(model, STNet_VisionTransformerGAP):
with torch.no_grad():
img_xformed = model.stnet(img).to("cpu")
xformed.append(img_xformed)
print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} |".format(
fname[0],
torch.sum(gt_count_whole).item(),
@ -332,10 +347,16 @@ def valid_one_epoch(test_loader, model, device, args):
mae = mae * 1.0 / (len(test_loader) * batch_size)
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
writer.add_scalar("MAE (valid)", mae, epoch)
writer.add_scalar("MSE (valid)", mse, epoch)
if len(xformed) != 0:
img_grid = torchvision.utils.make_grid(xformed)
writer.add_image("STN: transformed image", img_grid, epoch)
nni.report_intermediate_result(mae)
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
mae=mae, mse=mse
))
writer.flush()
return mae
@ -353,7 +374,7 @@ if __name__ == "__main__":
worker,
args=(combined_params, ), # rank supplied at callee as 1st param
# also above *has* to be 1-tuple else runtime expands Namespace.
nprocs=combined_params.world_size,
nprocs=combined_params.ddp_world_size,
)
else:
# No DDP, run in current thread