Edited loss fn to incorporate index of dispersion

This commit is contained in:
Zhengyi Chen 2024-03-06 22:30:43 +00:00 committed by rubberhead
parent 9d2a30a226
commit c228fc27cc

View file

@ -278,12 +278,18 @@ def train_one_epoch(
# loss & bpass & etc. # loss & bpass & etc.
if isinstance(model.module, STNet_VisionTransformerGAP): if isinstance(model.module, STNet_VisionTransformerGAP):
loss_xformer = criterion(out, gt_count) loss_xformer = criterion(out, gt_count)
loss_stn = ( mse_info_retention = F.mse_loss(
F.mse_loss( input =gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True), target=gt_count_whole
gt_count_whole) )
+ F.sigmoid( mean_dispersion_idx = (
gt_count.view(batch_size, -1).var(dim=1).mean()) gt_count.view(batch_size, -1).var(dim=1) /
gt_count.view(batch_size, -1).mean(dim=1)
).mean()
loss_stn = mse_info_retention + F.threshold(
mean_dispersion_idx,
threshold=mse_info_retention.item(),
value=mse_info_retention.item()
) )
loss_stn.requires_grad = True loss_stn.requires_grad = True
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)