Edited loss fn to incorporate index of dispersion

This commit is contained in:
Zhengyi Chen 2024-03-06 22:30:43 +00:00 committed by rubberhead
parent 9d2a30a226
commit c228fc27cc

View file

@ -278,12 +278,18 @@ def train_one_epoch(
# loss & bpass & etc.
if isinstance(model.module, STNet_VisionTransformerGAP):
loss_xformer = criterion(out, gt_count)
loss_stn = (
F.mse_loss(
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
gt_count_whole)
+ F.sigmoid(
gt_count.view(batch_size, -1).var(dim=1).mean())
mse_info_retention = F.mse_loss(
input =gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
target=gt_count_whole
)
mean_dispersion_idx = (
gt_count.view(batch_size, -1).var(dim=1) /
gt_count.view(batch_size, -1).mean(dim=1)
).mean()
loss_stn = mse_info_retention + F.threshold(
mean_dispersion_idx,
threshold=mse_info_retention.item(),
value=mse_info_retention.item()
)
loss_stn.requires_grad = True
optimizer.zero_grad(set_to_none=True)