Code for the graph

This commit is contained in:
Mouleesh 2024-03-19 17:16:56 +00:00
parent 97782aad6e
commit 4a914bf551

72
visualize.py Normal file
View file

@ -0,0 +1,72 @@
import pandas as pd
import os
import matplotlib.pyplot as plt
cur_dir = os.getcwd()
#relative_path = 'out-pack/1784598-stn-newloss-shb/test_stats.h5'
relative_paths = ['out-pack/filter_test_base-sha.h5', 'out-pack/filter_test_oldloss-sha.h5', 'out-pack/filter_test_newloss-sha.h5', 'out-pack/filter_test_base-shb.h5','out-pack/filter_test_newloss-shb.h5','out-pack/filter_test_oldloss-shb.h5']
labels = ['base', 'oldloss', 'newloss']
first_group = relative_paths[:3]
second_group =relative_paths[3:]
# Plot data
fig, axs = plt.subplots(2, 1, figsize=(10, 8)) #Create subplots
for relative_path, label in zip(first_group, labels):
file_name = os.path.join(cur_dir, relative_path)
data = pd.read_hdf(file_name, key='df')
# Plot mse with label as file name
axs[0].plot(data.index, data['mse'], linestyle='-', label=label)
# plot mae with label as file name
axs[1].plot(data.index, data['mae'], linestyle='-', label=label)
# Set labels foe x and y axes for each subplot
axs[0].set_ylabel('MSE')
axs[0].set_xlabel('Epochs')
axs[1].set_ylabel('MAE')
axs[1].set_xlabel('Epochs')
# Add Legend
axs[0].legend()
axs[1].legend()
# Set title for the entire plot
plt.suptitle("Shangai-A dataset")
# Adjust layout to prevent overlapping
plt.tight_layout()
# Save the graph
plt.savefig('Shangai-A_datasets.png')
# Plot data
fig, axs = plt.subplots(2, 1, figsize=(10, 8)) #Create subplots
for relative_path, label in zip(second_group, labels):
file_name = os.path.join(cur_dir, relative_path)
data = pd.read_hdf(file_name, key='df')
# Plot mse with label as file name
axs[0].plot(data.index, data['mse'], linestyle='-', label=label)
# plot mae with label as file name
axs[1].plot(data.index, data['mae'], linestyle='-', label=label)
# Set labels foe x and y axes for each subplot
axs[0].set_ylabel('MSE')
axs[0].set_xlabel('Epochs')
axs[1].set_ylabel('MAE')
axs[1].set_xlabel('Epochs')
# Add Legend
axs[0].legend()
axs[1].legend()
# Set title for the entire plot
plt.suptitle("Shangai-B dataset")
# Adjust layout to prevent overlapping
plt.tight_layout()
# Save the graph
plt.savefig('Shangai-B_dataset.png')
# Show plot
plt.show()