Test
This commit is contained in:
parent
6456d3878e
commit
359ed8c579
2 changed files with 29 additions and 18 deletions
16
arguments.py
16
arguments.py
|
|
@ -11,6 +11,9 @@ parser.add_argument(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Data configuration =========================================================
|
# Data configuration =========================================================
|
||||||
|
parser.add_argument(
|
||||||
|
"--worker", type=int, default=4, help="Number of data loader processes"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train_dataset", type=str, default="ShanghaiA", help="Training dataset"
|
"--train_dataset", type=str, default="ShanghaiA", help="Training dataset"
|
||||||
)
|
)
|
||||||
|
|
@ -28,19 +31,16 @@ parser.add_argument(
|
||||||
"--save_path", type=str, default="./save/default/",
|
"--save_path", type=str, default="./save/default/",
|
||||||
help="Directory to save checkpoints in"
|
help="Directory to save checkpoints in"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--resume", type=str, default=None, help="Path to checkpoint-ed model pth"
|
||||||
|
)
|
||||||
|
|
||||||
# Model configuration ========================================================
|
# Model configuration ========================================================
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--load_revnet_from", type=str, default=None,
|
"--model", type=str, default="base", help="Model variant: <base|stn>"
|
||||||
help="Pre-trained reverse perspective model path"
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--load_csrnet_from", type=str, default=None,
|
"--pth_tar", type=str, default=None, help="Path to pre-training model pth"
|
||||||
help="Pre-trained CSRNet model path"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--load_transcrowd_from", type=str, default=None,
|
|
||||||
help="Pre-trained TransCrowd model path"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optimizer configuration ====================================================
|
# Optimizer configuration ====================================================
|
||||||
|
|
|
||||||
31
train.py
31
train.py
|
|
@ -103,10 +103,15 @@ def worker(rank: int, args: Namespace):
|
||||||
setup_process_group(rank, world_size)
|
setup_process_group(rank, world_size)
|
||||||
|
|
||||||
# Setup device for one proc
|
# Setup device for one proc
|
||||||
if args.use_ddp:
|
if args.use_ddp and torch.cuda.is_available():
|
||||||
device = torch.device(rank if torch.cuda.is_available() else "cpu")
|
device = torch.device(rank)
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
device = torch.device(args.gpus)
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
device = torch.device("mps")
|
||||||
else:
|
else:
|
||||||
device = torch.device(args.gpus if torch.cuda.is_available() else "cpu")
|
print("[!!!] Using CPU for inference. This will be slow...")
|
||||||
|
device = torch.device("cpu")
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
|
|
||||||
# Prepare training data
|
# Prepare training data
|
||||||
|
|
@ -130,6 +135,11 @@ def worker(rank: int, args: Namespace):
|
||||||
find_unused_parameters=True,
|
find_unused_parameters=True,
|
||||||
gradient_as_bucket_view=True # XXX: vital, otherwise OOM
|
gradient_as_bucket_view=True # XXX: vital, otherwise OOM
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
model = nn.DataParallel(
|
||||||
|
model,
|
||||||
|
device_ids=args.gpus
|
||||||
|
)
|
||||||
|
|
||||||
# criterion, optimizer, scheduler
|
# criterion, optimizer, scheduler
|
||||||
criterion = nn.L1Loss(size_average=False).to(device)
|
criterion = nn.L1Loss(size_average=False).to(device)
|
||||||
|
|
@ -148,13 +158,13 @@ def worker(rank: int, args: Namespace):
|
||||||
if not os.path.exists(args.save_path):
|
if not os.path.exists(args.save_path):
|
||||||
os.makedirs(args.save_path)
|
os.makedirs(args.save_path)
|
||||||
|
|
||||||
if args.progress:
|
if args.resume:
|
||||||
if os.path.isfile(args.progress):
|
if os.path.isfile(args.resume):
|
||||||
print("=> Loading checkpoint \'{}\'".format(args.progress))
|
print("=> Loading checkpoint \'{}\'".format(args.resume))
|
||||||
checkpoint = torch.load(args.progress)
|
checkpoint = torch.load(args.resume)
|
||||||
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||||
args.start_epoch = checkpoint['epoch']
|
args.start_epoch = checkpoint['epoch']
|
||||||
args.best_pred = checkpoint['best_prec1']
|
args.best_pred = checkpoint['best_prec1'] # sic.
|
||||||
else:
|
else:
|
||||||
print("=> Checkpoint not found!")
|
print("=> Checkpoint not found!")
|
||||||
|
|
||||||
|
|
@ -187,7 +197,7 @@ def worker(rank: int, args: Namespace):
|
||||||
if not args.use_ddp or torch.distributed.get_rank() == 0:
|
if not args.use_ddp or torch.distributed.get_rank() == 0:
|
||||||
save_checkpoint({
|
save_checkpoint({
|
||||||
"epoch": epoch + 1,
|
"epoch": epoch + 1,
|
||||||
"arch": args.progress,
|
"arch": args.resume,
|
||||||
"state_dict": model.state_dict(),
|
"state_dict": model.state_dict(),
|
||||||
"best_prec1": args.best_pred,
|
"best_prec1": args.best_pred,
|
||||||
"optimizer": optimizer.state_dict(),
|
"optimizer": optimizer.state_dict(),
|
||||||
|
|
@ -298,7 +308,8 @@ if __name__ == "__main__":
|
||||||
# Use DDP, spawn threads
|
# Use DDP, spawn threads
|
||||||
torch_mp.spawn(
|
torch_mp.spawn(
|
||||||
worker,
|
worker,
|
||||||
args=(combined_params, ), # rank supplied automatically as 1st param
|
args=(combined_params, ), # rank supplied at callee as 1st param
|
||||||
|
# also above *has* to be 1-tuple else runtime expands Namespace.
|
||||||
nprocs=combined_params.world_size,
|
nprocs=combined_params.world_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue