FIX: vaidation code
This commit is contained in:
parent
c74f4c7fb3
commit
ee50e84946
1 changed files with 16 additions and 20 deletions
30
train.py
30
train.py
|
|
@ -89,7 +89,7 @@ def build_test_loader(data_keys, args):
|
||||||
test_loader = DataLoader(
|
test_loader = DataLoader(
|
||||||
dataset=test_dataset,
|
dataset=test_dataset,
|
||||||
sampler=test_dist_sampler,
|
sampler=test_dist_sampler,
|
||||||
batch_size=4
|
batch_size=1
|
||||||
)
|
)
|
||||||
return test_loader
|
return test_loader
|
||||||
|
|
||||||
|
|
@ -299,29 +299,25 @@ def valid_one_epoch(test_loader, model, device, args):
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
out, gt_count = model(img, kpoint)
|
out, gt_count = model(img, kpoint)
|
||||||
if args.debug:
|
pred_count = torch.squeeze(out, 1)
|
||||||
print("out: {} | gt_count: {}".format(
|
gt_count = torch.squeeze(gt_count, 1)
|
||||||
out.shape, gt_count.shape
|
|
||||||
))
|
|
||||||
count = torch.sum(out).item()
|
|
||||||
gt_count = torch.sum(gt_count).item()
|
|
||||||
|
|
||||||
mae += abs(gt_count - count)
|
|
||||||
mse += abs(gt_count - count) ** 2
|
|
||||||
|
|
||||||
# if i % 15 == 0:
|
|
||||||
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(
|
|
||||||
fname[0], gt_count, count
|
|
||||||
))
|
|
||||||
|
|
||||||
|
diff = torch.sum(torch.abs(gt_count - pred_count)).item()
|
||||||
|
mae += diff
|
||||||
|
mse += diff ** 2
|
||||||
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
||||||
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
|
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
|
||||||
|
|
||||||
nni.report_intermediate_result(mae)
|
if i % 5 == 0:
|
||||||
|
print("[valid_one_epoch] {}\t| Gt {:.2f} Pred {:.4f}\t| mae {:.4f} mse {:.4f} |".format(
|
||||||
|
fname[0], torch.sum(gt_count).item(), torch.sum(pred_count).item(),
|
||||||
|
mae, mse
|
||||||
|
))
|
||||||
|
|
||||||
|
nni.report_intermediate_result()
|
||||||
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
|
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
|
||||||
mae=mae, mse=mse
|
mae=mae, mse=mse
|
||||||
))
|
))
|
||||||
|
|
||||||
return mae
|
return mae
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue