import argparse from typing import List parser = argparse.ArgumentParser( description = "Reverse-perspective + (TransCrowd | CSRNet)" ) # Reproducibility configuration ============================================== parser.add_argument( "--seed", type=int, default=None, help="RNG seed" ) # Data configuration ========================================================= parser.add_argument( "--workers", type=int, default=4, help="Number of data loader processes" ) parser.add_argument( "--train_dataset", type=str, default="ShanghaiA", help="Training dataset" ) parser.add_argument( "--eval_dataset", type=str, default="ShanghaiA", help="Evaluation dataset" ) parser.add_argument( "--print_freq", type=int, default=1, help="Print evaluation data per training epochs" ) parser.add_argument( "--start_epoch", type=int, default=0, help="Epoch to start training from" ) parser.add_argument( "--save_path", type=str, default="./save/default/", help="Directory to save checkpoints in" ) parser.add_argument( "--resume", type=str, default=None, help="Path to checkpoint-ed model pth" ) # Model configuration ======================================================== parser.add_argument( "--model", type=str, default="base", help="Model variant: " ) parser.add_argument( "--pth_tar", type=str, default=None, help="Path to pre-training model pth" ) # Optimizer configuration ==================================================== parser.add_argument( "--weight_decay", type=float, default=5e-4, help="Weight decay" ) parser.add_argument( "--momentum", type=float, default=0.95, help="Momentum" ) parser.add_argument( "--best_pred", type=float, default=1e5, help="Best prediction (MAE/MSE etc.)" ) # Performance configuration ================================================== parser.add_argument( "--batch_size", type=int, default=8, help="Number of images per batch" ) parser.add_argument( "--epochs", type=int, default=250, help="Number of epochs to train" ) parser.add_argument( "--gpus", type=str, default='0', help="GPU IDs to be made available for training runtime" ) # Runtime configuration ====================================================== parser.add_argument( "--use_ddp", type=bool, default=False, help="Use DistributedDataParallel training" ) parser.add_argument( "--ddp_world_size", type=int, default=1, help="DDP: Number of processes in Pytorch process group" ) parser.add_argument( "--debug", type=bool, default=False ) # nni configuration ========================================================== parser.add_argument( "--lr", type=float, default=1e-5, help="Learning rate" ) args = parser.parse_args() ret_args = parser.parse_args()