Call it a day

This commit is contained in:
Zhengyi Chen 2024-02-28 17:26:02 +00:00
parent 04f78fbcbc
commit dcc3f57596

239
train.py
View file

@ -1,5 +1,6 @@
import os import os
import random import random
import time
from typing import Optional from typing import Optional
from argparse import Namespace from argparse import Namespace
@ -14,6 +15,9 @@ import numpy as np
from model.transcrowd_gap import VisionTransformerGAP from model.transcrowd_gap import VisionTransformerGAP
from arguments import args, ret_args from arguments import args, ret_args
import dataset
from dataset import *
from model.transcrowd_gap import base_patch16_384_gap, stn_patch16_384_gap
logger = logging.getLogger("train") logger = logging.getLogger("train")
@ -33,26 +37,157 @@ def setup_process_group(
# join point! # join point!
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend="nccl", rank=rank, world_size=world_size rank=rank, world_size=world_size
) )
# TODO:
# The shape for each batch in transcrowd is [3, 384, 384], def build_train_loader(data_keys, args):
# this is due to images being cropped before training. train_dataset = ListDataset(
# To preserve image semantics wrt the entire layout, we want to apply cropping data_keys,
# i.e., as encoder input during the inference/training pipeline. shuffle = True,
# This should be okay since our transformations are all deterministic? transform = transforms.Compose([
# not sure... transforms.ToTensor(),
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
]),
train = True,
batch_size = args.batch_size,
nr_workers = args.workers,
args = args
)
if args.use_ddp:
train_dist_sampler = torch.utils.data.distributed.DistributedSampler(
dataset=train_dataset
)
else:
train_dist_sampler = None
train_loader = DataLoader(
dataset=train_dataset,
sampler=train_dist_sampler,
batch_size=args.batch_size,
drop_last=False
)
return train_loader
def build_test_loader(data_keys, args):
test_dataset = ListDataset(
data_keys,
shuffle=False,
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
]),
args=args,
train=False
)
if args.use_ddp:
test_dist_sampler = torch.utils.data.distributed.DistributedSampler(
dataset=test_dataset
)
else:
test_dist_sampler = None
test_loader = DataLoader(
dataset=test_dataset,
sampler=test_dist_sampler,
batch_size=1
)
return test_loader
def build_train_loader(): def worker(rank: int, args: Namespace):
pass world_size = args.ddp_world_size
# (If DDP) join after process group among processes
if args.use_ddp:
setup_process_group(rank, world_size)
# Setup device for one proc
if args.use_ddp:
device = torch.device(rank if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.gpus if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)
# Prepare training data
train_list, test_list = unpack_npy_data(args)
train_data = convert_data(train_list, args, train=True)
test_data = convert_data(test_list, args, train=False)
train_loader = build_train_loader(train_data, args)
test_loader = build_test_loader(test_data, args)
def build_valid_loader(): # Instantiate model
pass if args.model == "stn":
model = stn_patch16_384_gap(args.pth_tar).to(device)
else:
model = base_patch16_384_gap(args.pth_tar).to(device)
if args.use_ddp:
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[rank], output_device=rank,
find_unused_parameters=True,
gradient_as_bucket_view=True # XXX: vital, otherwise OOM
)
# criterion, optimizer, scheduler
criterion = nn.L1Loss(size_average=False).to(device)
optimizer = torch.optim.Adam(
[{"params": model.parameters(), "lr": args.lr}],
lr=args.lr,
weight_decay=args.weight_decay
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[300], gamma=.1, last_epoch=-1
)
# Checkpointing
if (not args.use_ddp) or torch.distributed.get_rank() == 0:
print("[worker-0] Saving to \'{}\'".format(args.save_path))
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
if args.progress:
if os.path.isfile(args.progress):
print("=> Loading checkpoint \'{}\'".format(args.progress))
checkpoint = torch.load(args.progress)
model.load_state_dict(checkpoint['state_dict'], strict=False)
args.start_epoch = checkpoint['epoch']
args.best_pred = checkpoint['best_prec1']
else:
print("=> Checkpoint not found!")
print("[worker-{}] Starting @ epoch {}: pred: {}".format(
rank, args.start_epoch, args.best_pred
))
# For each epoch:
for epoch in range(args.start_epoch, args.epochs):
# Tell sampler which epoch it is
if args.use_ddp:
train_loader.sampler.set_epoch(epoch)
test_loader.sampler.set_epoch(epoch)
# Train
start = time.time()
train_one_epoch(train_loader, model, criterion, optimizer, scheduler, epoch, device, args)
end_train = time.time()
# Validate
if epoch % 5 == 0:
prec1 = valid_one_epoch(test_loader, model, device, args)
end_valid = time.time()
is_best = prec1 < args.best_pred
args.best_pred = min(prec1, args.best_pred)
print("* best MAE {mae:.3f} *".format(mae=args.best_pred))
# Save checkpoint
# if not args.use_ddp or torch.distributed.get_rank() == 0:
# cleanup
torch.distributed.destroy_process_group()
def train_one_epoch( def train_one_epoch(
@ -62,6 +197,7 @@ def train_one_epoch(
optimizer, optimizer,
scheduler, scheduler,
epoch: int, epoch: int,
device,
args: Namespace args: Namespace
): ):
# Get learning rate # Get learning rate
@ -70,27 +206,76 @@ def train_one_epoch(
(epoch, epoch * len(train_loader.dataset), curr_lr) (epoch, epoch * len(train_loader.dataset), curr_lr)
) )
# Set to train mode (perspective estimator only) # Set to train mode
revpers_net.train() model.train()
end = time.time()
# In one epoch, for each training sample # In one epoch, for each training sample
for i, (fname, img, gt_count) in enumerate(train_loader): for i, (fname, img, gt_count) in enumerate(train_loader):
# move stuff to device # fpass
# fpass (revpers) img = img.to(device)
img = img.cuda() out = model(img)
gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1)
# loss wrt revpers # loss
loss = criterion() loss = criterion(out, gt_count)
# free grad from mem
optimizer.zero_grad()
# bpass
loss.backward()
# optimizer
optimizer.step()
# periodic message
if i % args.print_freq == 0:
print("Epoch {}: {}/{}".format(epoch, i, len(train_loader)))
scheduler.step()
pass def valid_one_epoch(test_loader, model, device, args):
print("[valid_one_epoch] Validating...")
batch_size = 1
model.eval()
def valid_one_epoch(): mae = .0
pass mse = .0
visi = []
index = 0
for i, (fname, img, gt_count) in enumerate(test_loader):
img = img.to(device)
# XXX: what do this do
if len(img.shape) == 5:
img = img.squeeze(0)
if len(img.shape) == 3:
img = img.unsqueeze(0)
with torch.no_grad():
out = model(img)
count = torch.sum(out).item()
gt_count = torch.sum(gt_count).item()
mae += abs(gt_count - count)
mse += abs(gt_count - count) ** 2
if i % 15 == 0:
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(
fname[0], gt_count, count
))
mae = mae * 1.0 / (len(test_loader) * batch_size)
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
nni.report_intermediate_result(mae)
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
mae=mae, mse=mse
))
return mae
def main(rank: int, args: Namespace):
pass
if __name__ == "__main__": if __name__ == "__main__":
tuner_params = nni.get_next_parameter() tuner_params = nni.get_next_parameter()
@ -103,10 +288,10 @@ if __name__ == "__main__":
if combined_params.use_ddp: if combined_params.use_ddp:
# Use DDP, spawn threads # Use DDP, spawn threads
torch_mp.spawn( torch_mp.spawn(
main, worker,
args=(combined_params, ), # rank supplied automatically as 1st param args=(combined_params, ), # rank supplied automatically as 1st param
nprocs=combined_params.world_size, nprocs=combined_params.world_size,
) )
else: else:
# No DDP, run in current thread # No DDP, run in current thread
main(None, combined_params) worker(0, combined_params)