From 94867bd8bfbb36a433e338392f6323a93284b315 Mon Sep 17 00:00:00 2001 From: Zhengyi Chen Date: Mon, 4 Mar 2024 14:39:58 +0000 Subject: [PATCH] Use mse instead of l1 for info retainment loss, fix mae --- train.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index db51133..8732497 100644 --- a/train.py +++ b/train.py @@ -257,7 +257,7 @@ def train_one_epoch( # loss loss = criterion(out, gt_count) # wrt. transformer loss += ( - criterion( # stn: info retainment + F.mse_loss( # stn: info retainment gt_count.view(batch_size, -1).sum(axis=1, keepdim=True), gt_count_whole) + F.threshold( # stn: perspective correction @@ -277,8 +277,8 @@ def train_one_epoch( optimizer.step() # periodic message - if i % args.print_freq == 0: - print("Epoch {}: {}/{}".format(epoch, i, len(train_loader))) + # if i % args.print_freq == 0: + # print("Epoch {}: {}/{}".format(epoch, i, len(train_loader))) if args.debug: break @@ -322,15 +322,15 @@ def valid_one_epoch(test_loader, model, device, args): diff = torch.abs(gt_count_whole - torch.sum(pred_count)).item() mae += diff mse += diff ** 2 - mae = mae * 1.0 / (len(test_loader) * batch_size) - mse = np.sqrt(mse / (len(test_loader)) * batch_size) if i % 5 == 0: - print("[valid_one_epoch] {}\t| Gt {:.2f} Pred {:.4f}\t| mae {:.4f} mse {:.4f} |".format( + print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} | mae {:.4f} mse {:.4f} |".format( fname[0], torch.sum(gt_count_whole).item(), torch.sum(pred_count).item(), mae, mse )) + mae = mae * 1.0 / (len(test_loader) * batch_size) + mse = np.sqrt(mse / (len(test_loader)) * batch_size) nni.report_intermediate_result(mae) print("* MAE {mae:.3f} | MSE {mse:.3f} *".format( mae=mae, mse=mse