This commit is contained in:
Zhengyi Chen 2024-03-03 22:19:47 +00:00
parent fc941ebaf7
commit c4905acf6d
2 changed files with 8 additions and 9 deletions

View file

@ -313,7 +313,7 @@ def valid_one_epoch(test_loader, model, device, args):
mae = mae * 1.0 / (len(test_loader) * batch_size)
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
# nni.report_intermediate_result(mae)
nni.report_intermediate_result(mae)
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
mae=mae, mse=mse
))
@ -322,14 +322,14 @@ def valid_one_epoch(test_loader, model, device, args):
if __name__ == "__main__":
# tuner_params = nni.get_next_parameter()
# logger.debug("Generated hyperparameters: {}", tuner_params)
# combined_params = nni.utils.merge_parameter(ret_args, tuner_params)
tuner_params = nni.get_next_parameter()
logger.debug("Generated hyperparameters: {}", tuner_params)
combined_params = nni.utils.merge_parameter(ret_args, tuner_params)
if args.debug:
os.nice(15)
combined_params = args
logger.debug("Parameters: {}", combined_params)
#combined_params = args
#logger.debug("Parameters: {}", combined_params)
if combined_params.use_ddp:
# Use DDP, spawn threads
@ -341,4 +341,4 @@ if __name__ == "__main__":
)
else:
# No DDP, run in current thread
worker(0, combined_params)
worker(0, combined_params)