From c4905acf6d776f8e0dd4202eae6cbea605b82049 Mon Sep 17 00:00:00 2001 From: Zhengyi Chen Date: Sun, 3 Mar 2024 22:19:47 +0000 Subject: [PATCH] Sync --- _ShanghaiA-train.sh | 3 +-- train.py | 14 +++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/_ShanghaiA-train.sh b/_ShanghaiA-train.sh index 24f5e21..c885ade 100644 --- a/_ShanghaiA-train.sh +++ b/_ShanghaiA-train.sh @@ -22,8 +22,7 @@ export TMP=/disk/scratch/${STUDENT_ID}/ source /home/${STUDENT_ID}/miniconda3/bin/activate mlp-cuda python train.py \ - --debug=True \ --model='stn' \ --save_path ./save_file/ShanghaiA \ --batch_size 4 \ - # --gpus 0,1,2,3,4,5 \ \ No newline at end of file + --gpus 0,1,2,3,4,5 diff --git a/train.py b/train.py index e04b5f2..5dd5958 100644 --- a/train.py +++ b/train.py @@ -313,7 +313,7 @@ def valid_one_epoch(test_loader, model, device, args): mae = mae * 1.0 / (len(test_loader) * batch_size) mse = np.sqrt(mse / (len(test_loader)) * batch_size) - # nni.report_intermediate_result(mae) + nni.report_intermediate_result(mae) print("* MAE {mae:.3f} | MSE {mse:.3f} *".format( mae=mae, mse=mse )) @@ -322,14 +322,14 @@ def valid_one_epoch(test_loader, model, device, args): if __name__ == "__main__": - # tuner_params = nni.get_next_parameter() - # logger.debug("Generated hyperparameters: {}", tuner_params) - # combined_params = nni.utils.merge_parameter(ret_args, tuner_params) + tuner_params = nni.get_next_parameter() + logger.debug("Generated hyperparameters: {}", tuner_params) + combined_params = nni.utils.merge_parameter(ret_args, tuner_params) if args.debug: os.nice(15) - combined_params = args - logger.debug("Parameters: {}", combined_params) + #combined_params = args + #logger.debug("Parameters: {}", combined_params) if combined_params.use_ddp: # Use DDP, spawn threads @@ -341,4 +341,4 @@ if __name__ == "__main__": ) else: # No DDP, run in current thread - worker(0, combined_params) \ No newline at end of file + worker(0, combined_params)