Use mse instead of l1 for info retainment loss, fix mae
This commit is contained in:
parent
d46a027e3f
commit
94867bd8bf
1 changed files with 6 additions and 6 deletions
12
train.py
12
train.py
|
|
@ -257,7 +257,7 @@ def train_one_epoch(
|
||||||
# loss
|
# loss
|
||||||
loss = criterion(out, gt_count) # wrt. transformer
|
loss = criterion(out, gt_count) # wrt. transformer
|
||||||
loss += (
|
loss += (
|
||||||
criterion( # stn: info retainment
|
F.mse_loss( # stn: info retainment
|
||||||
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
|
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
|
||||||
gt_count_whole)
|
gt_count_whole)
|
||||||
+ F.threshold( # stn: perspective correction
|
+ F.threshold( # stn: perspective correction
|
||||||
|
|
@ -277,8 +277,8 @@ def train_one_epoch(
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# periodic message
|
# periodic message
|
||||||
if i % args.print_freq == 0:
|
# if i % args.print_freq == 0:
|
||||||
print("Epoch {}: {}/{}".format(epoch, i, len(train_loader)))
|
# print("Epoch {}: {}/{}".format(epoch, i, len(train_loader)))
|
||||||
|
|
||||||
if args.debug:
|
if args.debug:
|
||||||
break
|
break
|
||||||
|
|
@ -322,15 +322,15 @@ def valid_one_epoch(test_loader, model, device, args):
|
||||||
diff = torch.abs(gt_count_whole - torch.sum(pred_count)).item()
|
diff = torch.abs(gt_count_whole - torch.sum(pred_count)).item()
|
||||||
mae += diff
|
mae += diff
|
||||||
mse += diff ** 2
|
mse += diff ** 2
|
||||||
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
|
||||||
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
|
|
||||||
|
|
||||||
if i % 5 == 0:
|
if i % 5 == 0:
|
||||||
print("[valid_one_epoch] {}\t| Gt {:.2f} Pred {:.4f}\t| mae {:.4f} mse {:.4f} |".format(
|
print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} | mae {:.4f} mse {:.4f} |".format(
|
||||||
fname[0], torch.sum(gt_count_whole).item(), torch.sum(pred_count).item(),
|
fname[0], torch.sum(gt_count_whole).item(), torch.sum(pred_count).item(),
|
||||||
mae, mse
|
mae, mse
|
||||||
))
|
))
|
||||||
|
|
||||||
|
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
||||||
|
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
|
||||||
nni.report_intermediate_result(mae)
|
nni.report_intermediate_result(mae)
|
||||||
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
|
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
|
||||||
mae=mae, mse=mse
|
mae=mae, mse=mse
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue