From 2d31162c58ba9aa73db0991fe7d8b0086d738ca9 Mon Sep 17 00:00:00 2001 From: Zhengyi Chen Date: Mon, 4 Mar 2024 18:36:40 +0000 Subject: [PATCH] TEST: use autocast for mixed-precision training --- train.py | 43 ++++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/train.py b/train.py index 8732497..320ec6f 100644 --- a/train.py +++ b/train.py @@ -243,32 +243,36 @@ def train_one_epoch( kpoint = kpoint.type(torch.FloatTensor) gt_count_whole = gt_count_whole.type(torch.FloatTensor).unsqueeze(1) batch_size = img.size(0) - # fpass + # send to device if device is not None: img = img.to(device) kpoint = kpoint.to(device) gt_count_whole = gt_count_whole.to(device) + device_type = device.type elif torch.cuda.is_available(): img = img.cuda() kpoint = kpoint.cuda() gt_count_whole = gt_count_whole.cuda() - out, gt_count = model(img, kpoint) + device_type = "cuda" - # loss - loss = criterion(out, gt_count) # wrt. transformer - loss += ( - F.mse_loss( # stn: info retainment - gt_count.view(batch_size, -1).sum(axis=1, keepdim=True), - gt_count_whole) - + F.threshold( # stn: perspective correction - gt_count.view(batch_size, -1).var(dim=1).mean(), - threshold=loss.item(), - value=loss.item() + with torch.autocast(device_type): + # fpass + out, gt_count = model(img, kpoint) + # loss + loss = criterion(out, gt_count) # wrt. transformer + loss += ( + F.mse_loss( # stn: info retainment + gt_count.view(batch_size, -1).sum(axis=1, keepdim=True), + gt_count_whole) + + F.threshold( # stn: perspective correction + gt_count.view(batch_size, -1).var(dim=1).mean(), + threshold=loss.item(), + value=loss.item() + ) ) - ) # free grad from mem - optimizer.zero_grad() + optimizer.zero_grad(set_to_none=True) # bpass loss.backward() @@ -276,10 +280,6 @@ def train_one_epoch( # optimizer optimizer.step() - # periodic message - # if i % args.print_freq == 0: - # print("Epoch {}: {}/{}".format(epoch, i, len(train_loader))) - if args.debug: break @@ -324,9 +324,10 @@ def valid_one_epoch(test_loader, model, device, args): mse += diff ** 2 if i % 5 == 0: - print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} | mae {:.4f} mse {:.4f} |".format( - fname[0], torch.sum(gt_count_whole).item(), torch.sum(pred_count).item(), - mae, mse + print("[valid_one_epoch] {} | Gt {:.2f} Pred {:.4f} |".format( + fname[0], + torch.sum(gt_count_whole).item(), + torch.sum(pred_count).item() )) mae = mae * 1.0 / (len(test_loader) * batch_size)