import pandas as pd import os import matplotlib.pyplot as plt cur_dir = os.getcwd() #relative_path = 'out-pack/1784598-stn-newloss-shb/test_stats.h5' relative_paths = ['out-pack/filter_test_base-sha.h5', 'out-pack/filter_test_oldloss-sha.h5', 'out-pack/filter_test_newloss-sha.h5', 'out-pack/filter_test_base-shb.h5','out-pack/filter_test_newloss-shb.h5','out-pack/filter_test_oldloss-shb.h5'] labels = ['base', 'oldloss', 'newloss'] first_group = relative_paths[:3] second_group =relative_paths[3:] # Plot data fig, axs = plt.subplots(2, 1, figsize=(10, 8)) #Create subplots for relative_path, label in zip(first_group, labels): file_name = os.path.join(cur_dir, relative_path) data = pd.read_hdf(file_name, key='df') # Plot mse with label as file name axs[0].plot(data.index, data['mse'], linestyle='-', label=label) # plot mae with label as file name axs[1].plot(data.index, data['mae'], linestyle='-', label=label) # Set labels foe x and y axes for each subplot axs[0].set_ylabel('MSE') axs[0].set_xlabel('Epochs') axs[1].set_ylabel('MAE') axs[1].set_xlabel('Epochs') # Add Legend axs[0].legend() axs[1].legend() # Set title for the entire plot plt.suptitle("Shangai-A dataset") # Adjust layout to prevent overlapping plt.tight_layout() # Save the graph plt.savefig('Shangai-A_datasets.png') # Plot data fig, axs = plt.subplots(2, 1, figsize=(10, 8)) #Create subplots for relative_path, label in zip(second_group, labels): file_name = os.path.join(cur_dir, relative_path) data = pd.read_hdf(file_name, key='df') # Plot mse with label as file name axs[0].plot(data.index, data['mse'], linestyle='-', label=label) # plot mae with label as file name axs[1].plot(data.index, data['mae'], linestyle='-', label=label) # Set labels foe x and y axes for each subplot axs[0].set_ylabel('MSE') axs[0].set_xlabel('Epochs') axs[1].set_ylabel('MAE') axs[1].set_xlabel('Epochs') # Add Legend axs[0].legend() axs[1].legend() # Set title for the entire plot plt.suptitle("Shangai-B dataset") # Adjust layout to prevent overlapping plt.tight_layout() # Save the graph plt.savefig('Shangai-B_dataset.png') # Show plot plt.show()