...
This commit is contained in:
parent
ae9bc34fde
commit
0d35d607fe
2 changed files with 36 additions and 30 deletions
|
|
@ -52,7 +52,7 @@ def pre_dataset_sh():
|
||||||
|
|
||||||
# np.random.seed(0)
|
# np.random.seed(0)
|
||||||
# random.seed(0)
|
# random.seed(0)
|
||||||
for _, img_path in tqdm(img_paths, desc="Preprocessing Data"):
|
for img_path in tqdm(img_paths, desc="Preprocessing Data"):
|
||||||
img_data = cv2.imread(img_path)
|
img_data = cv2.imread(img_path)
|
||||||
mat = io.loadmat(
|
mat = io.loadmat(
|
||||||
img_path
|
img_path
|
||||||
|
|
|
||||||
64
train.py
64
train.py
|
|
@ -28,9 +28,9 @@ logger = logging.getLogger("train")
|
||||||
if not args.export_to_h5:
|
if not args.export_to_h5:
|
||||||
writer = SummaryWriter(args.save_path + "/tensorboard-run")
|
writer = SummaryWriter(args.save_path + "/tensorboard-run")
|
||||||
else:
|
else:
|
||||||
train_df = pd.DataFrame(columns=["l1loss", "composite-loss"])
|
train_df = pd.DataFrame(columns=["l1loss", "composite-loss"], dtype=float)
|
||||||
train_stat_file = args.save_path + "/train_stats.h5"
|
train_stat_file = args.save_path + "/train_stats.h5"
|
||||||
test_df = pd.DataFrame(columns=["mse", "mae"])
|
test_df = pd.DataFrame(columns=["mse", "mae"], dtype=float)
|
||||||
test_stat_file = args.save_path + "/test_stats.h5"
|
test_stat_file = args.save_path + "/test_stats.h5"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -268,32 +268,32 @@ def train_one_epoch(
|
||||||
device_type = "cuda"
|
device_type = "cuda"
|
||||||
|
|
||||||
# Desperate measure to reduce mem footprint...
|
# Desperate measure to reduce mem footprint...
|
||||||
with torch.autocast(device_type):
|
# with torch.autocast(device_type):
|
||||||
# fpass
|
# fpass
|
||||||
out, gt_count = model(img, kpoint)
|
out, gt_count = model(img, kpoint)
|
||||||
# loss
|
# loss
|
||||||
loss = criterion(out, gt_count) # wrt. transformer
|
loss = criterion(out, gt_count) # wrt. transformer
|
||||||
if args.export_to_h5:
|
if args.export_to_h5:
|
||||||
train_df.loc[epoch * i, "l1loss"] = loss.item()
|
train_df.loc[epoch * i, "l1loss"] = float(loss.item())
|
||||||
else:
|
else:
|
||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
"L1-loss wrt. xformer (train)", loss, epoch * i
|
"L1-loss wrt. xformer (train)", loss, epoch * i
|
||||||
)
|
|
||||||
|
|
||||||
loss += (
|
|
||||||
F.mse_loss( # stn: info retainment
|
|
||||||
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
|
|
||||||
gt_count_whole)
|
|
||||||
+ F.threshold( # stn: perspective correction
|
|
||||||
gt_count.view(batch_size, -1).var(dim=1).mean(),
|
|
||||||
threshold=loss.item(),
|
|
||||||
value=loss.item()
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if args.export_to_h5:
|
|
||||||
train_df.loc[epoch * i, "composite-loss"] = loss.item()
|
loss += (
|
||||||
else:
|
F.mse_loss( # stn: info retainment
|
||||||
writer.add_scalar("Composite loss (train)", loss, epoch * i)
|
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
|
||||||
|
gt_count_whole)
|
||||||
|
+ F.threshold( # stn: perspective correction
|
||||||
|
gt_count.view(batch_size, -1).var(dim=1).mean(),
|
||||||
|
threshold=loss.item(),
|
||||||
|
value=loss.item()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if args.export_to_h5:
|
||||||
|
train_df.loc[epoch * i, "composite-loss"] = float(loss.item())
|
||||||
|
else:
|
||||||
|
writer.add_scalar("Composite loss (train)", loss, epoch * i)
|
||||||
|
|
||||||
# free grad from mem
|
# free grad from mem
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
@ -364,6 +364,9 @@ def valid_one_epoch(test_loader, model, device, epoch, args):
|
||||||
torch.sum(pred_count).item()
|
torch.sum(pred_count).item()
|
||||||
))
|
))
|
||||||
|
|
||||||
|
if args.debug:
|
||||||
|
break
|
||||||
|
|
||||||
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
||||||
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
|
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
|
||||||
if args.export_to_h5:
|
if args.export_to_h5:
|
||||||
|
|
@ -373,14 +376,17 @@ def valid_one_epoch(test_loader, model, device, epoch, args):
|
||||||
else:
|
else:
|
||||||
writer.add_scalar("MAE (valid)", mae, epoch)
|
writer.add_scalar("MAE (valid)", mae, epoch)
|
||||||
writer.add_scalar("MSE (valid)", mse, epoch)
|
writer.add_scalar("MSE (valid)", mse, epoch)
|
||||||
if len(xformed) != 0:
|
if len(xformed) != 0 and not args.export_to_h5:
|
||||||
img_grid = torchvision.utils.make_grid(xformed)
|
img_grid = torchvision.utils.make_grid(xformed)
|
||||||
writer.add_image("STN: transformed image", img_grid, epoch)
|
writer.add_image("STN: transformed image", img_grid, epoch)
|
||||||
|
|
||||||
|
if not args.export_to_h5:
|
||||||
|
writer.flush()
|
||||||
|
|
||||||
nni.report_intermediate_result(mae)
|
nni.report_intermediate_result(mae)
|
||||||
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
|
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
|
||||||
mae=mae, mse=mse
|
mae=mae, mse=mse
|
||||||
))
|
))
|
||||||
writer.flush()
|
|
||||||
return mae
|
return mae
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue