Updated glue impl, complex loss fn

This commit is contained in:
Zhengyi Chen 2024-03-04 04:10:17 +00:00
parent 83fcc43f0b
commit d46a027e3f
2 changed files with 37 additions and 41 deletions

View file

@ -24,36 +24,21 @@ class SquareCropTransformLayer(nn.Module):
) -> (torch.Tensor, torch.Tensor): ) -> (torch.Tensor, torch.Tensor):
# Here, x_ & kpoints_ already applied affine transform. # Here, x_ & kpoints_ already applied affine transform.
assert len(x_.shape) == 4 assert len(x_.shape) == 4
channels, height, width = x_.shape[1:] batch_size, channels, height, width = x_.shape
h_split_count = height // self.crop_size h_split_count = height // self.crop_size
w_split_count = width // self.crop_size w_split_count = width // self.crop_size
# Perform identical splits -- note kpoints_ does not have C dimension! # Perform identical splits -- note kpoints_ does not have C dimension!
ret_x = torch.cat( ret_x = x_.view(
torch.tensor_split( batch_size * h_split_count * w_split_count,
torch.cat( channels,
torch.tensor_split( self.crop_size,
x_, self.crop_size,
h_split_count,
dim=2
)
),
w_split_count,
dim=3
)
) # Performance should be acceptable but looks dumb as hell, is there a better way?
split_t = torch.cat(
torch.tensor_split(
torch.cat(
torch.tensor_split(
kpoints_,
h_split_count,
dim=1
)
),
w_split_count,
dim=2
) )
split_t = kpoints_.view(
batch_size * h_split_count * w_split_count,
self.crop_size,
self.crop_size,
) )
# Sum into gt_count # Sum into gt_count

View file

@ -239,22 +239,33 @@ def train_one_epoch(
model.train() model.train()
# In one epoch, for each training sample # In one epoch, for each training sample
for i, (fname, img, kpoint, gt_count) in enumerate(train_loader): for i, (fname, img, kpoint, gt_count_whole) in enumerate(train_loader):
kpoint = kpoint.type(torch.FloatTensor) kpoint = kpoint.type(torch.FloatTensor)
gt_count = gt_count.type(torch.FloatTensor) gt_count_whole = gt_count_whole.type(torch.FloatTensor).unsqueeze(1)
batch_size = img.size(0)
# fpass # fpass
if device is not None: if device is not None:
img = img.to(device) img = img.to(device)
kpoint = kpoint.to(device) kpoint = kpoint.to(device)
gt_count = gt_count.to(device) gt_count_whole = gt_count_whole.to(device)
elif torch.cuda.is_available(): elif torch.cuda.is_available():
img = img.cuda() img = img.cuda()
kpoint = kpoint.cuda() kpoint = kpoint.cuda()
gt_count = gt_count.cuda() gt_count_whole = gt_count_whole.cuda()
out, _ = model(img, kpoint) out, gt_count = model(img, kpoint)
# loss # loss
loss = criterion(out, gt_count) loss = criterion(out, gt_count) # wrt. transformer
loss += (
criterion( # stn: info retainment
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
gt_count_whole)
+ F.threshold( # stn: perspective correction
gt_count.view(batch_size, -1).var(dim=1).mean(),
threshold=loss.item(),
value=loss.item()
)
)
# free grad from mem # free grad from mem
optimizer.zero_grad() optimizer.zero_grad()
@ -285,17 +296,17 @@ def valid_one_epoch(test_loader, model, device, args):
visi = [] visi = []
index = 0 index = 0
for i, (fname, img, kpoint, gt_count) in enumerate(test_loader): for i, (fname, img, kpoint, gt_count_whole) in enumerate(test_loader):
kpoint = kpoint.type(torch.FloatTensor) kpoint = kpoint.type(torch.FloatTensor)
gt_count = gt_count.type(torch.FloatTensor) gt_count_whole = gt_count_whole.type(torch.FloatTensor)
if device is not None: if device is not None:
img = img.to(device) img = img.to(device)
kpoint = kpoint.to(device) kpoint = kpoint.to(device)
gt_count = gt_count.to(device) gt_count_whole = gt_count_whole.to(device)
elif torch.cuda.is_available(): elif torch.cuda.is_available():
img = img.cuda() img = img.cuda()
kpoint = kpoint.cuda() kpoint = kpoint.cuda()
gt_count = gt_count.cuda() gt_count_whole = gt_count_whole.cuda()
# XXX: do this even happen? # XXX: do this even happen?
if len(img.shape) == 5: if len(img.shape) == 5:
@ -304,11 +315,11 @@ def valid_one_epoch(test_loader, model, device, args):
img = img.unsqueeze(0) img = img.unsqueeze(0)
with torch.no_grad(): with torch.no_grad():
out, _ = model(img, kpoint) out, gt_count = model(img, kpoint)
pred_count = torch.squeeze(out, 1) pred_count = torch.squeeze(out, 1)
# gt_count = torch.squeeze(gt_count, 1) gt_count = torch.squeeze(gt_count, 1)
diff = torch.abs(gt_count - torch.sum(pred_count)).item() diff = torch.abs(gt_count_whole - torch.sum(pred_count)).item()
mae += diff mae += diff
mse += diff ** 2 mse += diff ** 2
mae = mae * 1.0 / (len(test_loader) * batch_size) mae = mae * 1.0 / (len(test_loader) * batch_size)
@ -316,11 +327,11 @@ def valid_one_epoch(test_loader, model, device, args):
if i % 5 == 0: if i % 5 == 0:
print("[valid_one_epoch] {}\t| Gt {:.2f} Pred {:.4f}\t| mae {:.4f} mse {:.4f} |".format( print("[valid_one_epoch] {}\t| Gt {:.2f} Pred {:.4f}\t| mae {:.4f} mse {:.4f} |".format(
fname[0], torch.sum(gt_count).item(), torch.sum(pred_count).item(), fname[0], torch.sum(gt_count_whole).item(), torch.sum(pred_count).item(),
mae, mse mae, mse
)) ))
nni.report_intermediate_result() nni.report_intermediate_result(mae)
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format( print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
mae=mae, mse=mse mae=mae, mse=mse
)) ))