FIX: vaidation code
This commit is contained in:
parent
c74f4c7fb3
commit
ee50e84946
1 changed files with 16 additions and 20 deletions
36
train.py
36
train.py
|
|
@ -89,7 +89,7 @@ def build_test_loader(data_keys, args):
|
|||
test_loader = DataLoader(
|
||||
dataset=test_dataset,
|
||||
sampler=test_dist_sampler,
|
||||
batch_size=4
|
||||
batch_size=1
|
||||
)
|
||||
return test_loader
|
||||
|
||||
|
|
@ -299,30 +299,26 @@ def valid_one_epoch(test_loader, model, device, args):
|
|||
|
||||
with torch.no_grad():
|
||||
out, gt_count = model(img, kpoint)
|
||||
if args.debug:
|
||||
print("out: {} | gt_count: {}".format(
|
||||
out.shape, gt_count.shape
|
||||
))
|
||||
count = torch.sum(out).item()
|
||||
gt_count = torch.sum(gt_count).item()
|
||||
|
||||
mae += abs(gt_count - count)
|
||||
mse += abs(gt_count - count) ** 2
|
||||
|
||||
# if i % 15 == 0:
|
||||
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(
|
||||
fname[0], gt_count, count
|
||||
))
|
||||
pred_count = torch.squeeze(out, 1)
|
||||
gt_count = torch.squeeze(gt_count, 1)
|
||||
|
||||
diff = torch.sum(torch.abs(gt_count - pred_count)).item()
|
||||
mae += diff
|
||||
mse += diff ** 2
|
||||
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
||||
mse = np.sqrt(mse / (len(test_loader)) * batch_size)
|
||||
|
||||
nni.report_intermediate_result(mae)
|
||||
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
|
||||
mae=mae, mse=mse
|
||||
))
|
||||
if i % 5 == 0:
|
||||
print("[valid_one_epoch] {}\t| Gt {:.2f} Pred {:.4f}\t| mae {:.4f} mse {:.4f} |".format(
|
||||
fname[0], torch.sum(gt_count).item(), torch.sum(pred_count).item(),
|
||||
mae, mse
|
||||
))
|
||||
|
||||
return mae
|
||||
nni.report_intermediate_result()
|
||||
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
|
||||
mae=mae, mse=mse
|
||||
))
|
||||
return mae
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue