This commit is contained in:
Zhengyi Chen 2024-03-03 23:13:57 +00:00
parent da8287b7e8
commit c74f4c7fb3
3 changed files with 17 additions and 9 deletions

View file

@ -6,6 +6,8 @@
#SBATCH --mem=24000 #SBATCH --mem=24000
#SBATCH --time=3-00:00:00 #SBATCH --time=3-00:00:00
set -e
export CUDA_HOME=/opt/cuda-9.0.176.1/ export CUDA_HOME=/opt/cuda-9.0.176.1/
export CUDNN_HOME=/opt/cuDNN-7.0/ export CUDNN_HOME=/opt/cuDNN-7.0/
export STUDENT_ID=$(whoami) export STUDENT_ID=$(whoami)
@ -21,8 +23,11 @@ export TMPDIR=/disk/scratch/${STUDENT_ID}/
export TMP=/disk/scratch/${STUDENT_ID}/ export TMP=/disk/scratch/${STUDENT_ID}/
source /home/${STUDENT_ID}/miniconda3/bin/activate mlp-cuda source /home/${STUDENT_ID}/miniconda3/bin/activate mlp-cuda
python train.py \ python train.py \
--model='stn' \ --debug True \
--model 'stn' \
--save_path ./save_file/ShanghaiA \ --save_path ./save_file/ShanghaiA \
--batch_size 4 \ --batch_size 4 \
--gpus 0,1,2,3,4,5 --gpus 0,1,2,3,4,5 \
--print_freq 100

View file

@ -24,7 +24,6 @@ class STNet(nn.Module):
_dummy_size_ = input_size _dummy_size_ = input_size
# shape checking # shape checking
print("STN: dummy_size {}".format(_dummy_size_))
_dummy_x_ = torch.zeros(_dummy_size_) _dummy_x_ = torch.zeros(_dummy_size_)
# (3.1) Spatial transformer localization-network # (3.1) Spatial transformer localization-network

View file

@ -89,7 +89,7 @@ def build_test_loader(data_keys, args):
test_loader = DataLoader( test_loader = DataLoader(
dataset=test_dataset, dataset=test_dataset,
sampler=test_dist_sampler, sampler=test_dist_sampler,
batch_size=1 batch_size=4
) )
return test_loader return test_loader
@ -299,16 +299,20 @@ def valid_one_epoch(test_loader, model, device, args):
with torch.no_grad(): with torch.no_grad():
out, gt_count = model(img, kpoint) out, gt_count = model(img, kpoint)
if args.debug:
print("out: {} | gt_count: {}".format(
out.shape, gt_count.shape
))
count = torch.sum(out).item() count = torch.sum(out).item()
gt_count = torch.sum(gt_count).item()
gt_count = torch.sum(gt_count).item()
mae += abs(gt_count - count) mae += abs(gt_count - count)
mse += abs(gt_count - count) ** 2 mse += abs(gt_count - count) ** 2
if i % 15 == 0: # if i % 15 == 0:
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format( print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(
fname[0], gt_count, count fname[0], gt_count, count
)) ))
mae = mae * 1.0 / (len(test_loader) * batch_size) mae = mae * 1.0 / (len(test_loader) * batch_size)
mse = np.sqrt(mse / (len(test_loader)) * batch_size) mse = np.sqrt(mse / (len(test_loader)) * batch_size)