FIX: vaidation code

This commit is contained in:
Zhengyi Chen 2024-03-03 23:56:17 +00:00
parent c74f4c7fb3
commit ee50e84946

View file

@ -89,7 +89,7 @@ def build_test_loader(data_keys, args):
test_loader = DataLoader(
dataset=test_dataset,
sampler=test_dist_sampler,
batch_size=4
batch_size=1
)
return test_loader
@ -299,30 +299,26 @@ def valid_one_epoch(test_loader, model, device, args):
with torch.no_grad():
out, gt_count = model(img, kpoint)
if args.debug:
print("out: {} | gt_count: {}".format(
out.shape, gt_count.shape
))
count = torch.sum(out).item()
gt_count = torch.sum(gt_count).item()
mae += abs(gt_count - count)
mse += abs(gt_count - count) ** 2
# if i % 15 == 0:
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(
fname[0], gt_count, count
))
pred_count = torch.squeeze(out, 1)
gt_count = torch.squeeze(gt_count, 1)
diff = torch.sum(torch.abs(gt_count - pred_count)).item()
mae += diff
mse += diff ** 2
mae = mae * 1.0 / (len(test_loader) * batch_size)
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
nni.report_intermediate_result(mae)
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
mae=mae, mse=mse
))
if i % 5 == 0:
print("[valid_one_epoch] {}\t| Gt {:.2f} Pred {:.4f}\t| mae {:.4f} mse {:.4f} |".format(
fname[0], torch.sum(gt_count).item(), torch.sum(pred_count).item(),
mae, mse
))
return mae
nni.report_intermediate_result()
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
mae=mae, mse=mse
))
return mae
if __name__ == "__main__":