TEST: use tensorboard for stuff
This commit is contained in:
parent
2d31162c58
commit
524ee03187
10 changed files with 212 additions and 12 deletions
29
train.py
29
train.py
|
|
@ -9,6 +9,8 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.multiprocessing as torch_mp
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torchvision
|
||||
import nni
|
||||
import logging
|
||||
import numpy as np
|
||||
|
|
@ -17,10 +19,11 @@ from model.transcrowd_gap import VisionTransformerGAP
|
|||
from arguments import args, ret_args
|
||||
import dataset
|
||||
from dataset import *
|
||||
from model.transcrowd_gap import base_patch16_384_gap, stn_patch16_384_gap
|
||||
from model.transcrowd_gap import *
|
||||
from checkpoint import save_checkpoint
|
||||
|
||||
logger = logging.getLogger("train")
|
||||
writer = SummaryWriter(args.save_path + "/tensorboard-run")
|
||||
|
||||
def setup_process_group(
|
||||
rank: int,
|
||||
|
|
@ -196,7 +199,7 @@ def worker(rank: int, args: Namespace):
|
|||
|
||||
# Validate
|
||||
if epoch % 5 == 0 or args.debug:
|
||||
prec1 = valid_one_epoch(test_loader, model, device, args)
|
||||
prec1 = valid_one_epoch(test_loader, model, device, epoch, args)
|
||||
end_valid = time.time()
|
||||
is_best = prec1 < args.best_pred
|
||||
args.best_pred = min(prec1, args.best_pred)
|
||||
|
|
@ -255,11 +258,14 @@ def train_one_epoch(
|
|||
gt_count_whole = gt_count_whole.cuda()
|
||||
device_type = "cuda"
|
||||
|
||||
# Desperate measure to reduce mem footprint...
|
||||
with torch.autocast(device_type):
|
||||
# fpass
|
||||
out, gt_count = model(img, kpoint)
|
||||
# loss
|
||||
loss = criterion(out, gt_count) # wrt. transformer
|
||||
writer.add_scalar("L1-loss wrt. xformer (train)", loss, epoch * i)
|
||||
|
||||
loss += (
|
||||
F.mse_loss( # stn: info retainment
|
||||
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
|
||||
|
|
@ -270,6 +276,7 @@ def train_one_epoch(
|
|||
value=loss.item()
|
||||
)
|
||||
)
|
||||
writer.add_scalar("Composite loss (train)", loss, epoch * i)
|
||||
|
||||
# free grad from mem
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
|
@ -283,10 +290,13 @@ def train_one_epoch(
|
|||
if args.debug:
|
||||
break
|
||||
|
||||
# Flush writer
|
||||
writer.flush()
|
||||
|
||||
scheduler.step()
|
||||
|
||||
|
||||
def valid_one_epoch(test_loader, model, device, args):
|
||||
def valid_one_epoch(test_loader, model, device, epoch, args):
|
||||
print("[valid_one_epoch] Validating...")
|
||||
batch_size = 1
|
||||
model.eval()
|
||||
|
|
@ -295,6 +305,7 @@ def valid_one_epoch(test_loader, model, device, args):
|
|||
mse = .0
|
||||
visi = []
|
||||
index = 0
|
||||
xformed = []
|
||||
|
||||
for i, (fname, img, kpoint, gt_count_whole) in enumerate(test_loader):
|
||||
kpoint = kpoint.type(torch.FloatTensor)
|
||||
|
|
@ -324,6 +335,10 @@ def valid_one_epoch(test_loader, model, device, args):
|
|||
mse += diff ** 2
|
||||
|
||||
if i % 5 == 0:
|
||||
if isinstance(model, STNet_VisionTransformerGAP):
|
||||
with torch.no_grad():
|
||||
img_xformed = model.stnet(img).to("cpu")
|
||||
xformed.append(img_xformed)
|
||||
print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} |".format(
|
||||
fname[0],
|
||||
torch.sum(gt_count_whole).item(),
|
||||
|
|
@ -332,10 +347,16 @@ def valid_one_epoch(test_loader, model, device, args):
|
|||
|
||||
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
||||
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
|
||||
writer.add_scalar("MAE (valid)", mae, epoch)
|
||||
writer.add_scalar("MSE (valid)", mse, epoch)
|
||||
if len(xformed) != 0:
|
||||
img_grid = torchvision.utils.make_grid(xformed)
|
||||
writer.add_image("STN: transformed image", img_grid, epoch)
|
||||
nni.report_intermediate_result(mae)
|
||||
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
|
||||
mae=mae, mse=mse
|
||||
))
|
||||
writer.flush()
|
||||
return mae
|
||||
|
||||
|
||||
|
|
@ -353,7 +374,7 @@ if __name__ == "__main__":
|
|||
worker,
|
||||
args=(combined_params, ), # rank supplied at callee as 1st param
|
||||
# also above *has* to be 1-tuple else runtime expands Namespace.
|
||||
nprocs=combined_params.world_size,
|
||||
nprocs=combined_params.ddp_world_size,
|
||||
)
|
||||
else:
|
||||
# No DDP, run in current thread
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue