mlp-project/train.py
2024-03-06 22:30:43 +00:00

425 lines
13 KiB
Python

import os
import random
import time
from typing import Optional
from argparse import Namespace
import timm
import torch
import torch.nn as nn
import torch.multiprocessing as torch_mp
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision
import nni
import logging
import numpy as np
import pandas as pd
from network.transcrowd_gap import (
VisionTransformerGAP,
STNet_VisionTransformerGAP,
base_patch16_384_gap,
stn_patch16_384_gap
)
from arguments import args, ret_args
import dataset
from dataset import *
# from model.transcrowd_gap import *
from checkpoint import save_checkpoint
logger = logging.getLogger("train")
if not args.export_to_h5:
writer = SummaryWriter(args.save_path + "/tensorboard-run")
else:
train_df = pd.DataFrame(columns=["l1loss", "composite-loss"], dtype=float)
train_stat_file = args.save_path + "/train_stats.h5"
test_df = pd.DataFrame(columns=["mse", "mae"], dtype=float)
test_stat_file = args.save_path + "/test_stats.h5"
def setup_process_group(
rank: int,
world_size: int,
master_addr: str = "localhost",
master_port: Optional[np.ushort] = None
):
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = (
str(random.randint(40000, 65545))
if master_port is None
else str(master_port)
)
# join point!
torch.distributed.init_process_group(
rank=rank, world_size=world_size
)
def build_train_loader(data_keys, args):
train_dataset = ListDataset(
data_keys,
shuffle = True,
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
]),
train = True,
batch_size = args.batch_size,
nr_workers = args.workers,
args = args
)
if args.use_ddp:
train_dist_sampler = torch.utils.data.distributed.DistributedSampler(
dataset=train_dataset
)
else:
train_dist_sampler = None
train_loader = DataLoader(
dataset=train_dataset,
sampler=train_dist_sampler,
batch_size=args.batch_size,
drop_last=False
)
return train_loader
def build_test_loader(data_keys, args):
test_dataset = ListDataset(
data_keys,
shuffle=False,
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
]),
args=args,
train=False
)
if args.use_ddp:
test_dist_sampler = torch.utils.data.distributed.DistributedSampler(
dataset=test_dataset
)
else:
test_dist_sampler = None
test_loader = DataLoader(
dataset=test_dataset,
sampler=test_dist_sampler,
batch_size=1
)
return test_loader
def worker(rank: int, args: Namespace):
world_size = args.ddp_world_size
# (If DDP) join after process group among processes
if args.use_ddp:
setup_process_group(rank, world_size)
# Setup device for one proc
if args.use_ddp and torch.cuda.is_available():
device = torch.device(rank)
elif torch.cuda.is_available():
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
args.gpus = [int(gpu_id) for gpu_id in args.gpus.split(",")]
device = None
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
print("[!!!] Using CPU for inference. This will be slow...")
device = torch.device("cpu")
if device is not None:
torch.set_default_device(device)
# Prepare training data
train_list, test_list = unpack_npy_data(args)
train_data = convert_data(train_list, args, train=True)
test_data = convert_data(test_list, args, train=False)
train_loader = build_train_loader(train_data, args)
test_loader = build_test_loader(test_data, args)
# Instantiate model
if args.model == "stn":
model = stn_patch16_384_gap(args.pth_tar)
else:
model = base_patch16_384_gap(args.pth_tar)
if args.use_ddp:
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[rank], output_device=rank,
find_unused_parameters=True,
gradient_as_bucket_view=True # XXX: vital, otherwise OOM
)
else:
model = nn.DataParallel(
model,
device_ids=args.gpus
)
if device is not None:
model = model.to(device)
elif torch.cuda.is_available():
model = model.cuda()
# criterion, optimizer, scheduler
criterion = nn.L1Loss(reduction="sum")
if device is not None:
criterion = criterion.to(device)
elif torch.cuda.is_available():
criterion = criterion.cuda()
optimizer = torch.optim.Adam(
[{"params": model.parameters(), "lr": args.lr}],
lr=args.lr,
weight_decay=args.weight_decay
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[50], gamma=.1, last_epoch=-1
)
# Checkpointing
if (not args.use_ddp) or torch.distributed.get_rank() == 0:
print("[worker-0] Saving to \'{}\'".format(args.save_path))
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
if args.resume:
if os.path.isfile(args.resume):
print("=> Loading checkpoint \'{}\'".format(args.resume))
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'], strict=False)
args.start_epoch = checkpoint['epoch']
args.best_pred = checkpoint['best_prec1'] # sic.
else:
print("=> Checkpoint not found!")
print("[worker-{}] Starting @ epoch {}: pred: {}".format(
rank, args.start_epoch, args.best_pred
))
# For each epoch:
for epoch in range(args.start_epoch, args.epochs):
# Tell sampler which epoch it is
if args.use_ddp:
train_loader.sampler.set_epoch(epoch)
test_loader.sampler.set_epoch(epoch)
# Train
start = time.time()
train_one_epoch(train_loader, model, criterion, optimizer, scheduler, epoch, device, args)
end_train = time.time()
# Validate
if epoch % 5 == 0 or args.debug:
prec1 = valid_one_epoch(test_loader, model, device, epoch, args)
end_valid = time.time()
is_best = prec1 < args.best_pred
args.best_pred = min(prec1, args.best_pred)
print("* best MAE {mae:.3f} *".format(mae=args.best_pred))
# Save checkpoint
if not args.use_ddp or torch.distributed.get_rank() == 0:
save_checkpoint({
"epoch": epoch + 1,
"arch": args.resume,
"state_dict": model.state_dict(),
"best_prec1": args.best_pred,
"optimizer": optimizer.state_dict(),
}, is_best, args.save_path)
# cleanup
if args.use_ddp:
torch.distributed.destroy_process_group()
def train_one_epoch(
train_loader: DataLoader,
model: VisionTransformerGAP,
criterion,
optimizer,
scheduler,
epoch: int,
device,
args: Namespace
):
# Get learning rate
curr_lr = optimizer.param_groups[0]["lr"]
print("Epoch %d, processed %d samples, lr %.10f" %
(epoch, epoch * len(train_loader.dataset), curr_lr)
)
# Set to train mode
model.train()
# In one epoch, for each training sample
for i, (fname, img, kpoint, gt_count_whole) in enumerate(train_loader):
kpoint = kpoint.type(torch.FloatTensor)
gt_count_whole = gt_count_whole.type(torch.FloatTensor).unsqueeze(1)
batch_size = img.size(0)
# send to device
if device is not None:
img = img.to(device)
kpoint = kpoint.to(device)
gt_count_whole = gt_count_whole.to(device)
device_type = device.type
elif torch.cuda.is_available():
img = img.cuda()
kpoint = kpoint.cuda()
gt_count_whole = gt_count_whole.cuda()
device_type = "cuda"
# fpass
out, gt_count = model(img, kpoint)
# loss & bpass & etc.
if isinstance(model.module, STNet_VisionTransformerGAP):
loss_xformer = criterion(out, gt_count)
loss_stn = (
F.mse_loss(
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
gt_count_whole)
+ ( # mean index of dispersion
gt_count.view(batch_size, -1).var(dim=1)
/ gt_count.view(batch_size, -1).mean(dim=1)
).mean()
)
loss_stn.requires_grad = True
optimizer.zero_grad(set_to_none=True)
# Accum first for STN
loss_stn.backward(
inputs=list(model.module.stnet.parameters()), retain_graph=True
)
# Avoid double accum
for param in model.module.stnet.parameters():
param.grad = None
# Then, backward for entire net
loss_xformer.backward()
if args.export_to_h5:
train_df.loc[epoch * i, "l1loss"] = float(loss_xformer.item())
train_df.loc[epoch * i, "composite-loss"] = float(loss_stn.item())
else:
writer.add_scalar("l1loss", loss_xformer, epoch * i)
writer.add_scalar("composite-loss", loss_stn, epoch * i)
else:
loss = criterion(out, gt_count)
if args.export_to_h5:
train_df.loc[epoch * i, "l1loss"] = float(loss.item())
else:
writer.add_scalar("l1loss", loss, epoch * i)
# free grad from mem
optimizer.zero_grad(set_to_none=True)
# bpass
loss.backward()
# optimizer
optimizer.step()
if args.debug:
break
# Flush writer
if args.export_to_h5:
train_df.to_hdf(train_stat_file, key="df", mode="a", append=True)
else:
writer.flush()
scheduler.step()
def valid_one_epoch(test_loader, model, device, epoch, args):
print("[valid_one_epoch] Validating...")
batch_size = 1
model.eval()
mae = .0
mse = .0
visi = []
index = 0
xformed = []
for i, (fname, img, kpoint, gt_count_whole) in enumerate(test_loader):
kpoint = kpoint.type(torch.FloatTensor)
gt_count_whole = gt_count_whole.type(torch.FloatTensor)
if device is not None:
img = img.to(device)
kpoint = kpoint.to(device)
gt_count_whole = gt_count_whole.to(device)
elif torch.cuda.is_available():
img = img.cuda()
kpoint = kpoint.cuda()
gt_count_whole = gt_count_whole.cuda()
# XXX: do this even happen?
if len(img.shape) == 5:
img = img.squeeze(0)
if len(img.shape) == 3:
img = img.unsqueeze(0)
with torch.no_grad():
out, gt_count = model(img, kpoint)
pred_count = torch.squeeze(out, 1)
gt_count = torch.squeeze(gt_count, 1)
diff = torch.abs(gt_count_whole - torch.sum(pred_count)).item()
mae += diff
mse += diff ** 2
if i % 5 == 0:
print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} |".format(
fname[0],
torch.sum(gt_count_whole).item(),
torch.sum(pred_count).item()
))
if args.debug:
break
mae = mae * 1.0 / (len(test_loader) * batch_size)
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
if args.export_to_h5:
test_df.loc[epoch, "mae"] = mae
test_df.loc[epoch, "mse"] = mse
test_df.to_hdf(test_stat_file, key="df", mode="a", append=True)
else:
writer.add_scalar("MAE (valid)", mae, epoch)
writer.add_scalar("MSE (valid)", mse, epoch)
if len(xformed) != 0 and not args.export_to_h5:
img_grid = torchvision.utils.make_grid(xformed)
writer.add_image("STN: transformed image", img_grid, epoch)
if not args.export_to_h5:
writer.flush()
nni.report_intermediate_result(mae)
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
mae=mae, mse=mse
))
return mae
if __name__ == "__main__":
tuner_params = nni.get_next_parameter()
logger.debug("Generated hyperparameters: {}", tuner_params)
combined_params = nni.utils.merge_parameter(ret_args, tuner_params)
if args.debug:
os.nice(15)
if combined_params.use_ddp:
# Use DDP, spawn threads
torch_mp.spawn(
worker,
args=(combined_params, ), # rank supplied at callee as 1st param
# also above *has* to be 1-tuple else runtime expands Namespace.
nprocs=combined_params.ddp_world_size,
)
else:
# No DDP, run in current thread
worker(0, combined_params)