From 0d35d607fe605380d18acd95a74817ceb1543433 Mon Sep 17 00:00:00 2001 From: rubberhead Date: Wed, 6 Mar 2024 03:26:28 +0000 Subject: [PATCH] ... --- preprocess_data.py | 2 +- train.py | 64 +++++++++++++++++++++++++--------------------- 2 files changed, 36 insertions(+), 30 deletions(-) diff --git a/preprocess_data.py b/preprocess_data.py index 8bcc7e2..170be8e 100644 --- a/preprocess_data.py +++ b/preprocess_data.py @@ -52,7 +52,7 @@ def pre_dataset_sh(): # np.random.seed(0) # random.seed(0) - for _, img_path in tqdm(img_paths, desc="Preprocessing Data"): + for img_path in tqdm(img_paths, desc="Preprocessing Data"): img_data = cv2.imread(img_path) mat = io.loadmat( img_path diff --git a/train.py b/train.py index 247fe4f..474ac53 100644 --- a/train.py +++ b/train.py @@ -28,9 +28,9 @@ logger = logging.getLogger("train") if not args.export_to_h5: writer = SummaryWriter(args.save_path + "/tensorboard-run") else: - train_df = pd.DataFrame(columns=["l1loss", "composite-loss"]) + train_df = pd.DataFrame(columns=["l1loss", "composite-loss"], dtype=float) train_stat_file = args.save_path + "/train_stats.h5" - test_df = pd.DataFrame(columns=["mse", "mae"]) + test_df = pd.DataFrame(columns=["mse", "mae"], dtype=float) test_stat_file = args.save_path + "/test_stats.h5" @@ -268,32 +268,32 @@ def train_one_epoch( device_type = "cuda" # Desperate measure to reduce mem footprint... - with torch.autocast(device_type): - # fpass - out, gt_count = model(img, kpoint) - # loss - loss = criterion(out, gt_count) # wrt. transformer - if args.export_to_h5: - train_df.loc[epoch * i, "l1loss"] = loss.item() - else: - writer.add_scalar( - "L1-loss wrt. xformer (train)", loss, epoch * i - ) - - loss += ( - F.mse_loss( # stn: info retainment - gt_count.view(batch_size, -1).sum(axis=1, keepdim=True), - gt_count_whole) - + F.threshold( # stn: perspective correction - gt_count.view(batch_size, -1).var(dim=1).mean(), - threshold=loss.item(), - value=loss.item() - ) + # with torch.autocast(device_type): + # fpass + out, gt_count = model(img, kpoint) + # loss + loss = criterion(out, gt_count) # wrt. transformer + if args.export_to_h5: + train_df.loc[epoch * i, "l1loss"] = float(loss.item()) + else: + writer.add_scalar( + "L1-loss wrt. xformer (train)", loss, epoch * i ) - if args.export_to_h5: - train_df.loc[epoch * i, "composite-loss"] = loss.item() - else: - writer.add_scalar("Composite loss (train)", loss, epoch * i) + + loss += ( + F.mse_loss( # stn: info retainment + gt_count.view(batch_size, -1).sum(axis=1, keepdim=True), + gt_count_whole) + + F.threshold( # stn: perspective correction + gt_count.view(batch_size, -1).var(dim=1).mean(), + threshold=loss.item(), + value=loss.item() + ) + ) + if args.export_to_h5: + train_df.loc[epoch * i, "composite-loss"] = float(loss.item()) + else: + writer.add_scalar("Composite loss (train)", loss, epoch * i) # free grad from mem optimizer.zero_grad(set_to_none=True) @@ -364,6 +364,9 @@ def valid_one_epoch(test_loader, model, device, epoch, args): torch.sum(pred_count).item() )) + if args.debug: + break + mae = mae * 1.0 / (len(test_loader) * batch_size) mse = np.sqrt(mse / (len(test_loader)) * batch_size) if args.export_to_h5: @@ -373,14 +376,17 @@ def valid_one_epoch(test_loader, model, device, epoch, args): else: writer.add_scalar("MAE (valid)", mae, epoch) writer.add_scalar("MSE (valid)", mse, epoch) - if len(xformed) != 0: + if len(xformed) != 0 and not args.export_to_h5: img_grid = torchvision.utils.make_grid(xformed) writer.add_image("STN: transformed image", img_grid, epoch) + + if not args.export_to_h5: + writer.flush() + nni.report_intermediate_result(mae) print("* MAE {mae:.3f} | MSE {mse:.3f} *".format( mae=mae, mse=mse )) - writer.flush() return mae