Yay, works in DP via CPU
This commit is contained in:
parent
ab15419d2f
commit
fc941ebaf7
6 changed files with 18 additions and 9 deletions
14
train.py
14
train.py
|
|
@ -22,7 +22,6 @@ from checkpoint import save_checkpoint
|
|||
|
||||
logger = logging.getLogger("train")
|
||||
|
||||
|
||||
def setup_process_group(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
|
|
@ -242,7 +241,6 @@ def train_one_epoch(
|
|||
# In one epoch, for each training sample
|
||||
for i, (fname, img, kpoint) in enumerate(train_loader):
|
||||
kpoint = kpoint.type(torch.FloatTensor)
|
||||
print("Training: img {} | kpoint {}".format(img.shape, kpoint.shape))
|
||||
# fpass
|
||||
if device is not None:
|
||||
img = img.to(device)
|
||||
|
|
@ -251,7 +249,6 @@ def train_one_epoch(
|
|||
img = img.cuda()
|
||||
kpoint = kpoint.cuda()
|
||||
out, gt_count = model(img, kpoint)
|
||||
# gt_count = gt_count.type(torch.FloatTensor).to(device).unsqueeze(1)
|
||||
|
||||
# loss
|
||||
loss = criterion(out, gt_count)
|
||||
|
|
@ -269,6 +266,9 @@ def train_one_epoch(
|
|||
if i % args.print_freq == 0:
|
||||
print("Epoch {}: {}/{}".format(epoch, i, len(train_loader)))
|
||||
|
||||
if args.debug:
|
||||
break
|
||||
|
||||
scheduler.step()
|
||||
|
||||
|
||||
|
|
@ -283,6 +283,7 @@ def valid_one_epoch(test_loader, model, device, args):
|
|||
index = 0
|
||||
|
||||
for i, (fname, img, kpoint) in enumerate(test_loader):
|
||||
kpoint = kpoint.type(torch.FloatTensor)
|
||||
if device is not None:
|
||||
img = img.to(device)
|
||||
kpoint = kpoint.to(device)
|
||||
|
|
@ -301,8 +302,8 @@ def valid_one_epoch(test_loader, model, device, args):
|
|||
count = torch.sum(out).item()
|
||||
|
||||
gt_count = torch.sum(gt_count).item()
|
||||
mae += abs(kpoint - count)
|
||||
mse += abs(kpoint - count) ** 2
|
||||
mae += abs(gt_count - count)
|
||||
mse += abs(gt_count - count) ** 2
|
||||
|
||||
if i % 15 == 0:
|
||||
print("[valid_one_epoch] {} Gt {:.2f} Pred {}".format(
|
||||
|
|
@ -324,6 +325,9 @@ if __name__ == "__main__":
|
|||
# tuner_params = nni.get_next_parameter()
|
||||
# logger.debug("Generated hyperparameters: {}", tuner_params)
|
||||
# combined_params = nni.utils.merge_parameter(ret_args, tuner_params)
|
||||
if args.debug:
|
||||
os.nice(15)
|
||||
|
||||
combined_params = args
|
||||
logger.debug("Parameters: {}", combined_params)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue