This commit is contained in:
Zhengyi Chen 2024-03-02 17:52:10 +00:00
parent 6456d3878e
commit 359ed8c579
2 changed files with 29 additions and 18 deletions

View file

@ -11,6 +11,9 @@ parser.add_argument(
) )
# Data configuration ========================================================= # Data configuration =========================================================
parser.add_argument(
"--worker", type=int, default=4, help="Number of data loader processes"
)
parser.add_argument( parser.add_argument(
"--train_dataset", type=str, default="ShanghaiA", help="Training dataset" "--train_dataset", type=str, default="ShanghaiA", help="Training dataset"
) )
@ -28,19 +31,16 @@ parser.add_argument(
"--save_path", type=str, default="./save/default/", "--save_path", type=str, default="./save/default/",
help="Directory to save checkpoints in" help="Directory to save checkpoints in"
) )
parser.add_argument(
"--resume", type=str, default=None, help="Path to checkpoint-ed model pth"
)
# Model configuration ======================================================== # Model configuration ========================================================
parser.add_argument( parser.add_argument(
"--load_revnet_from", type=str, default=None, "--model", type=str, default="base", help="Model variant: <base|stn>"
help="Pre-trained reverse perspective model path"
) )
parser.add_argument( parser.add_argument(
"--load_csrnet_from", type=str, default=None, "--pth_tar", type=str, default=None, help="Path to pre-training model pth"
help="Pre-trained CSRNet model path"
)
parser.add_argument(
"--load_transcrowd_from", type=str, default=None,
help="Pre-trained TransCrowd model path"
) )
# Optimizer configuration ==================================================== # Optimizer configuration ====================================================

View file

@ -103,10 +103,15 @@ def worker(rank: int, args: Namespace):
setup_process_group(rank, world_size) setup_process_group(rank, world_size)
# Setup device for one proc # Setup device for one proc
if args.use_ddp: if args.use_ddp and torch.cuda.is_available():
device = torch.device(rank if torch.cuda.is_available() else "cpu") device = torch.device(rank)
elif torch.cuda.is_available():
device = torch.device(args.gpus)
elif torch.backends.mps.is_available():
device = torch.device("mps")
else: else:
device = torch.device(args.gpus if torch.cuda.is_available() else "cpu") print("[!!!] Using CPU for inference. This will be slow...")
device = torch.device("cpu")
torch.set_default_device(device) torch.set_default_device(device)
# Prepare training data # Prepare training data
@ -130,6 +135,11 @@ def worker(rank: int, args: Namespace):
find_unused_parameters=True, find_unused_parameters=True,
gradient_as_bucket_view=True # XXX: vital, otherwise OOM gradient_as_bucket_view=True # XXX: vital, otherwise OOM
) )
else:
model = nn.DataParallel(
model,
device_ids=args.gpus
)
# criterion, optimizer, scheduler # criterion, optimizer, scheduler
criterion = nn.L1Loss(size_average=False).to(device) criterion = nn.L1Loss(size_average=False).to(device)
@ -148,13 +158,13 @@ def worker(rank: int, args: Namespace):
if not os.path.exists(args.save_path): if not os.path.exists(args.save_path):
os.makedirs(args.save_path) os.makedirs(args.save_path)
if args.progress: if args.resume:
if os.path.isfile(args.progress): if os.path.isfile(args.resume):
print("=> Loading checkpoint \'{}\'".format(args.progress)) print("=> Loading checkpoint \'{}\'".format(args.resume))
checkpoint = torch.load(args.progress) checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'], strict=False) model.load_state_dict(checkpoint['state_dict'], strict=False)
args.start_epoch = checkpoint['epoch'] args.start_epoch = checkpoint['epoch']
args.best_pred = checkpoint['best_prec1'] args.best_pred = checkpoint['best_prec1'] # sic.
else: else:
print("=> Checkpoint not found!") print("=> Checkpoint not found!")
@ -187,7 +197,7 @@ def worker(rank: int, args: Namespace):
if not args.use_ddp or torch.distributed.get_rank() == 0: if not args.use_ddp or torch.distributed.get_rank() == 0:
save_checkpoint({ save_checkpoint({
"epoch": epoch + 1, "epoch": epoch + 1,
"arch": args.progress, "arch": args.resume,
"state_dict": model.state_dict(), "state_dict": model.state_dict(),
"best_prec1": args.best_pred, "best_prec1": args.best_pred,
"optimizer": optimizer.state_dict(), "optimizer": optimizer.state_dict(),
@ -298,7 +308,8 @@ if __name__ == "__main__":
# Use DDP, spawn threads # Use DDP, spawn threads
torch_mp.spawn( torch_mp.spawn(
worker, worker,
args=(combined_params, ), # rank supplied automatically as 1st param args=(combined_params, ), # rank supplied at callee as 1st param
# also above *has* to be 1-tuple else runtime expands Namespace.
nprocs=combined_params.world_size, nprocs=combined_params.world_size,
) )
else: else: