lemme cook alright?
This commit is contained in:
parent
b6d2460060
commit
62df7464e4
9 changed files with 504 additions and 3 deletions
86
arguments.py
Normal file
86
arguments.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
import argparse
|
||||
from typing import List
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description = "Reverse-perspective + (TransCrowd | CSRNet)"
|
||||
)
|
||||
|
||||
# Reproducibility configuration ==============================================
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=None, help="RNG seed"
|
||||
)
|
||||
|
||||
# Data configuration =========================================================
|
||||
parser.add_argument(
|
||||
"--train_dataset", type=str, default="ShanghaiA", help="Training dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_dataset", type=str, default="ShanghaiA", help="Evaluation dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print_freq", type=int, default=1,
|
||||
help="Print evaluation data per <print-freq> training epochs"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start_epoch", type=int, default=0, help="Epoch to start training from"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_path", type=str, default="./save/default/",
|
||||
help="Directory to save checkpoints in"
|
||||
)
|
||||
|
||||
# Model configuration ========================================================
|
||||
parser.add_argument(
|
||||
"--load_revnet_from", type=str, default=None,
|
||||
help="Pre-trained reverse perspective model path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_csrnet_from", type=str, default=None,
|
||||
help="Pre-trained CSRNet model path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_transcrowd_from", type=str, default=None,
|
||||
help="Pre-trained TransCrowd model path"
|
||||
)
|
||||
|
||||
# Optimizer configuration ====================================================
|
||||
parser.add_argument(
|
||||
"--weight_decay", type=float, default=5e-4, help="Weight decay"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--momentum", type=float, default=0.95, help="Momentum"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--best_pred", type=float, default=1e5,
|
||||
help="Best prediction (MAE/MSE etc.)"
|
||||
)
|
||||
|
||||
# Performance configuration ==================================================
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=8, help="Number of images per batch"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs", type=int, default=250, help="Number of epochs to train"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpus", type=List[int], default=[0],
|
||||
help="GPU IDs to be made available for training runtime"
|
||||
)
|
||||
|
||||
# Runtime configuration ======================================================
|
||||
parser.add_argument(
|
||||
"--use_ddp", type=bool, default=False,
|
||||
help="Use DistributedDataParallel training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ddp_world_size", type=int, default=1,
|
||||
help="DDP: Number of processes in Pytorch process group"
|
||||
)
|
||||
|
||||
# nni configuration ==========================================================
|
||||
parser.add_argument(
|
||||
"--lr", type=float, default=1e-5, help="Learning rate"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
ret_args = parser.parse_args()
|
||||
Loading…
Add table
Add a link
Reference in a new issue