diff --git a/train.py b/train.py index 9beb2ec..7717af5 100644 --- a/train.py +++ b/train.py @@ -1,5 +1,6 @@ import os import random +import time from typing import Optional from argparse import Namespace @@ -14,6 +15,9 @@ import numpy as np from model.transcrowd_gap import VisionTransformerGAP from arguments import args, ret_args +import dataset +from dataset import * +from model.transcrowd_gap import base_patch16_384_gap, stn_patch16_384_gap logger = logging.getLogger("train") @@ -33,26 +37,157 @@ def setup_process_group( # join point! torch.distributed.init_process_group( - backend="nccl", rank=rank, world_size=world_size + rank=rank, world_size=world_size ) -# TODO: -# The shape for each batch in transcrowd is [3, 384, 384], -# this is due to images being cropped before training. -# To preserve image semantics wrt the entire layout, we want to apply cropping -# i.e., as encoder input during the inference/training pipeline. -# This should be okay since our transformations are all deterministic? -# not sure... + +def build_train_loader(data_keys, args): + train_dataset = ListDataset( + data_keys, + shuffle = True, + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225]) + ]), + train = True, + batch_size = args.batch_size, + nr_workers = args.workers, + args = args + ) + if args.use_ddp: + train_dist_sampler = torch.utils.data.distributed.DistributedSampler( + dataset=train_dataset + ) + else: + train_dist_sampler = None + train_loader = DataLoader( + dataset=train_dataset, + sampler=train_dist_sampler, + batch_size=args.batch_size, + drop_last=False + ) + return train_loader +def build_test_loader(data_keys, args): + test_dataset = ListDataset( + data_keys, + shuffle=False, + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225]) + ]), + args=args, + train=False + ) + if args.use_ddp: + test_dist_sampler = torch.utils.data.distributed.DistributedSampler( + dataset=test_dataset + ) + else: + test_dist_sampler = None + test_loader = DataLoader( + dataset=test_dataset, + sampler=test_dist_sampler, + batch_size=1 + ) + return test_loader -def build_train_loader(): - pass +def worker(rank: int, args: Namespace): + world_size = args.ddp_world_size + # (If DDP) join after process group among processes + if args.use_ddp: + setup_process_group(rank, world_size) -def build_valid_loader(): - pass + # Setup device for one proc + if args.use_ddp: + device = torch.device(rank if torch.cuda.is_available() else "cpu") + else: + device = torch.device(args.gpus if torch.cuda.is_available() else "cpu") + torch.set_default_device(device) + + # Prepare training data + train_list, test_list = unpack_npy_data(args) + train_data = convert_data(train_list, args, train=True) + test_data = convert_data(test_list, args, train=False) + train_loader = build_train_loader(train_data, args) + test_loader = build_test_loader(test_data, args) + + + # Instantiate model + if args.model == "stn": + model = stn_patch16_384_gap(args.pth_tar).to(device) + else: + model = base_patch16_384_gap(args.pth_tar).to(device) + + if args.use_ddp: + model = nn.parallel.DistributedDataParallel( + model, + device_ids=[rank], output_device=rank, + find_unused_parameters=True, + gradient_as_bucket_view=True # XXX: vital, otherwise OOM + ) + + # criterion, optimizer, scheduler + criterion = nn.L1Loss(size_average=False).to(device) + optimizer = torch.optim.Adam( + [{"params": model.parameters(), "lr": args.lr}], + lr=args.lr, + weight_decay=args.weight_decay + ) + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, milestones=[300], gamma=.1, last_epoch=-1 + ) + + # Checkpointing + if (not args.use_ddp) or torch.distributed.get_rank() == 0: + print("[worker-0] Saving to \'{}\'".format(args.save_path)) + if not os.path.exists(args.save_path): + os.makedirs(args.save_path) + + if args.progress: + if os.path.isfile(args.progress): + print("=> Loading checkpoint \'{}\'".format(args.progress)) + checkpoint = torch.load(args.progress) + model.load_state_dict(checkpoint['state_dict'], strict=False) + args.start_epoch = checkpoint['epoch'] + args.best_pred = checkpoint['best_prec1'] + else: + print("=> Checkpoint not found!") + + print("[worker-{}] Starting @ epoch {}: pred: {}".format( + rank, args.start_epoch, args.best_pred + )) + + # For each epoch: + for epoch in range(args.start_epoch, args.epochs): + # Tell sampler which epoch it is + if args.use_ddp: + train_loader.sampler.set_epoch(epoch) + test_loader.sampler.set_epoch(epoch) + + # Train + start = time.time() + train_one_epoch(train_loader, model, criterion, optimizer, scheduler, epoch, device, args) + end_train = time.time() + + # Validate + if epoch % 5 == 0: + prec1 = valid_one_epoch(test_loader, model, device, args) + end_valid = time.time() + is_best = prec1 < args.best_pred + args.best_pred = min(prec1, args.best_pred) + + print("* best MAE {mae:.3f} *".format(mae=args.best_pred)) + + # Save checkpoint + # if not args.use_ddp or torch.distributed.get_rank() == 0: + + + # cleanup + torch.distributed.destroy_process_group() def train_one_epoch( @@ -62,6 +197,7 @@ def train_one_epoch( optimizer, scheduler, epoch: int, + device, args: Namespace ): # Get learning rate @@ -70,27 +206,76 @@ def train_one_epoch( (epoch, epoch * len(train_loader.dataset), curr_lr) ) - # Set to train mode (perspective estimator only) - revpers_net.train() - end = time.time() + # Set to train mode + model.train() # In one epoch, for each training sample for i, (fname, img, gt_count) in enumerate(train_loader): - # move stuff to device - # fpass (revpers) - img = img.cuda() + # fpass + img = img.to(device) + out = model(img) + gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1) - # loss wrt revpers - loss = criterion() + # loss + loss = criterion(out, gt_count) + + # free grad from mem + optimizer.zero_grad() + + # bpass + loss.backward() + + # optimizer + optimizer.step() + + # periodic message + if i % args.print_freq == 0: + print("Epoch {}: {}/{}".format(epoch, i, len(train_loader))) + + scheduler.step() - pass +def valid_one_epoch(test_loader, model, device, args): + print("[valid_one_epoch] Validating...") + batch_size = 1 + model.eval() -def valid_one_epoch(): - pass + mae = .0 + mse = .0 + visi = [] + index = 0 + + for i, (fname, img, gt_count) in enumerate(test_loader): + img = img.to(device) + # XXX: what do this do + if len(img.shape) == 5: + img = img.squeeze(0) + if len(img.shape) == 3: + img = img.unsqueeze(0) + + with torch.no_grad(): + out = model(img) + count = torch.sum(out).item() + + gt_count = torch.sum(gt_count).item() + mae += abs(gt_count - count) + mse += abs(gt_count - count) ** 2 + + if i % 15 == 0: + print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format( + fname[0], gt_count, count + )) + + mae = mae * 1.0 / (len(test_loader) * batch_size) + mse = np.sqrt(mse / (len(test_loader)) * batch_size) + + nni.report_intermediate_result(mae) + print("* MAE {mae:.3f} | MSE {mse:.3f} *".format( + mae=mae, mse=mse + )) + + return mae -def main(rank: int, args: Namespace): - pass if __name__ == "__main__": tuner_params = nni.get_next_parameter() @@ -103,10 +288,10 @@ if __name__ == "__main__": if combined_params.use_ddp: # Use DDP, spawn threads torch_mp.spawn( - main, + worker, args=(combined_params, ), # rank supplied automatically as 1st param nprocs=combined_params.world_size, ) else: # No DDP, run in current thread - main(None, combined_params) \ No newline at end of file + worker(0, combined_params) \ No newline at end of file