Sync
This commit is contained in:
parent
da8287b7e8
commit
c74f4c7fb3
3 changed files with 17 additions and 9 deletions
|
|
@ -6,6 +6,8 @@
|
|||
#SBATCH --mem=24000
|
||||
#SBATCH --time=3-00:00:00
|
||||
|
||||
set -e
|
||||
|
||||
export CUDA_HOME=/opt/cuda-9.0.176.1/
|
||||
export CUDNN_HOME=/opt/cuDNN-7.0/
|
||||
export STUDENT_ID=$(whoami)
|
||||
|
|
@ -21,8 +23,11 @@ export TMPDIR=/disk/scratch/${STUDENT_ID}/
|
|||
export TMP=/disk/scratch/${STUDENT_ID}/
|
||||
|
||||
source /home/${STUDENT_ID}/miniconda3/bin/activate mlp-cuda
|
||||
|
||||
python train.py \
|
||||
--model='stn' \
|
||||
--debug True \
|
||||
--model 'stn' \
|
||||
--save_path ./save_file/ShanghaiA \
|
||||
--batch_size 4 \
|
||||
--gpus 0,1,2,3,4,5
|
||||
--gpus 0,1,2,3,4,5 \
|
||||
--print_freq 100
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ class STNet(nn.Module):
|
|||
_dummy_size_ = input_size
|
||||
|
||||
# shape checking
|
||||
print("STN: dummy_size {}".format(_dummy_size_))
|
||||
_dummy_x_ = torch.zeros(_dummy_size_)
|
||||
|
||||
# (3.1) Spatial transformer localization-network
|
||||
|
|
|
|||
10
train.py
10
train.py
|
|
@ -89,7 +89,7 @@ def build_test_loader(data_keys, args):
|
|||
test_loader = DataLoader(
|
||||
dataset=test_dataset,
|
||||
sampler=test_dist_sampler,
|
||||
batch_size=1
|
||||
batch_size=4
|
||||
)
|
||||
return test_loader
|
||||
|
||||
|
|
@ -299,13 +299,17 @@ def valid_one_epoch(test_loader, model, device, args):
|
|||
|
||||
with torch.no_grad():
|
||||
out, gt_count = model(img, kpoint)
|
||||
if args.debug:
|
||||
print("out: {} | gt_count: {}".format(
|
||||
out.shape, gt_count.shape
|
||||
))
|
||||
count = torch.sum(out).item()
|
||||
|
||||
gt_count = torch.sum(gt_count).item()
|
||||
|
||||
mae += abs(gt_count - count)
|
||||
mse += abs(gt_count - count) ** 2
|
||||
|
||||
if i % 15 == 0:
|
||||
# if i % 15 == 0:
|
||||
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(
|
||||
fname[0], gt_count, count
|
||||
))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue