mlp-project/arguments.py
2024-02-25 21:03:32 +00:00

86 lines
2.6 KiB
Python

import argparse
from typing import List
parser = argparse.ArgumentParser(
description = "Reverse-perspective + (TransCrowd | CSRNet)"
)
# Reproducibility configuration ==============================================
parser.add_argument(
"--seed", type=int, default=None, help="RNG seed"
)
# Data configuration =========================================================
parser.add_argument(
"--train_dataset", type=str, default="ShanghaiA", help="Training dataset"
)
parser.add_argument(
"--test_dataset", type=str, default="ShanghaiA", help="Evaluation dataset"
)
parser.add_argument(
"--print_freq", type=int, default=1,
help="Print evaluation data per <print-freq> training epochs"
)
parser.add_argument(
"--start_epoch", type=int, default=0, help="Epoch to start training from"
)
parser.add_argument(
"--save_path", type=str, default="./save/default/",
help="Directory to save checkpoints in"
)
# Model configuration ========================================================
parser.add_argument(
"--load_revnet_from", type=str, default=None,
help="Pre-trained reverse perspective model path"
)
parser.add_argument(
"--load_csrnet_from", type=str, default=None,
help="Pre-trained CSRNet model path"
)
parser.add_argument(
"--load_transcrowd_from", type=str, default=None,
help="Pre-trained TransCrowd model path"
)
# Optimizer configuration ====================================================
parser.add_argument(
"--weight_decay", type=float, default=5e-4, help="Weight decay"
)
parser.add_argument(
"--momentum", type=float, default=0.95, help="Momentum"
)
parser.add_argument(
"--best_pred", type=float, default=1e5,
help="Best prediction (MAE/MSE etc.)"
)
# Performance configuration ==================================================
parser.add_argument(
"--batch_size", type=int, default=8, help="Number of images per batch"
)
parser.add_argument(
"--epochs", type=int, default=250, help="Number of epochs to train"
)
parser.add_argument(
"--gpus", type=List[int], default=[0],
help="GPU IDs to be made available for training runtime"
)
# Runtime configuration ======================================================
parser.add_argument(
"--use_ddp", type=bool, default=False,
help="Use DistributedDataParallel training"
)
parser.add_argument(
"--ddp_world_size", type=int, default=1,
help="DDP: Number of processes in Pytorch process group"
)
# nni configuration ==========================================================
parser.add_argument(
"--lr", type=float, default=1e-5, help="Learning rate"
)
args = parser.parse_args()
ret_args = parser.parse_args()