diff --git a/train.py b/train.py index 7717af5..d4ea2ed 100644 --- a/train.py +++ b/train.py @@ -18,6 +18,7 @@ from arguments import args, ret_args import dataset from dataset import * from model.transcrowd_gap import base_patch16_384_gap, stn_patch16_384_gap +from checkpoint import save_checkpoint logger = logging.getLogger("train") @@ -44,14 +45,14 @@ def setup_process_group( def build_train_loader(data_keys, args): train_dataset = ListDataset( data_keys, - shuffle = True, + shuffle = True, transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225]) - ]), - train = True, - batch_size = args.batch_size, - nr_workers = args.workers, + ]), + train = True, + batch_size = args.batch_size, + nr_workers = args.workers, args = args ) if args.use_ddp: @@ -61,9 +62,9 @@ def build_train_loader(data_keys, args): else: train_dist_sampler = None train_loader = DataLoader( - dataset=train_dataset, - sampler=train_dist_sampler, - batch_size=args.batch_size, + dataset=train_dataset, + sampler=train_dist_sampler, + batch_size=args.batch_size, drop_last=False ) return train_loader @@ -71,13 +72,13 @@ def build_train_loader(data_keys, args): def build_test_loader(data_keys, args): test_dataset = ListDataset( - data_keys, - shuffle=False, + data_keys, + shuffle=False, transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225]) - ]), - args=args, + ]), + args=args, train=False ) if args.use_ddp: @@ -87,8 +88,8 @@ def build_test_loader(data_keys, args): else: test_dist_sampler = None test_loader = DataLoader( - dataset=test_dataset, - sampler=test_dist_sampler, + dataset=test_dataset, + sampler=test_dist_sampler, batch_size=1 ) return test_loader @@ -114,27 +115,27 @@ def worker(rank: int, args: Namespace): test_data = convert_data(test_list, args, train=False) train_loader = build_train_loader(train_data, args) test_loader = build_test_loader(test_data, args) - + # Instantiate model if args.model == "stn": model = stn_patch16_384_gap(args.pth_tar).to(device) else: model = base_patch16_384_gap(args.pth_tar).to(device) - + if args.use_ddp: model = nn.parallel.DistributedDataParallel( - model, - device_ids=[rank], output_device=rank, - find_unused_parameters=True, + model, + device_ids=[rank], output_device=rank, + find_unused_parameters=True, gradient_as_bucket_view=True # XXX: vital, otherwise OOM ) # criterion, optimizer, scheduler criterion = nn.L1Loss(size_average=False).to(device) optimizer = torch.optim.Adam( - [{"params": model.parameters(), "lr": args.lr}], - lr=args.lr, + [{"params": model.parameters(), "lr": args.lr}], + lr=args.lr, weight_decay=args.weight_decay ) scheduler = torch.optim.lr_scheduler.MultiStepLR( @@ -147,7 +148,7 @@ def worker(rank: int, args: Namespace): if not os.path.exists(args.save_path): os.makedirs(args.save_path) - if args.progress: + if args.progress: if os.path.isfile(args.progress): print("=> Loading checkpoint \'{}\'".format(args.progress)) checkpoint = torch.load(args.progress) @@ -161,7 +162,7 @@ def worker(rank: int, args: Namespace): rank, args.start_epoch, args.best_pred )) - # For each epoch: + # For each epoch: for epoch in range(args.start_epoch, args.epochs): # Tell sampler which epoch it is if args.use_ddp: @@ -183,11 +184,19 @@ def worker(rank: int, args: Namespace): print("* best MAE {mae:.3f} *".format(mae=args.best_pred)) # Save checkpoint - # if not args.use_ddp or torch.distributed.get_rank() == 0: - + if not args.use_ddp or torch.distributed.get_rank() == 0: + save_checkpoint({ + "epoch": epoch + 1, + "arch": args.progress, + "state_dict": model.state_dict(), + "best_prec1": args.best_pred, + "optimizer": optimizer.state_dict(), + }, is_best, args.save_path) + # cleanup - torch.distributed.destroy_process_group() + if args.use_ddp: + torch.distributed.destroy_process_group() def train_one_epoch( @@ -197,7 +206,7 @@ def train_one_epoch( optimizer, scheduler, epoch: int, - device, + device, args: Namespace ): # Get learning rate @@ -256,7 +265,7 @@ def valid_one_epoch(test_loader, model, device, args): with torch.no_grad(): out = model(img) count = torch.sum(out).item() - + gt_count = torch.sum(gt_count).item() mae += abs(gt_count - count) mse += abs(gt_count - count) ** 2