This commit is contained in:
Zhengyi Chen 2024-03-03 23:13:57 +00:00
parent da8287b7e8
commit c74f4c7fb3
3 changed files with 17 additions and 9 deletions

View file

@ -6,6 +6,8 @@
#SBATCH --mem=24000
#SBATCH --time=3-00:00:00
set -e
export CUDA_HOME=/opt/cuda-9.0.176.1/
export CUDNN_HOME=/opt/cuDNN-7.0/
export STUDENT_ID=$(whoami)
@ -21,8 +23,11 @@ export TMPDIR=/disk/scratch/${STUDENT_ID}/
export TMP=/disk/scratch/${STUDENT_ID}/
source /home/${STUDENT_ID}/miniconda3/bin/activate mlp-cuda
python train.py \
--model='stn' \
--debug True \
--model 'stn' \
--save_path ./save_file/ShanghaiA \
--batch_size 4 \
--gpus 0,1,2,3,4,5
--gpus 0,1,2,3,4,5 \
--print_freq 100

View file

@ -24,7 +24,6 @@ class STNet(nn.Module):
_dummy_size_ = input_size
# shape checking
print("STN: dummy_size {}".format(_dummy_size_))
_dummy_x_ = torch.zeros(_dummy_size_)
# (3.1) Spatial transformer localization-network

View file

@ -89,7 +89,7 @@ def build_test_loader(data_keys, args):
test_loader = DataLoader(
dataset=test_dataset,
sampler=test_dist_sampler,
batch_size=1
batch_size=4
)
return test_loader
@ -299,16 +299,20 @@ def valid_one_epoch(test_loader, model, device, args):
with torch.no_grad():
out, gt_count = model(img, kpoint)
if args.debug:
print("out: {} | gt_count: {}".format(
out.shape, gt_count.shape
))
count = torch.sum(out).item()
gt_count = torch.sum(gt_count).item()
gt_count = torch.sum(gt_count).item()
mae += abs(gt_count - count)
mse += abs(gt_count - count) ** 2
if i % 15 == 0:
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(
fname[0], gt_count, count
))
# if i % 15 == 0:
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(
fname[0], gt_count, count
))
mae = mae * 1.0 / (len(test_loader) * batch_size)
mse = np.sqrt(mse / (len(test_loader)) * batch_size)