Debug
This commit is contained in:
parent
a9dd8dee04
commit
ab15419d2f
5 changed files with 63 additions and 19 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -2,3 +2,4 @@ baseline-experiments/
|
|||
synchronous/
|
||||
npydata/
|
||||
**/__pycache__/**
|
||||
slurm-*
|
||||
|
|
|
|||
|
|
@ -6,5 +6,23 @@
|
|||
#SBATCH --mem=24000
|
||||
#SBATCH --time=3-00:00:00
|
||||
|
||||
export CUDA_HOME=/opt/cuda-9.0.176.1/
|
||||
export CUDNN_HOME=/opt/cuDNN-7.0/
|
||||
export STUDENT_ID=$(whoami)
|
||||
|
||||
export LD_LIBRARY_PATH=${CUDNN_HOME}/lib64:${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
|
||||
export LIBRARY_PATH=${CUDNN_HOME}/lib64:${LIBRARY_PATH}
|
||||
export CPATH=${CUDNN_HOME}/include:$CPATH
|
||||
export PATH=${CUDA_HOME}/bin:${PATH}
|
||||
export PYTHON_PATH=$PATH
|
||||
|
||||
mkdir -p /disk/scratch/${STUDENT_ID}
|
||||
export TMPDIR=/disk/scratch/${STUDENT_ID}/
|
||||
export TMP=/disk/scratch/${STUDENT_ID}/
|
||||
|
||||
source /home/${STUDENT_ID}/miniconda3/bin/activate mlp-cuda
|
||||
python train.py \
|
||||
--model='stn'
|
||||
--model='stn' \
|
||||
--save_path ./save_file/ShanghaiA \
|
||||
--batch_size 4 \
|
||||
--gpus 0,1,2,3,4,5 \
|
||||
|
|
@ -63,7 +63,7 @@ parser.add_argument(
|
|||
"--epochs", type=int, default=250, help="Number of epochs to train"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpus", type=List[int], default=[0],
|
||||
"--gpus", type=str, default='0',
|
||||
help="GPU IDs to be made available for training runtime"
|
||||
)
|
||||
|
||||
|
|
@ -76,6 +76,9 @@ parser.add_argument(
|
|||
"--ddp_world_size", type=int, default=1,
|
||||
help="DDP: Number of processes in Pytorch process group"
|
||||
)
|
||||
parse.add_argument(
|
||||
"--debug", type=bool, default=False
|
||||
)
|
||||
|
||||
# nni configuration ==========================================================
|
||||
parser.add_argument(
|
||||
|
|
|
|||
|
|
@ -69,13 +69,13 @@ class STNet(nn.Module):
|
|||
theta = self.fc_loc(xs)
|
||||
theta = theta.view(-1, 2, 3) # -> (2, 3)
|
||||
|
||||
grid = F.affine_grid(theta, x.size())
|
||||
x = F.grid_sample(x, grid)
|
||||
grid = F.affine_grid(theta, x.size(), align_corners=False)
|
||||
x = F.grid_sample(x, grid, align_corners=False)
|
||||
|
||||
# Do the same transformation to t sans training
|
||||
with torch.no_grad():
|
||||
t = t.view(t.size(0), 1, t.size(1), t.size(2))
|
||||
t = F.grid_sample(t, grid)
|
||||
t = F.grid_sample(t, grid, align_corners=False)
|
||||
t = t.squeeze(1)
|
||||
|
||||
return x, t
|
||||
|
|
|
|||
40
train.py
40
train.py
|
|
@ -106,12 +106,15 @@ def worker(rank: int, args: Namespace):
|
|||
if args.use_ddp and torch.cuda.is_available():
|
||||
device = torch.device(rank)
|
||||
elif torch.cuda.is_available():
|
||||
device = torch.device(args.gpus)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
||||
args.gpus = [int(gpu_id) for gpu_id in args.gpus.split(",")]
|
||||
device = None
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
print("[!!!] Using CPU for inference. This will be slow...")
|
||||
device = torch.device("cpu")
|
||||
if device is not None:
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Prepare training data
|
||||
|
|
@ -123,9 +126,9 @@ def worker(rank: int, args: Namespace):
|
|||
|
||||
# Instantiate model
|
||||
if args.model == "stn":
|
||||
model = stn_patch16_384_gap(args.pth_tar).to(device)
|
||||
model = stn_patch16_384_gap(args.pth_tar)
|
||||
else:
|
||||
model = base_patch16_384_gap(args.pth_tar).to(device)
|
||||
model = base_patch16_384_gap(args.pth_tar)
|
||||
|
||||
if args.use_ddp:
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
|
|
@ -140,8 +143,17 @@ def worker(rank: int, args: Namespace):
|
|||
device_ids=args.gpus
|
||||
)
|
||||
|
||||
if device is not None:
|
||||
model = model.to(device)
|
||||
elif torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
|
||||
# criterion, optimizer, scheduler
|
||||
criterion = nn.L1Loss(size_average=False).to(device)
|
||||
criterion = nn.L1Loss(size_average=False)
|
||||
if device is not None:
|
||||
criterion = criterion.to(device)
|
||||
elif torch.cuda.is_available():
|
||||
criterion = criterion.cuda()
|
||||
optimizer = torch.optim.Adam(
|
||||
[{"params": model.parameters(), "lr": args.lr}],
|
||||
lr=args.lr,
|
||||
|
|
@ -184,7 +196,7 @@ def worker(rank: int, args: Namespace):
|
|||
end_train = time.time()
|
||||
|
||||
# Validate
|
||||
if epoch % 5 == 0:
|
||||
if epoch % 5 == 0 or args.debug:
|
||||
prec1 = valid_one_epoch(test_loader, model, device, args)
|
||||
end_valid = time.time()
|
||||
is_best = prec1 < args.best_pred
|
||||
|
|
@ -232,8 +244,12 @@ def train_one_epoch(
|
|||
kpoint = kpoint.type(torch.FloatTensor)
|
||||
print("Training: img {} | kpoint {}".format(img.shape, kpoint.shape))
|
||||
# fpass
|
||||
if device is not None:
|
||||
img = img.to(device)
|
||||
kpoint = kpoint.to(device)
|
||||
elif torch.cuda.is_available():
|
||||
img = img.cuda()
|
||||
kpoint = kpoint.cuda()
|
||||
out, gt_count = model(img, kpoint)
|
||||
# gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1)
|
||||
|
||||
|
|
@ -266,8 +282,14 @@ def valid_one_epoch(test_loader, model, device, args):
|
|||
visi = []
|
||||
index = 0
|
||||
|
||||
for i, (fname, img, gt_count) in enumerate(test_loader):
|
||||
for i, (fname, img, kpoint) in enumerate(test_loader):
|
||||
if device is not None:
|
||||
img = img.to(device)
|
||||
kpoint = kpoint.to(device)
|
||||
elif torch.cuda.is_available():
|
||||
img = img.cuda()
|
||||
kpoint = kpoint.cuda()
|
||||
|
||||
# XXX: what do this do
|
||||
if len(img.shape) == 5:
|
||||
img = img.squeeze(0)
|
||||
|
|
@ -275,12 +297,12 @@ def valid_one_epoch(test_loader, model, device, args):
|
|||
img = img.unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(img)
|
||||
out, gt_count = model(img, kpoint)
|
||||
count = torch.sum(out).item()
|
||||
|
||||
gt_count = torch.sum(gt_count).item()
|
||||
mae += abs(gt_count - count)
|
||||
mse += abs(gt_count - count) ** 2
|
||||
mae += abs(kpoint - count)
|
||||
mse += abs(kpoint - count) ** 2
|
||||
|
||||
if i % 15 == 0:
|
||||
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue