Use mse instead of l1 for info retainment loss, fix mae
This commit is contained in:
parent
d46a027e3f
commit
94867bd8bf
1 changed files with 6 additions and 6 deletions
12
train.py
12
train.py
|
|
@ -257,7 +257,7 @@ def train_one_epoch(
|
|||
# loss
|
||||
loss = criterion(out, gt_count) # wrt. transformer
|
||||
loss += (
|
||||
criterion( # stn: info retainment
|
||||
F.mse_loss( # stn: info retainment
|
||||
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
|
||||
gt_count_whole)
|
||||
+ F.threshold( # stn: perspective correction
|
||||
|
|
@ -277,8 +277,8 @@ def train_one_epoch(
|
|||
optimizer.step()
|
||||
|
||||
# periodic message
|
||||
if i % args.print_freq == 0:
|
||||
print("Epoch {}: {}/{}".format(epoch, i, len(train_loader)))
|
||||
# if i % args.print_freq == 0:
|
||||
# print("Epoch {}: {}/{}".format(epoch, i, len(train_loader)))
|
||||
|
||||
if args.debug:
|
||||
break
|
||||
|
|
@ -322,15 +322,15 @@ def valid_one_epoch(test_loader, model, device, args):
|
|||
diff = torch.abs(gt_count_whole - torch.sum(pred_count)).item()
|
||||
mae += diff
|
||||
mse += diff ** 2
|
||||
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
||||
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
|
||||
|
||||
if i % 5 == 0:
|
||||
print("[valid_one_epoch] {}\t| Gt {:.2f} Pred {:.4f}\t| mae {:.4f} mse {:.4f} |".format(
|
||||
print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} | mae {:.4f} mse {:.4f} |".format(
|
||||
fname[0], torch.sum(gt_count_whole).item(), torch.sum(pred_count).item(),
|
||||
mae, mse
|
||||
))
|
||||
|
||||
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
||||
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
|
||||
nni.report_intermediate_result(mae)
|
||||
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
|
||||
mae=mae, mse=mse
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue