diff --git a/train.py b/train.py index a6491d1..7d81931 100644 --- a/train.py +++ b/train.py @@ -282,8 +282,10 @@ def train_one_epoch( F.mse_loss( gt_count.view(batch_size, -1).sum(axis=1, keepdim=True), gt_count_whole) - + F.sigmoid( - gt_count.view(batch_size, -1).var(dim=1).mean()) + + ( # mean index of dispersion + gt_count.view(batch_size, -1).var(dim=1) + / gt_count.view(batch_size, -1).mean(dim=1) + ).mean() ) loss_stn.requires_grad = True optimizer.zero_grad(set_to_none=True)