diff --git a/train.py b/train.py index 230ecc5..b85ef42 100644 --- a/train.py +++ b/train.py @@ -89,7 +89,7 @@ def build_test_loader(data_keys, args): test_loader = DataLoader( dataset=test_dataset, sampler=test_dist_sampler, - batch_size=4 + batch_size=1 ) return test_loader @@ -299,30 +299,26 @@ def valid_one_epoch(test_loader, model, device, args): with torch.no_grad(): out, gt_count = model(img, kpoint) - if args.debug: - print("out: {} | gt_count: {}".format( - out.shape, gt_count.shape - )) - count = torch.sum(out).item() - gt_count = torch.sum(gt_count).item() - - mae += abs(gt_count - count) - mse += abs(gt_count - count) ** 2 - - # if i % 15 == 0: - print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format( - fname[0], gt_count, count - )) + pred_count = torch.squeeze(out, 1) + gt_count = torch.squeeze(gt_count, 1) + diff = torch.sum(torch.abs(gt_count - pred_count)).item() + mae += diff + mse += diff ** 2 mae = mae * 1.0 / (len(test_loader) * batch_size) mse = np.sqrt(mse / (len(test_loader)) * batch_size) - nni.report_intermediate_result(mae) - print("* MAE {mae:.3f} | MSE {mse:.3f} *".format( - mae=mae, mse=mse - )) + if i % 5 == 0: + print("[valid_one_epoch] {}\t| Gt {:.2f} Pred {:.4f}\t| mae {:.4f} mse {:.4f} |".format( + fname[0], torch.sum(gt_count).item(), torch.sum(pred_count).item(), + mae, mse + )) - return mae + nni.report_intermediate_result() + print("* MAE {mae:.3f} | MSE {mse:.3f} *".format( + mae=mae, mse=mse + )) + return mae if __name__ == "__main__":