mlp-project/train.py
2024-02-28 17:26:02 +00:00

297 lines
No EOL
8.4 KiB
Python

import os
import random
import time
from typing import Optional
from argparse import Namespace
import timm
import torch
import torch.nn as nn
import torch.multiprocessing as torch_mp
from torch.utils.data import DataLoader
import nni
import logging
import numpy as np
from model.transcrowd_gap import VisionTransformerGAP
from arguments import args, ret_args
import dataset
from dataset import *
from model.transcrowd_gap import base_patch16_384_gap, stn_patch16_384_gap
logger = logging.getLogger("train")
def setup_process_group(
rank: int,
world_size: int,
master_addr: str = "localhost",
master_port: Optional[np.ushort] = None
):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = (
str(random.randint(40000, 65545))
if master_port is None
else str(master_port)
)
# join point!
torch.distributed.init_process_group(
rank=rank, world_size=world_size
)
def build_train_loader(data_keys, args):
train_dataset = ListDataset(
data_keys,
shuffle = True,
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
]),
train = True,
batch_size = args.batch_size,
nr_workers = args.workers,
args = args
)
if args.use_ddp:
train_dist_sampler = torch.utils.data.distributed.DistributedSampler(
dataset=train_dataset
)
else:
train_dist_sampler = None
train_loader = DataLoader(
dataset=train_dataset,
sampler=train_dist_sampler,
batch_size=args.batch_size,
drop_last=False
)
return train_loader
def build_test_loader(data_keys, args):
test_dataset = ListDataset(
data_keys,
shuffle=False,
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
]),
args=args,
train=False
)
if args.use_ddp:
test_dist_sampler = torch.utils.data.distributed.DistributedSampler(
dataset=test_dataset
)
else:
test_dist_sampler = None
test_loader = DataLoader(
dataset=test_dataset,
sampler=test_dist_sampler,
batch_size=1
)
return test_loader
def worker(rank: int, args: Namespace):
world_size = args.ddp_world_size
# (If DDP) join after process group among processes
if args.use_ddp:
setup_process_group(rank, world_size)
# Setup device for one proc
if args.use_ddp:
device = torch.device(rank if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.gpus if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)
# Prepare training data
train_list, test_list = unpack_npy_data(args)
train_data = convert_data(train_list, args, train=True)
test_data = convert_data(test_list, args, train=False)
train_loader = build_train_loader(train_data, args)
test_loader = build_test_loader(test_data, args)
# Instantiate model
if args.model == "stn":
model = stn_patch16_384_gap(args.pth_tar).to(device)
else:
model = base_patch16_384_gap(args.pth_tar).to(device)
if args.use_ddp:
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[rank], output_device=rank,
find_unused_parameters=True,
gradient_as_bucket_view=True # XXX: vital, otherwise OOM
)
# criterion, optimizer, scheduler
criterion = nn.L1Loss(size_average=False).to(device)
optimizer = torch.optim.Adam(
[{"params": model.parameters(), "lr": args.lr}],
lr=args.lr,
weight_decay=args.weight_decay
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[300], gamma=.1, last_epoch=-1
)
# Checkpointing
if (not args.use_ddp) or torch.distributed.get_rank() == 0:
print("[worker-0] Saving to \'{}\'".format(args.save_path))
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
if args.progress:
if os.path.isfile(args.progress):
print("=> Loading checkpoint \'{}\'".format(args.progress))
checkpoint = torch.load(args.progress)
model.load_state_dict(checkpoint['state_dict'], strict=False)
args.start_epoch = checkpoint['epoch']
args.best_pred = checkpoint['best_prec1']
else:
print("=> Checkpoint not found!")
print("[worker-{}] Starting @ epoch {}: pred: {}".format(
rank, args.start_epoch, args.best_pred
))
# For each epoch:
for epoch in range(args.start_epoch, args.epochs):
# Tell sampler which epoch it is
if args.use_ddp:
train_loader.sampler.set_epoch(epoch)
test_loader.sampler.set_epoch(epoch)
# Train
start = time.time()
train_one_epoch(train_loader, model, criterion, optimizer, scheduler, epoch, device, args)
end_train = time.time()
# Validate
if epoch % 5 == 0:
prec1 = valid_one_epoch(test_loader, model, device, args)
end_valid = time.time()
is_best = prec1 < args.best_pred
args.best_pred = min(prec1, args.best_pred)
print("* best MAE {mae:.3f} *".format(mae=args.best_pred))
# Save checkpoint
# if not args.use_ddp or torch.distributed.get_rank() == 0:
# cleanup
torch.distributed.destroy_process_group()
def train_one_epoch(
train_loader: DataLoader,
model: VisionTransformerGAP,
criterion,
optimizer,
scheduler,
epoch: int,
device,
args: Namespace
):
# Get learning rate
curr_lr = optimizer.param_groups[0]["lr"]
print("Epoch %d, processed %d samples, lr %.10f" %
(epoch, epoch * len(train_loader.dataset), curr_lr)
)
# Set to train mode
model.train()
# In one epoch, for each training sample
for i, (fname, img, gt_count) in enumerate(train_loader):
# fpass
img = img.to(device)
out = model(img)
gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1)
# loss
loss = criterion(out, gt_count)
# free grad from mem
optimizer.zero_grad()
# bpass
loss.backward()
# optimizer
optimizer.step()
# periodic message
if i % args.print_freq == 0:
print("Epoch {}: {}/{}".format(epoch, i, len(train_loader)))
scheduler.step()
def valid_one_epoch(test_loader, model, device, args):
print("[valid_one_epoch] Validating...")
batch_size = 1
model.eval()
mae = .0
mse = .0
visi = []
index = 0
for i, (fname, img, gt_count) in enumerate(test_loader):
img = img.to(device)
# XXX: what do this do
if len(img.shape) == 5:
img = img.squeeze(0)
if len(img.shape) == 3:
img = img.unsqueeze(0)
with torch.no_grad():
out = model(img)
count = torch.sum(out).item()
gt_count = torch.sum(gt_count).item()
mae += abs(gt_count - count)
mse += abs(gt_count - count) ** 2
if i % 15 == 0:
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(
fname[0], gt_count, count
))
mae = mae * 1.0 / (len(test_loader) * batch_size)
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
nni.report_intermediate_result(mae)
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
mae=mae, mse=mse
))
return mae
if __name__ == "__main__":
tuner_params = nni.get_next_parameter()
logger.debug("Generated hyperparameters: {}", tuner_params)
combined_params = Namespace(
nni.utils.merge_parameter(ret_args, tuner_params)
) # Namespaces have better ergonomics, notably a struct-like access syntax.
logger.debug("Parameters: {}", combined_params)
if combined_params.use_ddp:
# Use DDP, spawn threads
torch_mp.spawn(
worker,
args=(combined_params, ), # rank supplied automatically as 1st param
nprocs=combined_params.world_size,
)
else:
# No DDP, run in current thread
worker(0, combined_params)