Edited loss (ongoing)
This commit is contained in:
parent
9d2a30a226
commit
12adb6b7bf
1 changed files with 4 additions and 2 deletions
6
train.py
6
train.py
|
|
@ -282,8 +282,10 @@ def train_one_epoch(
|
||||||
F.mse_loss(
|
F.mse_loss(
|
||||||
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
|
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
|
||||||
gt_count_whole)
|
gt_count_whole)
|
||||||
+ F.sigmoid(
|
+ ( # mean index of dispersion
|
||||||
gt_count.view(batch_size, -1).var(dim=1).mean())
|
gt_count.view(batch_size, -1).var(dim=1)
|
||||||
|
/ gt_count.view(batch_size, -1).mean(dim=1)
|
||||||
|
).mean()
|
||||||
)
|
)
|
||||||
loss_stn.requires_grad = True
|
loss_stn.requires_grad = True
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue