FEATURE: export to DataFrame in hdf5
This commit is contained in:
parent
208091ce8a
commit
ae9bc34fde
2 changed files with 35 additions and 6 deletions
|
|
@ -79,6 +79,10 @@ parser.add_argument(
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--debug", type=bool, default=False
|
"--debug", type=bool, default=False
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--export_to_h5", type=bool, default=True,
|
||||||
|
help="Export training & validation statistics (.h5 of pd.DataFrame)"
|
||||||
|
)
|
||||||
|
|
||||||
# nni configuration ==========================================================
|
# nni configuration ==========================================================
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
|
||||||
37
train.py
37
train.py
|
|
@ -14,6 +14,7 @@ import torchvision
|
||||||
import nni
|
import nni
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
from model.transcrowd_gap import VisionTransformerGAP
|
from model.transcrowd_gap import VisionTransformerGAP
|
||||||
from arguments import args, ret_args
|
from arguments import args, ret_args
|
||||||
|
|
@ -23,7 +24,15 @@ from model.transcrowd_gap import *
|
||||||
from checkpoint import save_checkpoint
|
from checkpoint import save_checkpoint
|
||||||
|
|
||||||
logger = logging.getLogger("train")
|
logger = logging.getLogger("train")
|
||||||
writer = SummaryWriter(args.save_path + "/tensorboard-run")
|
|
||||||
|
if not args.export_to_h5:
|
||||||
|
writer = SummaryWriter(args.save_path + "/tensorboard-run")
|
||||||
|
else:
|
||||||
|
train_df = pd.DataFrame(columns=["l1loss", "composite-loss"])
|
||||||
|
train_stat_file = args.save_path + "/train_stats.h5"
|
||||||
|
test_df = pd.DataFrame(columns=["mse", "mae"])
|
||||||
|
test_stat_file = args.save_path + "/test_stats.h5"
|
||||||
|
|
||||||
|
|
||||||
def setup_process_group(
|
def setup_process_group(
|
||||||
rank: int,
|
rank: int,
|
||||||
|
|
@ -264,7 +273,12 @@ def train_one_epoch(
|
||||||
out, gt_count = model(img, kpoint)
|
out, gt_count = model(img, kpoint)
|
||||||
# loss
|
# loss
|
||||||
loss = criterion(out, gt_count) # wrt. transformer
|
loss = criterion(out, gt_count) # wrt. transformer
|
||||||
writer.add_scalar("L1-loss wrt. xformer (train)", loss, epoch * i)
|
if args.export_to_h5:
|
||||||
|
train_df.loc[epoch * i, "l1loss"] = loss.item()
|
||||||
|
else:
|
||||||
|
writer.add_scalar(
|
||||||
|
"L1-loss wrt. xformer (train)", loss, epoch * i
|
||||||
|
)
|
||||||
|
|
||||||
loss += (
|
loss += (
|
||||||
F.mse_loss( # stn: info retainment
|
F.mse_loss( # stn: info retainment
|
||||||
|
|
@ -276,7 +290,10 @@ def train_one_epoch(
|
||||||
value=loss.item()
|
value=loss.item()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
writer.add_scalar("Composite loss (train)", loss, epoch * i)
|
if args.export_to_h5:
|
||||||
|
train_df.loc[epoch * i, "composite-loss"] = loss.item()
|
||||||
|
else:
|
||||||
|
writer.add_scalar("Composite loss (train)", loss, epoch * i)
|
||||||
|
|
||||||
# free grad from mem
|
# free grad from mem
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
@ -291,7 +308,10 @@ def train_one_epoch(
|
||||||
break
|
break
|
||||||
|
|
||||||
# Flush writer
|
# Flush writer
|
||||||
writer.flush()
|
if args.export_to_h5:
|
||||||
|
train_df.to_hdf(train_stat_file, key="df", mode="a", append=True)
|
||||||
|
else:
|
||||||
|
writer.flush()
|
||||||
|
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
|
|
@ -346,8 +366,13 @@ def valid_one_epoch(test_loader, model, device, epoch, args):
|
||||||
|
|
||||||
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
||||||
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
|
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
|
||||||
writer.add_scalar("MAE (valid)", mae, epoch)
|
if args.export_to_h5:
|
||||||
writer.add_scalar("MSE (valid)", mse, epoch)
|
test_df.loc[epoch, "mae"] = mae
|
||||||
|
test_df.loc[epoch, "mse"] = mse
|
||||||
|
test_df.to_hdf(test_stat_file, key="df", mode="a", append=True)
|
||||||
|
else:
|
||||||
|
writer.add_scalar("MAE (valid)", mae, epoch)
|
||||||
|
writer.add_scalar("MSE (valid)", mse, epoch)
|
||||||
if len(xformed) != 0:
|
if len(xformed) != 0:
|
||||||
img_grid = torchvision.utils.make_grid(xformed)
|
img_grid = torchvision.utils.make_grid(xformed)
|
||||||
writer.add_image("STN: transformed image", img_grid, epoch)
|
writer.add_image("STN: transformed image", img_grid, epoch)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue