From c74f4c7fb30d80ea319735820ec831878dfa8332 Mon Sep 17 00:00:00 2001 From: Zhengyi Chen Date: Sun, 3 Mar 2024 23:13:57 +0000 Subject: [PATCH] Sync --- _ShanghaiA-train.sh | 9 +++++++-- model/stn.py | 1 - train.py | 16 ++++++++++------ 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/_ShanghaiA-train.sh b/_ShanghaiA-train.sh index c885ade..a1af788 100644 --- a/_ShanghaiA-train.sh +++ b/_ShanghaiA-train.sh @@ -6,6 +6,8 @@ #SBATCH --mem=24000 #SBATCH --time=3-00:00:00 +set -e + export CUDA_HOME=/opt/cuda-9.0.176.1/ export CUDNN_HOME=/opt/cuDNN-7.0/ export STUDENT_ID=$(whoami) @@ -21,8 +23,11 @@ export TMPDIR=/disk/scratch/${STUDENT_ID}/ export TMP=/disk/scratch/${STUDENT_ID}/ source /home/${STUDENT_ID}/miniconda3/bin/activate mlp-cuda + python train.py \ - --model='stn' \ + --debug True \ + --model 'stn' \ --save_path ./save_file/ShanghaiA \ --batch_size 4 \ - --gpus 0,1,2,3,4,5 + --gpus 0,1,2,3,4,5 \ + --print_freq 100 diff --git a/model/stn.py b/model/stn.py index 2d18abe..d98650d 100644 --- a/model/stn.py +++ b/model/stn.py @@ -24,7 +24,6 @@ class STNet(nn.Module): _dummy_size_ = input_size # shape checking - print("STN: dummy_size {}".format(_dummy_size_)) _dummy_x_ = torch.zeros(_dummy_size_) # (3.1) Spatial transformer localization-network diff --git a/train.py b/train.py index b79dc6e..230ecc5 100644 --- a/train.py +++ b/train.py @@ -89,7 +89,7 @@ def build_test_loader(data_keys, args): test_loader = DataLoader( dataset=test_dataset, sampler=test_dist_sampler, - batch_size=1 + batch_size=4 ) return test_loader @@ -299,16 +299,20 @@ def valid_one_epoch(test_loader, model, device, args): with torch.no_grad(): out, gt_count = model(img, kpoint) + if args.debug: + print("out: {} | gt_count: {}".format( + out.shape, gt_count.shape + )) count = torch.sum(out).item() + gt_count = torch.sum(gt_count).item() - gt_count = torch.sum(gt_count).item() mae += abs(gt_count - count) mse += abs(gt_count - count) ** 2 - if i % 15 == 0: - print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format( - fname[0], gt_count, count - )) + # if i % 15 == 0: + print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format( + fname[0], gt_count, count + )) mae = mae * 1.0 / (len(test_loader) * batch_size) mse = np.sqrt(mse / (len(test_loader)) * batch_size)