TEST: use autocast for mixed-precision training

This commit is contained in:
Zhengyi Chen 2024-03-04 18:36:40 +00:00
parent 94867bd8bf
commit 2d31162c58

View file

@ -243,32 +243,36 @@ def train_one_epoch(
kpoint = kpoint.type(torch.FloatTensor) kpoint = kpoint.type(torch.FloatTensor)
gt_count_whole = gt_count_whole.type(torch.FloatTensor).unsqueeze(1) gt_count_whole = gt_count_whole.type(torch.FloatTensor).unsqueeze(1)
batch_size = img.size(0) batch_size = img.size(0)
# fpass # send to device
if device is not None: if device is not None:
img = img.to(device) img = img.to(device)
kpoint = kpoint.to(device) kpoint = kpoint.to(device)
gt_count_whole = gt_count_whole.to(device) gt_count_whole = gt_count_whole.to(device)
device_type = device.type
elif torch.cuda.is_available(): elif torch.cuda.is_available():
img = img.cuda() img = img.cuda()
kpoint = kpoint.cuda() kpoint = kpoint.cuda()
gt_count_whole = gt_count_whole.cuda() gt_count_whole = gt_count_whole.cuda()
out, gt_count = model(img, kpoint) device_type = "cuda"
# loss with torch.autocast(device_type):
loss = criterion(out, gt_count) # wrt. transformer # fpass
loss += ( out, gt_count = model(img, kpoint)
F.mse_loss( # stn: info retainment # loss
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True), loss = criterion(out, gt_count) # wrt. transformer
gt_count_whole) loss += (
+ F.threshold( # stn: perspective correction F.mse_loss( # stn: info retainment
gt_count.view(batch_size, -1).var(dim=1).mean(), gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
threshold=loss.item(), gt_count_whole)
value=loss.item() + F.threshold( # stn: perspective correction
gt_count.view(batch_size, -1).var(dim=1).mean(),
threshold=loss.item(),
value=loss.item()
)
) )
)
# free grad from mem # free grad from mem
optimizer.zero_grad() optimizer.zero_grad(set_to_none=True)
# bpass # bpass
loss.backward() loss.backward()
@ -276,10 +280,6 @@ def train_one_epoch(
# optimizer # optimizer
optimizer.step() optimizer.step()
# periodic message
# if i % args.print_freq == 0:
# print("Epoch {}: {}/{}".format(epoch, i, len(train_loader)))
if args.debug: if args.debug:
break break
@ -324,9 +324,10 @@ def valid_one_epoch(test_loader, model, device, args):
mse += diff ** 2 mse += diff ** 2
if i % 5 == 0: if i % 5 == 0:
print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} | mae {:.4f} mse {:.4f} |".format( print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} |".format(
fname[0], torch.sum(gt_count_whole).item(), torch.sum(pred_count).item(), fname[0],
mae, mse torch.sum(gt_count_whole).item(),
torch.sum(pred_count).item()
)) ))
mae = mae * 1.0 / (len(test_loader) * batch_size) mae = mae * 1.0 / (len(test_loader) * batch_size)