This commit is contained in:
Zhengyi Chen 2024-03-02 17:52:10 +00:00
parent 6456d3878e
commit 359ed8c579
2 changed files with 29 additions and 18 deletions

View file

@ -103,10 +103,15 @@ def worker(rank: int, args: Namespace):
setup_process_group(rank, world_size)
# Setup device for one proc
if args.use_ddp:
device = torch.device(rank if torch.cuda.is_available() else "cpu")
if args.use_ddp and torch.cuda.is_available():
device = torch.device(rank)
elif torch.cuda.is_available():
device = torch.device(args.gpus)
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device(args.gpus if torch.cuda.is_available() else "cpu")
print("[!!!] Using CPU for inference. This will be slow...")
device = torch.device("cpu")
torch.set_default_device(device)
# Prepare training data
@ -130,6 +135,11 @@ def worker(rank: int, args: Namespace):
find_unused_parameters=True,
gradient_as_bucket_view=True # XXX: vital, otherwise OOM
)
else:
model = nn.DataParallel(
model,
device_ids=args.gpus
)
# criterion, optimizer, scheduler
criterion = nn.L1Loss(size_average=False).to(device)
@ -148,13 +158,13 @@ def worker(rank: int, args: Namespace):
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
if args.progress:
if os.path.isfile(args.progress):
print("=> Loading checkpoint \'{}\'".format(args.progress))
checkpoint = torch.load(args.progress)
if args.resume:
if os.path.isfile(args.resume):
print("=> Loading checkpoint \'{}\'".format(args.resume))
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'], strict=False)
args.start_epoch = checkpoint['epoch']
args.best_pred = checkpoint['best_prec1']
args.best_pred = checkpoint['best_prec1'] # sic.
else:
print("=> Checkpoint not found!")
@ -187,7 +197,7 @@ def worker(rank: int, args: Namespace):
if not args.use_ddp or torch.distributed.get_rank() == 0:
save_checkpoint({
"epoch": epoch + 1,
"arch": args.progress,
"arch": args.resume,
"state_dict": model.state_dict(),
"best_prec1": args.best_pred,
"optimizer": optimizer.state_dict(),
@ -298,7 +308,8 @@ if __name__ == "__main__":
# Use DDP, spawn threads
torch_mp.spawn(
worker,
args=(combined_params, ), # rank supplied automatically as 1st param
args=(combined_params, ), # rank supplied at callee as 1st param
# also above *has* to be 1-tuple else runtime expands Namespace.
nprocs=combined_params.world_size,
)
else: