diff --git a/train.py b/train.py index a6491d1..e5fe107 100644 --- a/train.py +++ b/train.py @@ -278,12 +278,18 @@ def train_one_epoch( # loss & bpass & etc. if isinstance(model.module, STNet_VisionTransformerGAP): loss_xformer = criterion(out, gt_count) - loss_stn = ( - F.mse_loss( - gt_count.view(batch_size, -1).sum(axis=1, keepdim=True), - gt_count_whole) - + F.sigmoid( - gt_count.view(batch_size, -1).var(dim=1).mean()) + mse_info_retention = F.mse_loss( + input =gt_count.view(batch_size, -1).sum(axis=1, keepdim=True), + target=gt_count_whole + ) + mean_dispersion_idx = ( + gt_count.view(batch_size, -1).var(dim=1) / + gt_count.view(batch_size, -1).mean(dim=1) + ).mean() + loss_stn = mse_info_retention + F.threshold( + mean_dispersion_idx, + threshold=mse_info_retention.item(), + value=mse_info_retention.item() ) loss_stn.requires_grad = True optimizer.zero_grad(set_to_none=True)