diff --git a/visualize.py b/visualize.py new file mode 100644 index 0000000..2afd6ec --- /dev/null +++ b/visualize.py @@ -0,0 +1,72 @@ +import pandas as pd +import os +import matplotlib.pyplot as plt + +cur_dir = os.getcwd() +#relative_path = 'out-pack/1784598-stn-newloss-shb/test_stats.h5' +relative_paths = ['out-pack/filter_test_base-sha.h5', 'out-pack/filter_test_oldloss-sha.h5', 'out-pack/filter_test_newloss-sha.h5', 'out-pack/filter_test_base-shb.h5','out-pack/filter_test_newloss-shb.h5','out-pack/filter_test_oldloss-shb.h5'] +labels = ['base', 'oldloss', 'newloss'] +first_group = relative_paths[:3] +second_group =relative_paths[3:] +# Plot data +fig, axs = plt.subplots(2, 1, figsize=(10, 8)) #Create subplots +for relative_path, label in zip(first_group, labels): + file_name = os.path.join(cur_dir, relative_path) + data = pd.read_hdf(file_name, key='df') + # Plot mse with label as file name + axs[0].plot(data.index, data['mse'], linestyle='-', label=label) + # plot mae with label as file name + axs[1].plot(data.index, data['mae'], linestyle='-', label=label) + +# Set labels foe x and y axes for each subplot +axs[0].set_ylabel('MSE') +axs[0].set_xlabel('Epochs') +axs[1].set_ylabel('MAE') +axs[1].set_xlabel('Epochs') + +# Add Legend +axs[0].legend() +axs[1].legend() + +# Set title for the entire plot +plt.suptitle("Shangai-A dataset") + +# Adjust layout to prevent overlapping +plt.tight_layout() + +# Save the graph +plt.savefig('Shangai-A_datasets.png') + +# Plot data +fig, axs = plt.subplots(2, 1, figsize=(10, 8)) #Create subplots +for relative_path, label in zip(second_group, labels): + file_name = os.path.join(cur_dir, relative_path) + data = pd.read_hdf(file_name, key='df') + # Plot mse with label as file name + axs[0].plot(data.index, data['mse'], linestyle='-', label=label) + # plot mae with label as file name + axs[1].plot(data.index, data['mae'], linestyle='-', label=label) + +# Set labels foe x and y axes for each subplot +axs[0].set_ylabel('MSE') +axs[0].set_xlabel('Epochs') +axs[1].set_ylabel('MAE') +axs[1].set_xlabel('Epochs') + +# Add Legend +axs[0].legend() +axs[1].legend() + +# Set title for the entire plot +plt.suptitle("Shangai-B dataset") + +# Adjust layout to prevent overlapping +plt.tight_layout() + +# Save the graph +plt.savefig('Shangai-B_dataset.png') + +# Show plot +plt.show() + +