This commit is contained in:
Zhengyi Chen 2024-03-03 23:13:57 +00:00
parent da8287b7e8
commit c74f4c7fb3
3 changed files with 17 additions and 9 deletions

View file

@ -89,7 +89,7 @@ def build_test_loader(data_keys, args):
test_loader = DataLoader(
dataset=test_dataset,
sampler=test_dist_sampler,
batch_size=1
batch_size=4
)
return test_loader
@ -299,16 +299,20 @@ def valid_one_epoch(test_loader, model, device, args):
with torch.no_grad():
out, gt_count = model(img, kpoint)
if args.debug:
print("out: {} | gt_count: {}".format(
out.shape, gt_count.shape
))
count = torch.sum(out).item()
gt_count = torch.sum(gt_count).item()
gt_count = torch.sum(gt_count).item()
mae += abs(gt_count - count)
mse += abs(gt_count - count) ** 2
if i % 15 == 0:
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(
fname[0], gt_count, count
))
# if i % 15 == 0:
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(
fname[0], gt_count, count
))
mae = mae * 1.0 / (len(test_loader) * batch_size)
mse = np.sqrt(mse / (len(test_loader)) * batch_size)