TEST: use autocast for mixed-precision training
This commit is contained in:
parent
94867bd8bf
commit
2d31162c58
1 changed files with 22 additions and 21 deletions
43
train.py
43
train.py
|
|
@ -243,32 +243,36 @@ def train_one_epoch(
|
||||||
kpoint = kpoint.type(torch.FloatTensor)
|
kpoint = kpoint.type(torch.FloatTensor)
|
||||||
gt_count_whole = gt_count_whole.type(torch.FloatTensor).unsqueeze(1)
|
gt_count_whole = gt_count_whole.type(torch.FloatTensor).unsqueeze(1)
|
||||||
batch_size = img.size(0)
|
batch_size = img.size(0)
|
||||||
# fpass
|
# send to device
|
||||||
if device is not None:
|
if device is not None:
|
||||||
img = img.to(device)
|
img = img.to(device)
|
||||||
kpoint = kpoint.to(device)
|
kpoint = kpoint.to(device)
|
||||||
gt_count_whole = gt_count_whole.to(device)
|
gt_count_whole = gt_count_whole.to(device)
|
||||||
|
device_type = device.type
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
img = img.cuda()
|
img = img.cuda()
|
||||||
kpoint = kpoint.cuda()
|
kpoint = kpoint.cuda()
|
||||||
gt_count_whole = gt_count_whole.cuda()
|
gt_count_whole = gt_count_whole.cuda()
|
||||||
out, gt_count = model(img, kpoint)
|
device_type = "cuda"
|
||||||
|
|
||||||
# loss
|
with torch.autocast(device_type):
|
||||||
loss = criterion(out, gt_count) # wrt. transformer
|
# fpass
|
||||||
loss += (
|
out, gt_count = model(img, kpoint)
|
||||||
F.mse_loss( # stn: info retainment
|
# loss
|
||||||
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
|
loss = criterion(out, gt_count) # wrt. transformer
|
||||||
gt_count_whole)
|
loss += (
|
||||||
+ F.threshold( # stn: perspective correction
|
F.mse_loss( # stn: info retainment
|
||||||
gt_count.view(batch_size, -1).var(dim=1).mean(),
|
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
|
||||||
threshold=loss.item(),
|
gt_count_whole)
|
||||||
value=loss.item()
|
+ F.threshold( # stn: perspective correction
|
||||||
|
gt_count.view(batch_size, -1).var(dim=1).mean(),
|
||||||
|
threshold=loss.item(),
|
||||||
|
value=loss.item()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# free grad from mem
|
# free grad from mem
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
# bpass
|
# bpass
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
@ -276,10 +280,6 @@ def train_one_epoch(
|
||||||
# optimizer
|
# optimizer
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# periodic message
|
|
||||||
# if i % args.print_freq == 0:
|
|
||||||
# print("Epoch {}: {}/{}".format(epoch, i, len(train_loader)))
|
|
||||||
|
|
||||||
if args.debug:
|
if args.debug:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
@ -324,9 +324,10 @@ def valid_one_epoch(test_loader, model, device, args):
|
||||||
mse += diff ** 2
|
mse += diff ** 2
|
||||||
|
|
||||||
if i % 5 == 0:
|
if i % 5 == 0:
|
||||||
print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} | mae {:.4f} mse {:.4f} |".format(
|
print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} |".format(
|
||||||
fname[0], torch.sum(gt_count_whole).item(), torch.sum(pred_count).item(),
|
fname[0],
|
||||||
mae, mse
|
torch.sum(gt_count_whole).item(),
|
||||||
|
torch.sum(pred_count).item()
|
||||||
))
|
))
|
||||||
|
|
||||||
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue