From 12adb6b7bf76eea3461b513b202f5869226b4ac5 Mon Sep 17 00:00:00 2001 From: Zhengyi Chen Date: Wed, 6 Mar 2024 22:30:43 +0000 Subject: [PATCH] Edited loss (ongoing) --- train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index a6491d1..7d81931 100644 --- a/train.py +++ b/train.py @@ -282,8 +282,10 @@ def train_one_epoch( F.mse_loss( gt_count.view(batch_size, -1).sum(axis=1, keepdim=True), gt_count_whole) - + F.sigmoid( - gt_count.view(batch_size, -1).var(dim=1).mean()) + + ( # mean index of dispersion + gt_count.view(batch_size, -1).var(dim=1) + / gt_count.view(batch_size, -1).mean(dim=1) + ).mean() ) loss_stn.requires_grad = True optimizer.zero_grad(set_to_none=True)