Sync
This commit is contained in:
parent
da8287b7e8
commit
c74f4c7fb3
3 changed files with 17 additions and 9 deletions
|
|
@ -6,6 +6,8 @@
|
||||||
#SBATCH --mem=24000
|
#SBATCH --mem=24000
|
||||||
#SBATCH --time=3-00:00:00
|
#SBATCH --time=3-00:00:00
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
export CUDA_HOME=/opt/cuda-9.0.176.1/
|
export CUDA_HOME=/opt/cuda-9.0.176.1/
|
||||||
export CUDNN_HOME=/opt/cuDNN-7.0/
|
export CUDNN_HOME=/opt/cuDNN-7.0/
|
||||||
export STUDENT_ID=$(whoami)
|
export STUDENT_ID=$(whoami)
|
||||||
|
|
@ -21,8 +23,11 @@ export TMPDIR=/disk/scratch/${STUDENT_ID}/
|
||||||
export TMP=/disk/scratch/${STUDENT_ID}/
|
export TMP=/disk/scratch/${STUDENT_ID}/
|
||||||
|
|
||||||
source /home/${STUDENT_ID}/miniconda3/bin/activate mlp-cuda
|
source /home/${STUDENT_ID}/miniconda3/bin/activate mlp-cuda
|
||||||
|
|
||||||
python train.py \
|
python train.py \
|
||||||
--model='stn' \
|
--debug True \
|
||||||
|
--model 'stn' \
|
||||||
--save_path ./save_file/ShanghaiA \
|
--save_path ./save_file/ShanghaiA \
|
||||||
--batch_size 4 \
|
--batch_size 4 \
|
||||||
--gpus 0,1,2,3,4,5
|
--gpus 0,1,2,3,4,5 \
|
||||||
|
--print_freq 100
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,6 @@ class STNet(nn.Module):
|
||||||
_dummy_size_ = input_size
|
_dummy_size_ = input_size
|
||||||
|
|
||||||
# shape checking
|
# shape checking
|
||||||
print("STN: dummy_size {}".format(_dummy_size_))
|
|
||||||
_dummy_x_ = torch.zeros(_dummy_size_)
|
_dummy_x_ = torch.zeros(_dummy_size_)
|
||||||
|
|
||||||
# (3.1) Spatial transformer localization-network
|
# (3.1) Spatial transformer localization-network
|
||||||
|
|
|
||||||
10
train.py
10
train.py
|
|
@ -89,7 +89,7 @@ def build_test_loader(data_keys, args):
|
||||||
test_loader = DataLoader(
|
test_loader = DataLoader(
|
||||||
dataset=test_dataset,
|
dataset=test_dataset,
|
||||||
sampler=test_dist_sampler,
|
sampler=test_dist_sampler,
|
||||||
batch_size=1
|
batch_size=4
|
||||||
)
|
)
|
||||||
return test_loader
|
return test_loader
|
||||||
|
|
||||||
|
|
@ -299,13 +299,17 @@ def valid_one_epoch(test_loader, model, device, args):
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
out, gt_count = model(img, kpoint)
|
out, gt_count = model(img, kpoint)
|
||||||
|
if args.debug:
|
||||||
|
print("out: {} | gt_count: {}".format(
|
||||||
|
out.shape, gt_count.shape
|
||||||
|
))
|
||||||
count = torch.sum(out).item()
|
count = torch.sum(out).item()
|
||||||
|
|
||||||
gt_count = torch.sum(gt_count).item()
|
gt_count = torch.sum(gt_count).item()
|
||||||
|
|
||||||
mae += abs(gt_count - count)
|
mae += abs(gt_count - count)
|
||||||
mse += abs(gt_count - count) ** 2
|
mse += abs(gt_count - count) ** 2
|
||||||
|
|
||||||
if i % 15 == 0:
|
# if i % 15 == 0:
|
||||||
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(
|
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(
|
||||||
fname[0], gt_count, count
|
fname[0], gt_count, count
|
||||||
))
|
))
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue