From d46a027e3f9d7a745b59497d993e1f2b1d44c981 Mon Sep 17 00:00:00 2001 From: Zhengyi Chen Date: Mon, 4 Mar 2024 04:10:17 +0000 Subject: [PATCH] Updated glue impl, complex loss fn --- model/glue.py | 37 +++++++++++-------------------------- train.py | 41 ++++++++++++++++++++++++++--------------- 2 files changed, 37 insertions(+), 41 deletions(-) diff --git a/model/glue.py b/model/glue.py index de01198..47fd915 100644 --- a/model/glue.py +++ b/model/glue.py @@ -24,36 +24,21 @@ class SquareCropTransformLayer(nn.Module): ) -> (torch.Tensor, torch.Tensor): # Here, x_ & kpoints_ already applied affine transform. assert len(x_.shape) == 4 - channels, height, width = x_.shape[1:] + batch_size, channels, height, width = x_.shape h_split_count = height // self.crop_size w_split_count = width // self.crop_size # Perform identical splits -- note kpoints_ does not have C dimension! - ret_x = torch.cat( - torch.tensor_split( - torch.cat( - torch.tensor_split( - x_, - h_split_count, - dim=2 - ) - ), - w_split_count, - dim=3 - ) - ) # Performance should be acceptable but looks dumb as hell, is there a better way? - split_t = torch.cat( - torch.tensor_split( - torch.cat( - torch.tensor_split( - kpoints_, - h_split_count, - dim=1 - ) - ), - w_split_count, - dim=2 - ) + ret_x = x_.view( + batch_size * h_split_count * w_split_count, + channels, + self.crop_size, + self.crop_size, + ) + split_t = kpoints_.view( + batch_size * h_split_count * w_split_count, + self.crop_size, + self.crop_size, ) # Sum into gt_count diff --git a/train.py b/train.py index 62c1000..db51133 100644 --- a/train.py +++ b/train.py @@ -239,22 +239,33 @@ def train_one_epoch( model.train() # In one epoch, for each training sample - for i, (fname, img, kpoint, gt_count) in enumerate(train_loader): + for i, (fname, img, kpoint, gt_count_whole) in enumerate(train_loader): kpoint = kpoint.type(torch.FloatTensor) - gt_count = gt_count.type(torch.FloatTensor) + gt_count_whole = gt_count_whole.type(torch.FloatTensor).unsqueeze(1) + batch_size = img.size(0) # fpass if device is not None: img = img.to(device) kpoint = kpoint.to(device) - gt_count = gt_count.to(device) + gt_count_whole = gt_count_whole.to(device) elif torch.cuda.is_available(): img = img.cuda() kpoint = kpoint.cuda() - gt_count = gt_count.cuda() - out, _ = model(img, kpoint) + gt_count_whole = gt_count_whole.cuda() + out, gt_count = model(img, kpoint) # loss - loss = criterion(out, gt_count) + loss = criterion(out, gt_count) # wrt. transformer + loss += ( + criterion( # stn: info retainment + gt_count.view(batch_size, -1).sum(axis=1, keepdim=True), + gt_count_whole) + + F.threshold( # stn: perspective correction + gt_count.view(batch_size, -1).var(dim=1).mean(), + threshold=loss.item(), + value=loss.item() + ) + ) # free grad from mem optimizer.zero_grad() @@ -285,17 +296,17 @@ def valid_one_epoch(test_loader, model, device, args): visi = [] index = 0 - for i, (fname, img, kpoint, gt_count) in enumerate(test_loader): + for i, (fname, img, kpoint, gt_count_whole) in enumerate(test_loader): kpoint = kpoint.type(torch.FloatTensor) - gt_count = gt_count.type(torch.FloatTensor) + gt_count_whole = gt_count_whole.type(torch.FloatTensor) if device is not None: img = img.to(device) kpoint = kpoint.to(device) - gt_count = gt_count.to(device) + gt_count_whole = gt_count_whole.to(device) elif torch.cuda.is_available(): img = img.cuda() kpoint = kpoint.cuda() - gt_count = gt_count.cuda() + gt_count_whole = gt_count_whole.cuda() # XXX: do this even happen? if len(img.shape) == 5: @@ -304,11 +315,11 @@ def valid_one_epoch(test_loader, model, device, args): img = img.unsqueeze(0) with torch.no_grad(): - out, _ = model(img, kpoint) + out, gt_count = model(img, kpoint) pred_count = torch.squeeze(out, 1) - # gt_count = torch.squeeze(gt_count, 1) + gt_count = torch.squeeze(gt_count, 1) - diff = torch.abs(gt_count - torch.sum(pred_count)).item() + diff = torch.abs(gt_count_whole - torch.sum(pred_count)).item() mae += diff mse += diff ** 2 mae = mae * 1.0 / (len(test_loader) * batch_size) @@ -316,11 +327,11 @@ def valid_one_epoch(test_loader, model, device, args): if i % 5 == 0: print("[valid_one_epoch] {}\t| Gt {:.2f} Pred {:.4f}\t| mae {:.4f} mse {:.4f} |".format( - fname[0], torch.sum(gt_count).item(), torch.sum(pred_count).item(), + fname[0], torch.sum(gt_count_whole).item(), torch.sum(pred_count).item(), mae, mse )) - nni.report_intermediate_result() + nni.report_intermediate_result(mae) print("* MAE {mae:.3f} | MSE {mse:.3f} *".format( mae=mae, mse=mse ))