Use mse instead of l1 for info retainment loss, fix mae

This commit is contained in:
Zhengyi Chen 2024-03-04 14:39:58 +00:00
parent d46a027e3f
commit 94867bd8bf

View file

@ -257,7 +257,7 @@ def train_one_epoch(
# loss # loss
loss = criterion(out, gt_count) # wrt. transformer loss = criterion(out, gt_count) # wrt. transformer
loss += ( loss += (
criterion( # stn: info retainment F.mse_loss( # stn: info retainment
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True), gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
gt_count_whole) gt_count_whole)
+ F.threshold( # stn: perspective correction + F.threshold( # stn: perspective correction
@ -277,8 +277,8 @@ def train_one_epoch(
optimizer.step() optimizer.step()
# periodic message # periodic message
if i % args.print_freq == 0: # if i % args.print_freq == 0:
print("Epoch {}: {}/{}".format(epoch, i, len(train_loader))) # print("Epoch {}: {}/{}".format(epoch, i, len(train_loader)))
if args.debug: if args.debug:
break break
@ -322,15 +322,15 @@ def valid_one_epoch(test_loader, model, device, args):
diff = torch.abs(gt_count_whole - torch.sum(pred_count)).item() diff = torch.abs(gt_count_whole - torch.sum(pred_count)).item()
mae += diff mae += diff
mse += diff ** 2 mse += diff ** 2
mae = mae * 1.0 / (len(test_loader) * batch_size)
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
if i % 5 == 0: if i % 5 == 0:
print("[valid_one_epoch] {}\t| Gt {:.2f} Pred {:.4f}\t| mae {:.4f} mse {:.4f} |".format( print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} | mae {:.4f} mse {:.4f} |".format(
fname[0], torch.sum(gt_count_whole).item(), torch.sum(pred_count).item(), fname[0], torch.sum(gt_count_whole).item(), torch.sum(pred_count).item(),
mae, mse mae, mse
)) ))
mae = mae * 1.0 / (len(test_loader) * batch_size)
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
nni.report_intermediate_result(mae) nni.report_intermediate_result(mae)
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format( print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
mae=mae, mse=mse mae=mae, mse=mse