72 lines
2.1 KiB
Python
72 lines
2.1 KiB
Python
import pandas as pd
|
|
import os
|
|
import matplotlib.pyplot as plt
|
|
|
|
cur_dir = os.getcwd()
|
|
#relative_path = 'out-pack/1784598-stn-newloss-shb/test_stats.h5'
|
|
relative_paths = ['out-pack/filter_test_base-sha.h5', 'out-pack/filter_test_oldloss-sha.h5', 'out-pack/filter_test_newloss-sha.h5', 'out-pack/filter_test_base-shb.h5','out-pack/filter_test_newloss-shb.h5','out-pack/filter_test_oldloss-shb.h5']
|
|
labels = ['base', 'oldloss', 'newloss']
|
|
first_group = relative_paths[:3]
|
|
second_group =relative_paths[3:]
|
|
# Plot data
|
|
fig, axs = plt.subplots(2, 1, figsize=(10, 8)) #Create subplots
|
|
for relative_path, label in zip(first_group, labels):
|
|
file_name = os.path.join(cur_dir, relative_path)
|
|
data = pd.read_hdf(file_name, key='df')
|
|
# Plot mse with label as file name
|
|
axs[0].plot(data.index, data['mse'], linestyle='-', label=label)
|
|
# plot mae with label as file name
|
|
axs[1].plot(data.index, data['mae'], linestyle='-', label=label)
|
|
|
|
# Set labels foe x and y axes for each subplot
|
|
axs[0].set_ylabel('MSE')
|
|
axs[0].set_xlabel('Epochs')
|
|
axs[1].set_ylabel('MAE')
|
|
axs[1].set_xlabel('Epochs')
|
|
|
|
# Add Legend
|
|
axs[0].legend()
|
|
axs[1].legend()
|
|
|
|
# Set title for the entire plot
|
|
plt.suptitle("Shangai-A dataset")
|
|
|
|
# Adjust layout to prevent overlapping
|
|
plt.tight_layout()
|
|
|
|
# Save the graph
|
|
plt.savefig('Shangai-A_datasets.png')
|
|
|
|
# Plot data
|
|
fig, axs = plt.subplots(2, 1, figsize=(10, 8)) #Create subplots
|
|
for relative_path, label in zip(second_group, labels):
|
|
file_name = os.path.join(cur_dir, relative_path)
|
|
data = pd.read_hdf(file_name, key='df')
|
|
# Plot mse with label as file name
|
|
axs[0].plot(data.index, data['mse'], linestyle='-', label=label)
|
|
# plot mae with label as file name
|
|
axs[1].plot(data.index, data['mae'], linestyle='-', label=label)
|
|
|
|
# Set labels foe x and y axes for each subplot
|
|
axs[0].set_ylabel('MSE')
|
|
axs[0].set_xlabel('Epochs')
|
|
axs[1].set_ylabel('MAE')
|
|
axs[1].set_xlabel('Epochs')
|
|
|
|
# Add Legend
|
|
axs[0].legend()
|
|
axs[1].legend()
|
|
|
|
# Set title for the entire plot
|
|
plt.suptitle("Shangai-B dataset")
|
|
|
|
# Adjust layout to prevent overlapping
|
|
plt.tight_layout()
|
|
|
|
# Save the graph
|
|
plt.savefig('Shangai-B_dataset.png')
|
|
|
|
# Show plot
|
|
plt.show()
|
|
|
|
|