This commit is contained in:
Zhengyi Chen 2024-03-02 17:52:10 +00:00
parent 6456d3878e
commit 359ed8c579
2 changed files with 29 additions and 18 deletions

View file

@ -11,6 +11,9 @@ parser.add_argument(
)
# Data configuration =========================================================
parser.add_argument(
"--worker", type=int, default=4, help="Number of data loader processes"
)
parser.add_argument(
"--train_dataset", type=str, default="ShanghaiA", help="Training dataset"
)
@ -28,19 +31,16 @@ parser.add_argument(
"--save_path", type=str, default="./save/default/",
help="Directory to save checkpoints in"
)
parser.add_argument(
"--resume", type=str, default=None, help="Path to checkpoint-ed model pth"
)
# Model configuration ========================================================
parser.add_argument(
"--load_revnet_from", type=str, default=None,
help="Pre-trained reverse perspective model path"
"--model", type=str, default="base", help="Model variant: <base|stn>"
)
parser.add_argument(
"--load_csrnet_from", type=str, default=None,
help="Pre-trained CSRNet model path"
)
parser.add_argument(
"--load_transcrowd_from", type=str, default=None,
help="Pre-trained TransCrowd model path"
"--pth_tar", type=str, default=None, help="Path to pre-training model pth"
)
# Optimizer configuration ====================================================

View file

@ -103,10 +103,15 @@ def worker(rank: int, args: Namespace):
setup_process_group(rank, world_size)
# Setup device for one proc
if args.use_ddp:
device = torch.device(rank if torch.cuda.is_available() else "cpu")
if args.use_ddp and torch.cuda.is_available():
device = torch.device(rank)
elif torch.cuda.is_available():
device = torch.device(args.gpus)
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device(args.gpus if torch.cuda.is_available() else "cpu")
print("[!!!] Using CPU for inference. This will be slow...")
device = torch.device("cpu")
torch.set_default_device(device)
# Prepare training data
@ -130,6 +135,11 @@ def worker(rank: int, args: Namespace):
find_unused_parameters=True,
gradient_as_bucket_view=True # XXX: vital, otherwise OOM
)
else:
model = nn.DataParallel(
model,
device_ids=args.gpus
)
# criterion, optimizer, scheduler
criterion = nn.L1Loss(size_average=False).to(device)
@ -148,13 +158,13 @@ def worker(rank: int, args: Namespace):
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
if args.progress:
if os.path.isfile(args.progress):
print("=> Loading checkpoint \'{}\'".format(args.progress))
checkpoint = torch.load(args.progress)
if args.resume:
if os.path.isfile(args.resume):
print("=> Loading checkpoint \'{}\'".format(args.resume))
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'], strict=False)
args.start_epoch = checkpoint['epoch']
args.best_pred = checkpoint['best_prec1']
args.best_pred = checkpoint['best_prec1'] # sic.
else:
print("=> Checkpoint not found!")
@ -187,7 +197,7 @@ def worker(rank: int, args: Namespace):
if not args.use_ddp or torch.distributed.get_rank() == 0:
save_checkpoint({
"epoch": epoch + 1,
"arch": args.progress,
"arch": args.resume,
"state_dict": model.state_dict(),
"best_prec1": args.best_pred,
"optimizer": optimizer.state_dict(),
@ -298,7 +308,8 @@ if __name__ == "__main__":
# Use DDP, spawn threads
torch_mp.spawn(
worker,
args=(combined_params, ), # rank supplied automatically as 1st param
args=(combined_params, ), # rank supplied at callee as 1st param
# also above *has* to be 1-tuple else runtime expands Namespace.
nprocs=combined_params.world_size,
)
else: