From 359ed8c579246ab423c34859dc2dc115f01d6f7f Mon Sep 17 00:00:00 2001 From: rubberhead Date: Sat, 2 Mar 2024 17:52:10 +0000 Subject: [PATCH] Test --- arguments.py | 16 ++++++++-------- train.py | 31 +++++++++++++++++++++---------- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/arguments.py b/arguments.py index ef140cd..93324b5 100644 --- a/arguments.py +++ b/arguments.py @@ -11,6 +11,9 @@ parser.add_argument( ) # Data configuration ========================================================= +parser.add_argument( + "--worker", type=int, default=4, help="Number of data loader processes" +) parser.add_argument( "--train_dataset", type=str, default="ShanghaiA", help="Training dataset" ) @@ -28,19 +31,16 @@ parser.add_argument( "--save_path", type=str, default="./save/default/", help="Directory to save checkpoints in" ) +parser.add_argument( + "--resume", type=str, default=None, help="Path to checkpoint-ed model pth" +) # Model configuration ======================================================== parser.add_argument( - "--load_revnet_from", type=str, default=None, - help="Pre-trained reverse perspective model path" + "--model", type=str, default="base", help="Model variant: " ) parser.add_argument( - "--load_csrnet_from", type=str, default=None, - help="Pre-trained CSRNet model path" -) -parser.add_argument( - "--load_transcrowd_from", type=str, default=None, - help="Pre-trained TransCrowd model path" + "--pth_tar", type=str, default=None, help="Path to pre-training model pth" ) # Optimizer configuration ==================================================== diff --git a/train.py b/train.py index d4ea2ed..c1b9d0e 100644 --- a/train.py +++ b/train.py @@ -103,10 +103,15 @@ def worker(rank: int, args: Namespace): setup_process_group(rank, world_size) # Setup device for one proc - if args.use_ddp: - device = torch.device(rank if torch.cuda.is_available() else "cpu") + if args.use_ddp and torch.cuda.is_available(): + device = torch.device(rank) + elif torch.cuda.is_available(): + device = torch.device(args.gpus) + elif torch.backends.mps.is_available(): + device = torch.device("mps") else: - device = torch.device(args.gpus if torch.cuda.is_available() else "cpu") + print("[!!!] Using CPU for inference. This will be slow...") + device = torch.device("cpu") torch.set_default_device(device) # Prepare training data @@ -130,6 +135,11 @@ def worker(rank: int, args: Namespace): find_unused_parameters=True, gradient_as_bucket_view=True # XXX: vital, otherwise OOM ) + else: + model = nn.DataParallel( + model, + device_ids=args.gpus + ) # criterion, optimizer, scheduler criterion = nn.L1Loss(size_average=False).to(device) @@ -148,13 +158,13 @@ def worker(rank: int, args: Namespace): if not os.path.exists(args.save_path): os.makedirs(args.save_path) - if args.progress: - if os.path.isfile(args.progress): - print("=> Loading checkpoint \'{}\'".format(args.progress)) - checkpoint = torch.load(args.progress) + if args.resume: + if os.path.isfile(args.resume): + print("=> Loading checkpoint \'{}\'".format(args.resume)) + checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict'], strict=False) args.start_epoch = checkpoint['epoch'] - args.best_pred = checkpoint['best_prec1'] + args.best_pred = checkpoint['best_prec1'] # sic. else: print("=> Checkpoint not found!") @@ -187,7 +197,7 @@ def worker(rank: int, args: Namespace): if not args.use_ddp or torch.distributed.get_rank() == 0: save_checkpoint({ "epoch": epoch + 1, - "arch": args.progress, + "arch": args.resume, "state_dict": model.state_dict(), "best_prec1": args.best_pred, "optimizer": optimizer.state_dict(), @@ -298,7 +308,8 @@ if __name__ == "__main__": # Use DDP, spawn threads torch_mp.spawn( worker, - args=(combined_params, ), # rank supplied automatically as 1st param + args=(combined_params, ), # rank supplied at callee as 1st param + # also above *has* to be 1-tuple else runtime expands Namespace. nprocs=combined_params.world_size, ) else: