Test
This commit is contained in:
parent
6456d3878e
commit
359ed8c579
2 changed files with 29 additions and 18 deletions
31
train.py
31
train.py
|
|
@ -103,10 +103,15 @@ def worker(rank: int, args: Namespace):
|
|||
setup_process_group(rank, world_size)
|
||||
|
||||
# Setup device for one proc
|
||||
if args.use_ddp:
|
||||
device = torch.device(rank if torch.cuda.is_available() else "cpu")
|
||||
if args.use_ddp and torch.cuda.is_available():
|
||||
device = torch.device(rank)
|
||||
elif torch.cuda.is_available():
|
||||
device = torch.device(args.gpus)
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device(args.gpus if torch.cuda.is_available() else "cpu")
|
||||
print("[!!!] Using CPU for inference. This will be slow...")
|
||||
device = torch.device("cpu")
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Prepare training data
|
||||
|
|
@ -130,6 +135,11 @@ def worker(rank: int, args: Namespace):
|
|||
find_unused_parameters=True,
|
||||
gradient_as_bucket_view=True # XXX: vital, otherwise OOM
|
||||
)
|
||||
else:
|
||||
model = nn.DataParallel(
|
||||
model,
|
||||
device_ids=args.gpus
|
||||
)
|
||||
|
||||
# criterion, optimizer, scheduler
|
||||
criterion = nn.L1Loss(size_average=False).to(device)
|
||||
|
|
@ -148,13 +158,13 @@ def worker(rank: int, args: Namespace):
|
|||
if not os.path.exists(args.save_path):
|
||||
os.makedirs(args.save_path)
|
||||
|
||||
if args.progress:
|
||||
if os.path.isfile(args.progress):
|
||||
print("=> Loading checkpoint \'{}\'".format(args.progress))
|
||||
checkpoint = torch.load(args.progress)
|
||||
if args.resume:
|
||||
if os.path.isfile(args.resume):
|
||||
print("=> Loading checkpoint \'{}\'".format(args.resume))
|
||||
checkpoint = torch.load(args.resume)
|
||||
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
args.start_epoch = checkpoint['epoch']
|
||||
args.best_pred = checkpoint['best_prec1']
|
||||
args.best_pred = checkpoint['best_prec1'] # sic.
|
||||
else:
|
||||
print("=> Checkpoint not found!")
|
||||
|
||||
|
|
@ -187,7 +197,7 @@ def worker(rank: int, args: Namespace):
|
|||
if not args.use_ddp or torch.distributed.get_rank() == 0:
|
||||
save_checkpoint({
|
||||
"epoch": epoch + 1,
|
||||
"arch": args.progress,
|
||||
"arch": args.resume,
|
||||
"state_dict": model.state_dict(),
|
||||
"best_prec1": args.best_pred,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
|
|
@ -298,7 +308,8 @@ if __name__ == "__main__":
|
|||
# Use DDP, spawn threads
|
||||
torch_mp.spawn(
|
||||
worker,
|
||||
args=(combined_params, ), # rank supplied automatically as 1st param
|
||||
args=(combined_params, ), # rank supplied at callee as 1st param
|
||||
# also above *has* to be 1-tuple else runtime expands Namespace.
|
||||
nprocs=combined_params.world_size,
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue