This commit is contained in:
Zhengyi Chen 2024-03-06 03:26:28 +00:00
parent ae9bc34fde
commit 0d35d607fe
2 changed files with 36 additions and 30 deletions

View file

@ -52,7 +52,7 @@ def pre_dataset_sh():
# np.random.seed(0) # np.random.seed(0)
# random.seed(0) # random.seed(0)
for _, img_path in tqdm(img_paths, desc="Preprocessing Data"): for img_path in tqdm(img_paths, desc="Preprocessing Data"):
img_data = cv2.imread(img_path) img_data = cv2.imread(img_path)
mat = io.loadmat( mat = io.loadmat(
img_path img_path

View file

@ -28,9 +28,9 @@ logger = logging.getLogger("train")
if not args.export_to_h5: if not args.export_to_h5:
writer = SummaryWriter(args.save_path + "/tensorboard-run") writer = SummaryWriter(args.save_path + "/tensorboard-run")
else: else:
train_df = pd.DataFrame(columns=["l1loss", "composite-loss"]) train_df = pd.DataFrame(columns=["l1loss", "composite-loss"], dtype=float)
train_stat_file = args.save_path + "/train_stats.h5" train_stat_file = args.save_path + "/train_stats.h5"
test_df = pd.DataFrame(columns=["mse", "mae"]) test_df = pd.DataFrame(columns=["mse", "mae"], dtype=float)
test_stat_file = args.save_path + "/test_stats.h5" test_stat_file = args.save_path + "/test_stats.h5"
@ -268,32 +268,32 @@ def train_one_epoch(
device_type = "cuda" device_type = "cuda"
# Desperate measure to reduce mem footprint... # Desperate measure to reduce mem footprint...
with torch.autocast(device_type): # with torch.autocast(device_type):
# fpass # fpass
out, gt_count = model(img, kpoint) out, gt_count = model(img, kpoint)
# loss # loss
loss = criterion(out, gt_count) # wrt. transformer loss = criterion(out, gt_count) # wrt. transformer
if args.export_to_h5: if args.export_to_h5:
train_df.loc[epoch * i, "l1loss"] = loss.item() train_df.loc[epoch * i, "l1loss"] = float(loss.item())
else: else:
writer.add_scalar( writer.add_scalar(
"L1-loss wrt. xformer (train)", loss, epoch * i "L1-loss wrt. xformer (train)", loss, epoch * i
)
loss += (
F.mse_loss( # stn: info retainment
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
gt_count_whole)
+ F.threshold( # stn: perspective correction
gt_count.view(batch_size, -1).var(dim=1).mean(),
threshold=loss.item(),
value=loss.item()
)
) )
if args.export_to_h5:
train_df.loc[epoch * i, "composite-loss"] = loss.item() loss += (
else: F.mse_loss( # stn: info retainment
writer.add_scalar("Composite loss (train)", loss, epoch * i) gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
gt_count_whole)
+ F.threshold( # stn: perspective correction
gt_count.view(batch_size, -1).var(dim=1).mean(),
threshold=loss.item(),
value=loss.item()
)
)
if args.export_to_h5:
train_df.loc[epoch * i, "composite-loss"] = float(loss.item())
else:
writer.add_scalar("Composite loss (train)", loss, epoch * i)
# free grad from mem # free grad from mem
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
@ -364,6 +364,9 @@ def valid_one_epoch(test_loader, model, device, epoch, args):
torch.sum(pred_count).item() torch.sum(pred_count).item()
)) ))
if args.debug:
break
mae = mae * 1.0 / (len(test_loader) * batch_size) mae = mae * 1.0 / (len(test_loader) * batch_size)
mse = np.sqrt(mse / (len(test_loader)) * batch_size) mse = np.sqrt(mse / (len(test_loader)) * batch_size)
if args.export_to_h5: if args.export_to_h5:
@ -373,14 +376,17 @@ def valid_one_epoch(test_loader, model, device, epoch, args):
else: else:
writer.add_scalar("MAE (valid)", mae, epoch) writer.add_scalar("MAE (valid)", mae, epoch)
writer.add_scalar("MSE (valid)", mse, epoch) writer.add_scalar("MSE (valid)", mse, epoch)
if len(xformed) != 0: if len(xformed) != 0 and not args.export_to_h5:
img_grid = torchvision.utils.make_grid(xformed) img_grid = torchvision.utils.make_grid(xformed)
writer.add_image("STN: transformed image", img_grid, epoch) writer.add_image("STN: transformed image", img_grid, epoch)
if not args.export_to_h5:
writer.flush()
nni.report_intermediate_result(mae) nni.report_intermediate_result(mae)
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format( print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
mae=mae, mse=mse mae=mae, mse=mse
)) ))
writer.flush()
return mae return mae