Edited loss fn to incorporate index of dispersion
This commit is contained in:
parent
9d2a30a226
commit
c228fc27cc
1 changed files with 12 additions and 6 deletions
18
train.py
18
train.py
|
|
@ -278,12 +278,18 @@ def train_one_epoch(
|
|||
# loss & bpass & etc.
|
||||
if isinstance(model.module, STNet_VisionTransformerGAP):
|
||||
loss_xformer = criterion(out, gt_count)
|
||||
loss_stn = (
|
||||
F.mse_loss(
|
||||
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
|
||||
gt_count_whole)
|
||||
+ F.sigmoid(
|
||||
gt_count.view(batch_size, -1).var(dim=1).mean())
|
||||
mse_info_retention = F.mse_loss(
|
||||
input =gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
|
||||
target=gt_count_whole
|
||||
)
|
||||
mean_dispersion_idx = (
|
||||
gt_count.view(batch_size, -1).var(dim=1) /
|
||||
gt_count.view(batch_size, -1).mean(dim=1)
|
||||
).mean()
|
||||
loss_stn = mse_info_retention + F.threshold(
|
||||
mean_dispersion_idx,
|
||||
threshold=mse_info_retention.item(),
|
||||
value=mse_info_retention.item()
|
||||
)
|
||||
loss_stn.requires_grad = True
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue