Yay, works in DP via CPU

This commit is contained in:
Zhengyi Chen 2024-03-03 21:59:58 +00:00
parent ab15419d2f
commit fc941ebaf7
6 changed files with 18 additions and 9 deletions

2
.gitignore vendored
View file

@ -3,3 +3,5 @@ synchronous/
npydata/ npydata/
**/__pycache__/** **/__pycache__/**
slurm-* slurm-*
save/
.vscode/

View file

@ -22,7 +22,8 @@ export TMP=/disk/scratch/${STUDENT_ID}/
source /home/${STUDENT_ID}/miniconda3/bin/activate mlp-cuda source /home/${STUDENT_ID}/miniconda3/bin/activate mlp-cuda
python train.py \ python train.py \
--debug=True \
--model='stn' \ --model='stn' \
--save_path ./save_file/ShanghaiA \ --save_path ./save_file/ShanghaiA \
--batch_size 4 \ --batch_size 4 \
--gpus 0,1,2,3,4,5 \ # --gpus 0,1,2,3,4,5 \

View file

@ -76,7 +76,7 @@ parser.add_argument(
"--ddp_world_size", type=int, default=1, "--ddp_world_size", type=int, default=1,
help="DDP: Number of processes in Pytorch process group" help="DDP: Number of processes in Pytorch process group"
) )
parse.add_argument( parser.add_argument(
"--debug", type=bool, default=False "--debug", type=bool, default=False
) )

View file

@ -57,7 +57,10 @@ class SquareCropTransformLayer(nn.Module):
) )
# Sum into gt_count # Sum into gt_count
ret_gt_count = torch.sum(split_t.view(split_t.size(0), -1), dim=1) ret_gt_count = (torch
.sum(split_t.view(split_t.size(0), -1), dim=1)
.unsqueeze(1)
)
return ret_x, ret_gt_count return ret_x, ret_gt_count

View file

@ -85,7 +85,6 @@ class VisionTransformerGAP(VisionTransformer):
def forward(self, x, t): def forward(self, x, t):
with torch.no_grad(): with torch.no_grad():
x, t = self.glue(x, t) x, t = self.glue(x, t)
print(f"Glue: {x.shape} | {t.shape}")
x = self.forward_features(x) # Compute encoding x = self.forward_features(x) # Compute encoding
x = F.adaptive_avg_pool1d(x, (48)) x = F.adaptive_avg_pool1d(x, (48))
x = x.view(x.shape[0], -1) # Move data for regression head x = x.view(x.shape[0], -1) # Move data for regression head

View file

@ -22,7 +22,6 @@ from checkpoint import save_checkpoint
logger = logging.getLogger("train") logger = logging.getLogger("train")
def setup_process_group( def setup_process_group(
rank: int, rank: int,
world_size: int, world_size: int,
@ -242,7 +241,6 @@ def train_one_epoch(
# In one epoch, for each training sample # In one epoch, for each training sample
for i, (fname, img, kpoint) in enumerate(train_loader): for i, (fname, img, kpoint) in enumerate(train_loader):
kpoint = kpoint.type(torch.FloatTensor) kpoint = kpoint.type(torch.FloatTensor)
print("Training: img {} | kpoint {}".format(img.shape, kpoint.shape))
# fpass # fpass
if device is not None: if device is not None:
img = img.to(device) img = img.to(device)
@ -251,7 +249,6 @@ def train_one_epoch(
img = img.cuda() img = img.cuda()
kpoint = kpoint.cuda() kpoint = kpoint.cuda()
out, gt_count = model(img, kpoint) out, gt_count = model(img, kpoint)
# gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1)
# loss # loss
loss = criterion(out, gt_count) loss = criterion(out, gt_count)
@ -269,6 +266,9 @@ def train_one_epoch(
if i % args.print_freq == 0: if i % args.print_freq == 0:
print("Epoch {}: {}/{}".format(epoch, i, len(train_loader))) print("Epoch {}: {}/{}".format(epoch, i, len(train_loader)))
if args.debug:
break
scheduler.step() scheduler.step()
@ -283,6 +283,7 @@ def valid_one_epoch(test_loader, model, device, args):
index = 0 index = 0
for i, (fname, img, kpoint) in enumerate(test_loader): for i, (fname, img, kpoint) in enumerate(test_loader):
kpoint = kpoint.type(torch.FloatTensor)
if device is not None: if device is not None:
img = img.to(device) img = img.to(device)
kpoint = kpoint.to(device) kpoint = kpoint.to(device)
@ -301,8 +302,8 @@ def valid_one_epoch(test_loader, model, device, args):
count = torch.sum(out).item() count = torch.sum(out).item()
gt_count = torch.sum(gt_count).item() gt_count = torch.sum(gt_count).item()
mae += abs(kpoint - count) mae += abs(gt_count - count)
mse += abs(kpoint - count) ** 2 mse += abs(gt_count - count) ** 2
if i % 15 == 0: if i % 15 == 0:
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format( print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(
@ -324,6 +325,9 @@ if __name__ == "__main__":
# tuner_params = nni.get_next_parameter() # tuner_params = nni.get_next_parameter()
# logger.debug("Generated hyperparameters: {}", tuner_params) # logger.debug("Generated hyperparameters: {}", tuner_params)
# combined_params = nni.utils.merge_parameter(ret_args, tuner_params) # combined_params = nni.utils.merge_parameter(ret_args, tuner_params)
if args.debug:
os.nice(15)
combined_params = args combined_params = args
logger.debug("Parameters: {}", combined_params) logger.debug("Parameters: {}", combined_params)