mlp-project/train.py
2024-03-06 03:26:28 +00:00

411 lines
12 KiB
Python

import os
import random
import time
from typing import Optional
from argparse import Namespace
import timm
import torch
import torch.nn as nn
import torch.multiprocessing as torch_mp
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision
import nni
import logging
import numpy as np
import pandas as pd
from model.transcrowd_gap import VisionTransformerGAP
from arguments import args, ret_args
import dataset
from dataset import *
from model.transcrowd_gap import *
from checkpoint import save_checkpoint
logger = logging.getLogger("train")
if not args.export_to_h5:
writer = SummaryWriter(args.save_path + "/tensorboard-run")
else:
train_df = pd.DataFrame(columns=["l1loss", "composite-loss"], dtype=float)
train_stat_file = args.save_path + "/train_stats.h5"
test_df = pd.DataFrame(columns=["mse", "mae"], dtype=float)
test_stat_file = args.save_path + "/test_stats.h5"
def setup_process_group(
rank: int,
world_size: int,
master_addr: str = "localhost",
master_port: Optional[np.ushort] = None
):
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = (
str(random.randint(40000, 65545))
if master_port is None
else str(master_port)
)
# join point!
torch.distributed.init_process_group(
rank=rank, world_size=world_size
)
def build_train_loader(data_keys, args):
train_dataset = ListDataset(
data_keys,
shuffle = True,
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
]),
train = True,
batch_size = args.batch_size,
nr_workers = args.workers,
args = args
)
if args.use_ddp:
train_dist_sampler = torch.utils.data.distributed.DistributedSampler(
dataset=train_dataset
)
else:
train_dist_sampler = None
train_loader = DataLoader(
dataset=train_dataset,
sampler=train_dist_sampler,
batch_size=args.batch_size,
drop_last=False
)
return train_loader
def build_test_loader(data_keys, args):
test_dataset = ListDataset(
data_keys,
shuffle=False,
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
]),
args=args,
train=False
)
if args.use_ddp:
test_dist_sampler = torch.utils.data.distributed.DistributedSampler(
dataset=test_dataset
)
else:
test_dist_sampler = None
test_loader = DataLoader(
dataset=test_dataset,
sampler=test_dist_sampler,
batch_size=1
)
return test_loader
def worker(rank: int, args: Namespace):
world_size = args.ddp_world_size
# (If DDP) join after process group among processes
if args.use_ddp:
setup_process_group(rank, world_size)
# Setup device for one proc
if args.use_ddp and torch.cuda.is_available():
device = torch.device(rank)
elif torch.cuda.is_available():
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
args.gpus = [int(gpu_id) for gpu_id in args.gpus.split(",")]
device = None
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
print("[!!!] Using CPU for inference. This will be slow...")
device = torch.device("cpu")
if device is not None:
torch.set_default_device(device)
# Prepare training data
train_list, test_list = unpack_npy_data(args)
train_data = convert_data(train_list, args, train=True)
test_data = convert_data(test_list, args, train=False)
train_loader = build_train_loader(train_data, args)
test_loader = build_test_loader(test_data, args)
# Instantiate model
if args.model == "stn":
model = stn_patch16_384_gap(args.pth_tar)
else:
model = base_patch16_384_gap(args.pth_tar)
if args.use_ddp:
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[rank], output_device=rank,
find_unused_parameters=True,
gradient_as_bucket_view=True # XXX: vital, otherwise OOM
)
else:
model = nn.DataParallel(
model,
device_ids=args.gpus
)
if device is not None:
model = model.to(device)
elif torch.cuda.is_available():
model = model.cuda()
# criterion, optimizer, scheduler
criterion = nn.L1Loss(reduction="sum")
if device is not None:
criterion = criterion.to(device)
elif torch.cuda.is_available():
criterion = criterion.cuda()
optimizer = torch.optim.Adam(
[{"params": model.parameters(), "lr": args.lr}],
lr=args.lr,
weight_decay=args.weight_decay
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[300], gamma=.1, last_epoch=-1
)
# Checkpointing
if (not args.use_ddp) or torch.distributed.get_rank() == 0:
print("[worker-0] Saving to \'{}\'".format(args.save_path))
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
if args.resume:
if os.path.isfile(args.resume):
print("=> Loading checkpoint \'{}\'".format(args.resume))
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'], strict=False)
args.start_epoch = checkpoint['epoch']
args.best_pred = checkpoint['best_prec1'] # sic.
else:
print("=> Checkpoint not found!")
print("[worker-{}] Starting @ epoch {}: pred: {}".format(
rank, args.start_epoch, args.best_pred
))
# For each epoch:
for epoch in range(args.start_epoch, args.epochs):
# Tell sampler which epoch it is
if args.use_ddp:
train_loader.sampler.set_epoch(epoch)
test_loader.sampler.set_epoch(epoch)
# Train
start = time.time()
train_one_epoch(train_loader, model, criterion, optimizer, scheduler, epoch, device, args)
end_train = time.time()
# Validate
if epoch % 5 == 0 or args.debug:
prec1 = valid_one_epoch(test_loader, model, device, epoch, args)
end_valid = time.time()
is_best = prec1 < args.best_pred
args.best_pred = min(prec1, args.best_pred)
print("* best MAE {mae:.3f} *".format(mae=args.best_pred))
# Save checkpoint
if not args.use_ddp or torch.distributed.get_rank() == 0:
save_checkpoint({
"epoch": epoch + 1,
"arch": args.resume,
"state_dict": model.state_dict(),
"best_prec1": args.best_pred,
"optimizer": optimizer.state_dict(),
}, is_best, args.save_path)
# cleanup
if args.use_ddp:
torch.distributed.destroy_process_group()
def train_one_epoch(
train_loader: DataLoader,
model: VisionTransformerGAP,
criterion,
optimizer,
scheduler,
epoch: int,
device,
args: Namespace
):
# Get learning rate
curr_lr = optimizer.param_groups[0]["lr"]
print("Epoch %d, processed %d samples, lr %.10f" %
(epoch, epoch * len(train_loader.dataset), curr_lr)
)
# Set to train mode
model.train()
# In one epoch, for each training sample
for i, (fname, img, kpoint, gt_count_whole) in enumerate(train_loader):
kpoint = kpoint.type(torch.FloatTensor)
gt_count_whole = gt_count_whole.type(torch.FloatTensor).unsqueeze(1)
batch_size = img.size(0)
# send to device
if device is not None:
img = img.to(device)
kpoint = kpoint.to(device)
gt_count_whole = gt_count_whole.to(device)
device_type = device.type
elif torch.cuda.is_available():
img = img.cuda()
kpoint = kpoint.cuda()
gt_count_whole = gt_count_whole.cuda()
device_type = "cuda"
# Desperate measure to reduce mem footprint...
# with torch.autocast(device_type):
# fpass
out, gt_count = model(img, kpoint)
# loss
loss = criterion(out, gt_count) # wrt. transformer
if args.export_to_h5:
train_df.loc[epoch * i, "l1loss"] = float(loss.item())
else:
writer.add_scalar(
"L1-loss wrt. xformer (train)", loss, epoch * i
)
loss += (
F.mse_loss( # stn: info retainment
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
gt_count_whole)
+ F.threshold( # stn: perspective correction
gt_count.view(batch_size, -1).var(dim=1).mean(),
threshold=loss.item(),
value=loss.item()
)
)
if args.export_to_h5:
train_df.loc[epoch * i, "composite-loss"] = float(loss.item())
else:
writer.add_scalar("Composite loss (train)", loss, epoch * i)
# free grad from mem
optimizer.zero_grad(set_to_none=True)
# bpass
loss.backward()
# optimizer
optimizer.step()
if args.debug:
break
# Flush writer
if args.export_to_h5:
train_df.to_hdf(train_stat_file, key="df", mode="a", append=True)
else:
writer.flush()
scheduler.step()
def valid_one_epoch(test_loader, model, device, epoch, args):
print("[valid_one_epoch] Validating...")
batch_size = 1
model.eval()
mae = .0
mse = .0
visi = []
index = 0
xformed = []
for i, (fname, img, kpoint, gt_count_whole) in enumerate(test_loader):
kpoint = kpoint.type(torch.FloatTensor)
gt_count_whole = gt_count_whole.type(torch.FloatTensor)
if device is not None:
img = img.to(device)
kpoint = kpoint.to(device)
gt_count_whole = gt_count_whole.to(device)
elif torch.cuda.is_available():
img = img.cuda()
kpoint = kpoint.cuda()
gt_count_whole = gt_count_whole.cuda()
# XXX: do this even happen?
if len(img.shape) == 5:
img = img.squeeze(0)
if len(img.shape) == 3:
img = img.unsqueeze(0)
with torch.no_grad():
out, gt_count = model(img, kpoint)
pred_count = torch.squeeze(out, 1)
gt_count = torch.squeeze(gt_count, 1)
diff = torch.abs(gt_count_whole - torch.sum(pred_count)).item()
mae += diff
mse += diff ** 2
if i % 5 == 0:
# with torch.no_grad():
# img_xformed = model.stnet(img).to("cpu")
# xformed.append(img_xformed)
print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} |".format(
fname[0],
torch.sum(gt_count_whole).item(),
torch.sum(pred_count).item()
))
if args.debug:
break
mae = mae * 1.0 / (len(test_loader) * batch_size)
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
if args.export_to_h5:
test_df.loc[epoch, "mae"] = mae
test_df.loc[epoch, "mse"] = mse
test_df.to_hdf(test_stat_file, key="df", mode="a", append=True)
else:
writer.add_scalar("MAE (valid)", mae, epoch)
writer.add_scalar("MSE (valid)", mse, epoch)
if len(xformed) != 0 and not args.export_to_h5:
img_grid = torchvision.utils.make_grid(xformed)
writer.add_image("STN: transformed image", img_grid, epoch)
if not args.export_to_h5:
writer.flush()
nni.report_intermediate_result(mae)
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
mae=mae, mse=mse
))
return mae
if __name__ == "__main__":
tuner_params = nni.get_next_parameter()
logger.debug("Generated hyperparameters: {}", tuner_params)
combined_params = nni.utils.merge_parameter(ret_args, tuner_params)
if args.debug:
os.nice(15)
if combined_params.use_ddp:
# Use DDP, spawn threads
torch_mp.spawn(
worker,
args=(combined_params, ), # rank supplied at callee as 1st param
# also above *has* to be 1-tuple else runtime expands Namespace.
nprocs=combined_params.ddp_world_size,
)
else:
# No DDP, run in current thread
worker(0, combined_params)