From ab15419d2f3b1404fea58e9df090077b832e3e18 Mon Sep 17 00:00:00 2001 From: Zhengyi Chen Date: Sun, 3 Mar 2024 19:40:22 +0000 Subject: [PATCH] Debug --- .gitignore | 3 ++- _ShanghaiA-train.sh | 20 ++++++++++++++++++- arguments.py | 5 ++++- model/stn.py | 6 +++--- train.py | 48 +++++++++++++++++++++++++++++++++------------ 5 files changed, 63 insertions(+), 19 deletions(-) diff --git a/.gitignore b/.gitignore index ac4256a..25a3c78 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ baseline-experiments/ synchronous/ npydata/ -**/__pycache__/** \ No newline at end of file +**/__pycache__/** +slurm-* diff --git a/_ShanghaiA-train.sh b/_ShanghaiA-train.sh index 0e8a34f..eb29e4e 100644 --- a/_ShanghaiA-train.sh +++ b/_ShanghaiA-train.sh @@ -6,5 +6,23 @@ #SBATCH --mem=24000 #SBATCH --time=3-00:00:00 +export CUDA_HOME=/opt/cuda-9.0.176.1/ +export CUDNN_HOME=/opt/cuDNN-7.0/ +export STUDENT_ID=$(whoami) + +export LD_LIBRARY_PATH=${CUDNN_HOME}/lib64:${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} +export LIBRARY_PATH=${CUDNN_HOME}/lib64:${LIBRARY_PATH} +export CPATH=${CUDNN_HOME}/include:$CPATH +export PATH=${CUDA_HOME}/bin:${PATH} +export PYTHON_PATH=$PATH + +mkdir -p /disk/scratch/${STUDENT_ID} +export TMPDIR=/disk/scratch/${STUDENT_ID}/ +export TMP=/disk/scratch/${STUDENT_ID}/ + +source /home/${STUDENT_ID}/miniconda3/bin/activate mlp-cuda python train.py \ - --model='stn' \ No newline at end of file + --model='stn' \ + --save_path ./save_file/ShanghaiA \ + --batch_size 4 \ + --gpus 0,1,2,3,4,5 \ \ No newline at end of file diff --git a/arguments.py b/arguments.py index 72cf252..ac1bebf 100644 --- a/arguments.py +++ b/arguments.py @@ -63,7 +63,7 @@ parser.add_argument( "--epochs", type=int, default=250, help="Number of epochs to train" ) parser.add_argument( - "--gpus", type=List[int], default=[0], + "--gpus", type=str, default='0', help="GPU IDs to be made available for training runtime" ) @@ -76,6 +76,9 @@ parser.add_argument( "--ddp_world_size", type=int, default=1, help="DDP: Number of processes in Pytorch process group" ) +parse.add_argument( + "--debug", type=bool, default=False +) # nni configuration ========================================================== parser.add_argument( diff --git a/model/stn.py b/model/stn.py index 1b6099f..2d18abe 100644 --- a/model/stn.py +++ b/model/stn.py @@ -69,13 +69,13 @@ class STNet(nn.Module): theta = self.fc_loc(xs) theta = theta.view(-1, 2, 3) # -> (2, 3) - grid = F.affine_grid(theta, x.size()) - x = F.grid_sample(x, grid) + grid = F.affine_grid(theta, x.size(), align_corners=False) + x = F.grid_sample(x, grid, align_corners=False) # Do the same transformation to t sans training with torch.no_grad(): t = t.view(t.size(0), 1, t.size(1), t.size(2)) - t = F.grid_sample(t, grid) + t = F.grid_sample(t, grid, align_corners=False) t = t.squeeze(1) return x, t diff --git a/train.py b/train.py index e3216da..ca337c0 100644 --- a/train.py +++ b/train.py @@ -106,13 +106,16 @@ def worker(rank: int, args: Namespace): if args.use_ddp and torch.cuda.is_available(): device = torch.device(rank) elif torch.cuda.is_available(): - device = torch.device(args.gpus) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + args.gpus = [int(gpu_id) for gpu_id in args.gpus.split(",")] + device = None elif torch.backends.mps.is_available(): device = torch.device("mps") else: print("[!!!] Using CPU for inference. This will be slow...") device = torch.device("cpu") - torch.set_default_device(device) + if device is not None: + torch.set_default_device(device) # Prepare training data train_list, test_list = unpack_npy_data(args) @@ -123,9 +126,9 @@ def worker(rank: int, args: Namespace): # Instantiate model if args.model == "stn": - model = stn_patch16_384_gap(args.pth_tar).to(device) + model = stn_patch16_384_gap(args.pth_tar) else: - model = base_patch16_384_gap(args.pth_tar).to(device) + model = base_patch16_384_gap(args.pth_tar) if args.use_ddp: model = nn.parallel.DistributedDataParallel( @@ -140,8 +143,17 @@ def worker(rank: int, args: Namespace): device_ids=args.gpus ) + if device is not None: + model = model.to(device) + elif torch.cuda.is_available(): + model = model.cuda() + # criterion, optimizer, scheduler - criterion = nn.L1Loss(size_average=False).to(device) + criterion = nn.L1Loss(size_average=False) + if device is not None: + criterion = criterion.to(device) + elif torch.cuda.is_available(): + criterion = criterion.cuda() optimizer = torch.optim.Adam( [{"params": model.parameters(), "lr": args.lr}], lr=args.lr, @@ -184,7 +196,7 @@ def worker(rank: int, args: Namespace): end_train = time.time() # Validate - if epoch % 5 == 0: + if epoch % 5 == 0 or args.debug: prec1 = valid_one_epoch(test_loader, model, device, args) end_valid = time.time() is_best = prec1 < args.best_pred @@ -232,8 +244,12 @@ def train_one_epoch( kpoint = kpoint.type(torch.FloatTensor) print("Training: img {} | kpoint {}".format(img.shape, kpoint.shape)) # fpass - img = img.to(device) - kpoint = kpoint.to(device) + if device is not None: + img = img.to(device) + kpoint = kpoint.to(device) + elif torch.cuda.is_available(): + img = img.cuda() + kpoint = kpoint.cuda() out, gt_count = model(img, kpoint) # gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1) @@ -266,8 +282,14 @@ def valid_one_epoch(test_loader, model, device, args): visi = [] index = 0 - for i, (fname, img, gt_count) in enumerate(test_loader): - img = img.to(device) + for i, (fname, img, kpoint) in enumerate(test_loader): + if device is not None: + img = img.to(device) + kpoint = kpoint.to(device) + elif torch.cuda.is_available(): + img = img.cuda() + kpoint = kpoint.cuda() + # XXX: what do this do if len(img.shape) == 5: img = img.squeeze(0) @@ -275,12 +297,12 @@ def valid_one_epoch(test_loader, model, device, args): img = img.unsqueeze(0) with torch.no_grad(): - out = model(img) + out, gt_count = model(img, kpoint) count = torch.sum(out).item() gt_count = torch.sum(gt_count).item() - mae += abs(gt_count - count) - mse += abs(gt_count - count) ** 2 + mae += abs(kpoint - count) + mse += abs(kpoint - count) ** 2 if i % 15 == 0: print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(