This commit is contained in:
Zhengyi Chen 2024-03-03 22:19:47 +00:00
parent fc941ebaf7
commit c4905acf6d
2 changed files with 8 additions and 9 deletions

View file

@ -22,8 +22,7 @@ export TMP=/disk/scratch/${STUDENT_ID}/
source /home/${STUDENT_ID}/miniconda3/bin/activate mlp-cuda source /home/${STUDENT_ID}/miniconda3/bin/activate mlp-cuda
python train.py \ python train.py \
--debug=True \
--model='stn' \ --model='stn' \
--save_path ./save_file/ShanghaiA \ --save_path ./save_file/ShanghaiA \
--batch_size 4 \ --batch_size 4 \
# --gpus 0,1,2,3,4,5 \ --gpus 0,1,2,3,4,5

View file

@ -313,7 +313,7 @@ def valid_one_epoch(test_loader, model, device, args):
mae = mae * 1.0 / (len(test_loader) * batch_size) mae = mae * 1.0 / (len(test_loader) * batch_size)
mse = np.sqrt(mse / (len(test_loader)) * batch_size) mse = np.sqrt(mse / (len(test_loader)) * batch_size)
# nni.report_intermediate_result(mae) nni.report_intermediate_result(mae)
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format( print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
mae=mae, mse=mse mae=mae, mse=mse
)) ))
@ -322,14 +322,14 @@ def valid_one_epoch(test_loader, model, device, args):
if __name__ == "__main__": if __name__ == "__main__":
# tuner_params = nni.get_next_parameter() tuner_params = nni.get_next_parameter()
# logger.debug("Generated hyperparameters: {}", tuner_params) logger.debug("Generated hyperparameters: {}", tuner_params)
# combined_params = nni.utils.merge_parameter(ret_args, tuner_params) combined_params = nni.utils.merge_parameter(ret_args, tuner_params)
if args.debug: if args.debug:
os.nice(15) os.nice(15)
combined_params = args #combined_params = args
logger.debug("Parameters: {}", combined_params) #logger.debug("Parameters: {}", combined_params)
if combined_params.use_ddp: if combined_params.use_ddp:
# Use DDP, spawn threads # Use DDP, spawn threads