86 lines
2.7 KiB
Python
86 lines
2.7 KiB
Python
import argparse
|
|
from typing import List
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description = "Reverse-perspective + (TransCrowd | CSRNet)"
|
|
)
|
|
|
|
# Reproducibility configuration ==============================================
|
|
parser.add_argument(
|
|
"--seed", type=int, default=None, help="RNG seed"
|
|
)
|
|
|
|
# Data configuration =========================================================
|
|
parser.add_argument(
|
|
"--worker", type=int, default=4, help="Number of data loader processes"
|
|
)
|
|
parser.add_argument(
|
|
"--train_dataset", type=str, default="ShanghaiA", help="Training dataset"
|
|
)
|
|
parser.add_argument(
|
|
"--test_dataset", type=str, default="ShanghaiA", help="Evaluation dataset"
|
|
)
|
|
parser.add_argument(
|
|
"--print_freq", type=int, default=1,
|
|
help="Print evaluation data per <print-freq> training epochs"
|
|
)
|
|
parser.add_argument(
|
|
"--start_epoch", type=int, default=0, help="Epoch to start training from"
|
|
)
|
|
parser.add_argument(
|
|
"--save_path", type=str, default="./save/default/",
|
|
help="Directory to save checkpoints in"
|
|
)
|
|
parser.add_argument(
|
|
"--resume", type=str, default=None, help="Path to checkpoint-ed model pth"
|
|
)
|
|
|
|
# Model configuration ========================================================
|
|
parser.add_argument(
|
|
"--model", type=str, default="base", help="Model variant: <base|stn>"
|
|
)
|
|
parser.add_argument(
|
|
"--pth_tar", type=str, default=None, help="Path to pre-training model pth"
|
|
)
|
|
|
|
# Optimizer configuration ====================================================
|
|
parser.add_argument(
|
|
"--weight_decay", type=float, default=5e-4, help="Weight decay"
|
|
)
|
|
parser.add_argument(
|
|
"--momentum", type=float, default=0.95, help="Momentum"
|
|
)
|
|
parser.add_argument(
|
|
"--best_pred", type=float, default=1e5,
|
|
help="Best prediction (MAE/MSE etc.)"
|
|
)
|
|
|
|
# Performance configuration ==================================================
|
|
parser.add_argument(
|
|
"--batch_size", type=int, default=8, help="Number of images per batch"
|
|
)
|
|
parser.add_argument(
|
|
"--epochs", type=int, default=250, help="Number of epochs to train"
|
|
)
|
|
parser.add_argument(
|
|
"--gpus", type=List[int], default=[0],
|
|
help="GPU IDs to be made available for training runtime"
|
|
)
|
|
|
|
# Runtime configuration ======================================================
|
|
parser.add_argument(
|
|
"--use_ddp", type=bool, default=False,
|
|
help="Use DistributedDataParallel training"
|
|
)
|
|
parser.add_argument(
|
|
"--ddp_world_size", type=int, default=1,
|
|
help="DDP: Number of processes in Pytorch process group"
|
|
)
|
|
|
|
# nni configuration ==========================================================
|
|
parser.add_argument(
|
|
"--lr", type=float, default=1e-5, help="Learning rate"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
ret_args = parser.parse_args()
|