Edited loss fn to incorporate index of dispersion
This commit is contained in:
parent
9d2a30a226
commit
c228fc27cc
1 changed files with 12 additions and 6 deletions
18
train.py
18
train.py
|
|
@ -278,12 +278,18 @@ def train_one_epoch(
|
||||||
# loss & bpass & etc.
|
# loss & bpass & etc.
|
||||||
if isinstance(model.module, STNet_VisionTransformerGAP):
|
if isinstance(model.module, STNet_VisionTransformerGAP):
|
||||||
loss_xformer = criterion(out, gt_count)
|
loss_xformer = criterion(out, gt_count)
|
||||||
loss_stn = (
|
mse_info_retention = F.mse_loss(
|
||||||
F.mse_loss(
|
input =gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
|
||||||
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
|
target=gt_count_whole
|
||||||
gt_count_whole)
|
)
|
||||||
+ F.sigmoid(
|
mean_dispersion_idx = (
|
||||||
gt_count.view(batch_size, -1).var(dim=1).mean())
|
gt_count.view(batch_size, -1).var(dim=1) /
|
||||||
|
gt_count.view(batch_size, -1).mean(dim=1)
|
||||||
|
).mean()
|
||||||
|
loss_stn = mse_info_retention + F.threshold(
|
||||||
|
mean_dispersion_idx,
|
||||||
|
threshold=mse_info_retention.item(),
|
||||||
|
value=mse_info_retention.item()
|
||||||
)
|
)
|
||||||
loss_stn.requires_grad = True
|
loss_stn.requires_grad = True
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue