Edited loss (ongoing)

This commit is contained in:
Zhengyi Chen 2024-03-06 22:30:43 +00:00
parent 9d2a30a226
commit 12adb6b7bf

View file

@ -282,8 +282,10 @@ def train_one_epoch(
F.mse_loss(
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
gt_count_whole)
+ F.sigmoid(
gt_count.view(batch_size, -1).var(dim=1).mean())
+ ( # mean index of dispersion
gt_count.view(batch_size, -1).var(dim=1)
/ gt_count.view(batch_size, -1).mean(dim=1)
).mean()
)
loss_stn.requires_grad = True
optimizer.zero_grad(set_to_none=True)