This commit is contained in:
Zhengyi Chen 2024-03-03 19:40:22 +00:00
parent a9dd8dee04
commit ab15419d2f
5 changed files with 63 additions and 19 deletions

1
.gitignore vendored
View file

@ -2,3 +2,4 @@ baseline-experiments/
synchronous/ synchronous/
npydata/ npydata/
**/__pycache__/** **/__pycache__/**
slurm-*

View file

@ -6,5 +6,23 @@
#SBATCH --mem=24000 #SBATCH --mem=24000
#SBATCH --time=3-00:00:00 #SBATCH --time=3-00:00:00
export CUDA_HOME=/opt/cuda-9.0.176.1/
export CUDNN_HOME=/opt/cuDNN-7.0/
export STUDENT_ID=$(whoami)
export LD_LIBRARY_PATH=${CUDNN_HOME}/lib64:${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
export LIBRARY_PATH=${CUDNN_HOME}/lib64:${LIBRARY_PATH}
export CPATH=${CUDNN_HOME}/include:$CPATH
export PATH=${CUDA_HOME}/bin:${PATH}
export PYTHON_PATH=$PATH
mkdir -p /disk/scratch/${STUDENT_ID}
export TMPDIR=/disk/scratch/${STUDENT_ID}/
export TMP=/disk/scratch/${STUDENT_ID}/
source /home/${STUDENT_ID}/miniconda3/bin/activate mlp-cuda
python train.py \ python train.py \
--model='stn' --model='stn' \
--save_path ./save_file/ShanghaiA \
--batch_size 4 \
--gpus 0,1,2,3,4,5 \

View file

@ -63,7 +63,7 @@ parser.add_argument(
"--epochs", type=int, default=250, help="Number of epochs to train" "--epochs", type=int, default=250, help="Number of epochs to train"
) )
parser.add_argument( parser.add_argument(
"--gpus", type=List[int], default=[0], "--gpus", type=str, default='0',
help="GPU IDs to be made available for training runtime" help="GPU IDs to be made available for training runtime"
) )
@ -76,6 +76,9 @@ parser.add_argument(
"--ddp_world_size", type=int, default=1, "--ddp_world_size", type=int, default=1,
help="DDP: Number of processes in Pytorch process group" help="DDP: Number of processes in Pytorch process group"
) )
parse.add_argument(
"--debug", type=bool, default=False
)
# nni configuration ========================================================== # nni configuration ==========================================================
parser.add_argument( parser.add_argument(

View file

@ -69,13 +69,13 @@ class STNet(nn.Module):
theta = self.fc_loc(xs) theta = self.fc_loc(xs)
theta = theta.view(-1, 2, 3) # -> (2, 3) theta = theta.view(-1, 2, 3) # -> (2, 3)
grid = F.affine_grid(theta, x.size()) grid = F.affine_grid(theta, x.size(), align_corners=False)
x = F.grid_sample(x, grid) x = F.grid_sample(x, grid, align_corners=False)
# Do the same transformation to t sans training # Do the same transformation to t sans training
with torch.no_grad(): with torch.no_grad():
t = t.view(t.size(0), 1, t.size(1), t.size(2)) t = t.view(t.size(0), 1, t.size(1), t.size(2))
t = F.grid_sample(t, grid) t = F.grid_sample(t, grid, align_corners=False)
t = t.squeeze(1) t = t.squeeze(1)
return x, t return x, t

View file

@ -106,12 +106,15 @@ def worker(rank: int, args: Namespace):
if args.use_ddp and torch.cuda.is_available(): if args.use_ddp and torch.cuda.is_available():
device = torch.device(rank) device = torch.device(rank)
elif torch.cuda.is_available(): elif torch.cuda.is_available():
device = torch.device(args.gpus) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
args.gpus = [int(gpu_id) for gpu_id in args.gpus.split(",")]
device = None
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
device = torch.device("mps") device = torch.device("mps")
else: else:
print("[!!!] Using CPU for inference. This will be slow...") print("[!!!] Using CPU for inference. This will be slow...")
device = torch.device("cpu") device = torch.device("cpu")
if device is not None:
torch.set_default_device(device) torch.set_default_device(device)
# Prepare training data # Prepare training data
@ -123,9 +126,9 @@ def worker(rank: int, args: Namespace):
# Instantiate model # Instantiate model
if args.model == "stn": if args.model == "stn":
model = stn_patch16_384_gap(args.pth_tar).to(device) model = stn_patch16_384_gap(args.pth_tar)
else: else:
model = base_patch16_384_gap(args.pth_tar).to(device) model = base_patch16_384_gap(args.pth_tar)
if args.use_ddp: if args.use_ddp:
model = nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
@ -140,8 +143,17 @@ def worker(rank: int, args: Namespace):
device_ids=args.gpus device_ids=args.gpus
) )
if device is not None:
model = model.to(device)
elif torch.cuda.is_available():
model = model.cuda()
# criterion, optimizer, scheduler # criterion, optimizer, scheduler
criterion = nn.L1Loss(size_average=False).to(device) criterion = nn.L1Loss(size_average=False)
if device is not None:
criterion = criterion.to(device)
elif torch.cuda.is_available():
criterion = criterion.cuda()
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
[{"params": model.parameters(), "lr": args.lr}], [{"params": model.parameters(), "lr": args.lr}],
lr=args.lr, lr=args.lr,
@ -184,7 +196,7 @@ def worker(rank: int, args: Namespace):
end_train = time.time() end_train = time.time()
# Validate # Validate
if epoch % 5 == 0: if epoch % 5 == 0 or args.debug:
prec1 = valid_one_epoch(test_loader, model, device, args) prec1 = valid_one_epoch(test_loader, model, device, args)
end_valid = time.time() end_valid = time.time()
is_best = prec1 < args.best_pred is_best = prec1 < args.best_pred
@ -232,8 +244,12 @@ def train_one_epoch(
kpoint = kpoint.type(torch.FloatTensor) kpoint = kpoint.type(torch.FloatTensor)
print("Training: img {} | kpoint {}".format(img.shape, kpoint.shape)) print("Training: img {} | kpoint {}".format(img.shape, kpoint.shape))
# fpass # fpass
if device is not None:
img = img.to(device) img = img.to(device)
kpoint = kpoint.to(device) kpoint = kpoint.to(device)
elif torch.cuda.is_available():
img = img.cuda()
kpoint = kpoint.cuda()
out, gt_count = model(img, kpoint) out, gt_count = model(img, kpoint)
# gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1) # gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1)
@ -266,8 +282,14 @@ def valid_one_epoch(test_loader, model, device, args):
visi = [] visi = []
index = 0 index = 0
for i, (fname, img, gt_count) in enumerate(test_loader): for i, (fname, img, kpoint) in enumerate(test_loader):
if device is not None:
img = img.to(device) img = img.to(device)
kpoint = kpoint.to(device)
elif torch.cuda.is_available():
img = img.cuda()
kpoint = kpoint.cuda()
# XXX: what do this do # XXX: what do this do
if len(img.shape) == 5: if len(img.shape) == 5:
img = img.squeeze(0) img = img.squeeze(0)
@ -275,12 +297,12 @@ def valid_one_epoch(test_loader, model, device, args):
img = img.unsqueeze(0) img = img.unsqueeze(0)
with torch.no_grad(): with torch.no_grad():
out = model(img) out, gt_count = model(img, kpoint)
count = torch.sum(out).item() count = torch.sum(out).item()
gt_count = torch.sum(gt_count).item() gt_count = torch.sum(gt_count).item()
mae += abs(gt_count - count) mae += abs(kpoint - count)
mse += abs(gt_count - count) ** 2 mse += abs(kpoint - count) ** 2
if i % 15 == 0: if i % 15 == 0:
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format( print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(