Code for the graph
This commit is contained in:
parent
97782aad6e
commit
4a914bf551
1 changed files with 72 additions and 0 deletions
72
visualize.py
Normal file
72
visualize.py
Normal file
|
|
@ -0,0 +1,72 @@
|
||||||
|
import pandas as pd
|
||||||
|
import os
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
cur_dir = os.getcwd()
|
||||||
|
#relative_path = 'out-pack/1784598-stn-newloss-shb/test_stats.h5'
|
||||||
|
relative_paths = ['out-pack/filter_test_base-sha.h5', 'out-pack/filter_test_oldloss-sha.h5', 'out-pack/filter_test_newloss-sha.h5', 'out-pack/filter_test_base-shb.h5','out-pack/filter_test_newloss-shb.h5','out-pack/filter_test_oldloss-shb.h5']
|
||||||
|
labels = ['base', 'oldloss', 'newloss']
|
||||||
|
first_group = relative_paths[:3]
|
||||||
|
second_group =relative_paths[3:]
|
||||||
|
# Plot data
|
||||||
|
fig, axs = plt.subplots(2, 1, figsize=(10, 8)) #Create subplots
|
||||||
|
for relative_path, label in zip(first_group, labels):
|
||||||
|
file_name = os.path.join(cur_dir, relative_path)
|
||||||
|
data = pd.read_hdf(file_name, key='df')
|
||||||
|
# Plot mse with label as file name
|
||||||
|
axs[0].plot(data.index, data['mse'], linestyle='-', label=label)
|
||||||
|
# plot mae with label as file name
|
||||||
|
axs[1].plot(data.index, data['mae'], linestyle='-', label=label)
|
||||||
|
|
||||||
|
# Set labels foe x and y axes for each subplot
|
||||||
|
axs[0].set_ylabel('MSE')
|
||||||
|
axs[0].set_xlabel('Epochs')
|
||||||
|
axs[1].set_ylabel('MAE')
|
||||||
|
axs[1].set_xlabel('Epochs')
|
||||||
|
|
||||||
|
# Add Legend
|
||||||
|
axs[0].legend()
|
||||||
|
axs[1].legend()
|
||||||
|
|
||||||
|
# Set title for the entire plot
|
||||||
|
plt.suptitle("Shangai-A dataset")
|
||||||
|
|
||||||
|
# Adjust layout to prevent overlapping
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# Save the graph
|
||||||
|
plt.savefig('Shangai-A_datasets.png')
|
||||||
|
|
||||||
|
# Plot data
|
||||||
|
fig, axs = plt.subplots(2, 1, figsize=(10, 8)) #Create subplots
|
||||||
|
for relative_path, label in zip(second_group, labels):
|
||||||
|
file_name = os.path.join(cur_dir, relative_path)
|
||||||
|
data = pd.read_hdf(file_name, key='df')
|
||||||
|
# Plot mse with label as file name
|
||||||
|
axs[0].plot(data.index, data['mse'], linestyle='-', label=label)
|
||||||
|
# plot mae with label as file name
|
||||||
|
axs[1].plot(data.index, data['mae'], linestyle='-', label=label)
|
||||||
|
|
||||||
|
# Set labels foe x and y axes for each subplot
|
||||||
|
axs[0].set_ylabel('MSE')
|
||||||
|
axs[0].set_xlabel('Epochs')
|
||||||
|
axs[1].set_ylabel('MAE')
|
||||||
|
axs[1].set_xlabel('Epochs')
|
||||||
|
|
||||||
|
# Add Legend
|
||||||
|
axs[0].legend()
|
||||||
|
axs[1].legend()
|
||||||
|
|
||||||
|
# Set title for the entire plot
|
||||||
|
plt.suptitle("Shangai-B dataset")
|
||||||
|
|
||||||
|
# Adjust layout to prevent overlapping
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# Save the graph
|
||||||
|
plt.savefig('Shangai-B_dataset.png')
|
||||||
|
|
||||||
|
# Show plot
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue