Updated glue impl, complex loss fn
This commit is contained in:
parent
83fcc43f0b
commit
d46a027e3f
2 changed files with 37 additions and 41 deletions
|
|
@ -24,36 +24,21 @@ class SquareCropTransformLayer(nn.Module):
|
||||||
) -> (torch.Tensor, torch.Tensor):
|
) -> (torch.Tensor, torch.Tensor):
|
||||||
# Here, x_ & kpoints_ already applied affine transform.
|
# Here, x_ & kpoints_ already applied affine transform.
|
||||||
assert len(x_.shape) == 4
|
assert len(x_.shape) == 4
|
||||||
channels, height, width = x_.shape[1:]
|
batch_size, channels, height, width = x_.shape
|
||||||
h_split_count = height // self.crop_size
|
h_split_count = height // self.crop_size
|
||||||
w_split_count = width // self.crop_size
|
w_split_count = width // self.crop_size
|
||||||
|
|
||||||
# Perform identical splits -- note kpoints_ does not have C dimension!
|
# Perform identical splits -- note kpoints_ does not have C dimension!
|
||||||
ret_x = torch.cat(
|
ret_x = x_.view(
|
||||||
torch.tensor_split(
|
batch_size * h_split_count * w_split_count,
|
||||||
torch.cat(
|
channels,
|
||||||
torch.tensor_split(
|
self.crop_size,
|
||||||
x_,
|
self.crop_size,
|
||||||
h_split_count,
|
)
|
||||||
dim=2
|
split_t = kpoints_.view(
|
||||||
)
|
batch_size * h_split_count * w_split_count,
|
||||||
),
|
self.crop_size,
|
||||||
w_split_count,
|
self.crop_size,
|
||||||
dim=3
|
|
||||||
)
|
|
||||||
) # Performance should be acceptable but looks dumb as hell, is there a better way?
|
|
||||||
split_t = torch.cat(
|
|
||||||
torch.tensor_split(
|
|
||||||
torch.cat(
|
|
||||||
torch.tensor_split(
|
|
||||||
kpoints_,
|
|
||||||
h_split_count,
|
|
||||||
dim=1
|
|
||||||
)
|
|
||||||
),
|
|
||||||
w_split_count,
|
|
||||||
dim=2
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sum into gt_count
|
# Sum into gt_count
|
||||||
|
|
|
||||||
41
train.py
41
train.py
|
|
@ -239,22 +239,33 @@ def train_one_epoch(
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
# In one epoch, for each training sample
|
# In one epoch, for each training sample
|
||||||
for i, (fname, img, kpoint, gt_count) in enumerate(train_loader):
|
for i, (fname, img, kpoint, gt_count_whole) in enumerate(train_loader):
|
||||||
kpoint = kpoint.type(torch.FloatTensor)
|
kpoint = kpoint.type(torch.FloatTensor)
|
||||||
gt_count = gt_count.type(torch.FloatTensor)
|
gt_count_whole = gt_count_whole.type(torch.FloatTensor).unsqueeze(1)
|
||||||
|
batch_size = img.size(0)
|
||||||
# fpass
|
# fpass
|
||||||
if device is not None:
|
if device is not None:
|
||||||
img = img.to(device)
|
img = img.to(device)
|
||||||
kpoint = kpoint.to(device)
|
kpoint = kpoint.to(device)
|
||||||
gt_count = gt_count.to(device)
|
gt_count_whole = gt_count_whole.to(device)
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
img = img.cuda()
|
img = img.cuda()
|
||||||
kpoint = kpoint.cuda()
|
kpoint = kpoint.cuda()
|
||||||
gt_count = gt_count.cuda()
|
gt_count_whole = gt_count_whole.cuda()
|
||||||
out, _ = model(img, kpoint)
|
out, gt_count = model(img, kpoint)
|
||||||
|
|
||||||
# loss
|
# loss
|
||||||
loss = criterion(out, gt_count)
|
loss = criterion(out, gt_count) # wrt. transformer
|
||||||
|
loss += (
|
||||||
|
criterion( # stn: info retainment
|
||||||
|
gt_count.view(batch_size, -1).sum(axis=1, keepdim=True),
|
||||||
|
gt_count_whole)
|
||||||
|
+ F.threshold( # stn: perspective correction
|
||||||
|
gt_count.view(batch_size, -1).var(dim=1).mean(),
|
||||||
|
threshold=loss.item(),
|
||||||
|
value=loss.item()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# free grad from mem
|
# free grad from mem
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
@ -285,17 +296,17 @@ def valid_one_epoch(test_loader, model, device, args):
|
||||||
visi = []
|
visi = []
|
||||||
index = 0
|
index = 0
|
||||||
|
|
||||||
for i, (fname, img, kpoint, gt_count) in enumerate(test_loader):
|
for i, (fname, img, kpoint, gt_count_whole) in enumerate(test_loader):
|
||||||
kpoint = kpoint.type(torch.FloatTensor)
|
kpoint = kpoint.type(torch.FloatTensor)
|
||||||
gt_count = gt_count.type(torch.FloatTensor)
|
gt_count_whole = gt_count_whole.type(torch.FloatTensor)
|
||||||
if device is not None:
|
if device is not None:
|
||||||
img = img.to(device)
|
img = img.to(device)
|
||||||
kpoint = kpoint.to(device)
|
kpoint = kpoint.to(device)
|
||||||
gt_count = gt_count.to(device)
|
gt_count_whole = gt_count_whole.to(device)
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
img = img.cuda()
|
img = img.cuda()
|
||||||
kpoint = kpoint.cuda()
|
kpoint = kpoint.cuda()
|
||||||
gt_count = gt_count.cuda()
|
gt_count_whole = gt_count_whole.cuda()
|
||||||
|
|
||||||
# XXX: do this even happen?
|
# XXX: do this even happen?
|
||||||
if len(img.shape) == 5:
|
if len(img.shape) == 5:
|
||||||
|
|
@ -304,11 +315,11 @@ def valid_one_epoch(test_loader, model, device, args):
|
||||||
img = img.unsqueeze(0)
|
img = img.unsqueeze(0)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
out, _ = model(img, kpoint)
|
out, gt_count = model(img, kpoint)
|
||||||
pred_count = torch.squeeze(out, 1)
|
pred_count = torch.squeeze(out, 1)
|
||||||
# gt_count = torch.squeeze(gt_count, 1)
|
gt_count = torch.squeeze(gt_count, 1)
|
||||||
|
|
||||||
diff = torch.abs(gt_count - torch.sum(pred_count)).item()
|
diff = torch.abs(gt_count_whole - torch.sum(pred_count)).item()
|
||||||
mae += diff
|
mae += diff
|
||||||
mse += diff ** 2
|
mse += diff ** 2
|
||||||
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
mae = mae * 1.0 / (len(test_loader) * batch_size)
|
||||||
|
|
@ -316,11 +327,11 @@ def valid_one_epoch(test_loader, model, device, args):
|
||||||
|
|
||||||
if i % 5 == 0:
|
if i % 5 == 0:
|
||||||
print("[valid_one_epoch] {}\t| Gt {:.2f} Pred {:.4f}\t| mae {:.4f} mse {:.4f} |".format(
|
print("[valid_one_epoch] {}\t| Gt {:.2f} Pred {:.4f}\t| mae {:.4f} mse {:.4f} |".format(
|
||||||
fname[0], torch.sum(gt_count).item(), torch.sum(pred_count).item(),
|
fname[0], torch.sum(gt_count_whole).item(), torch.sum(pred_count).item(),
|
||||||
mae, mse
|
mae, mse
|
||||||
))
|
))
|
||||||
|
|
||||||
nni.report_intermediate_result()
|
nni.report_intermediate_result(mae)
|
||||||
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
|
print("* MAE {mae:.3f} | MSE {mse:.3f} *".format(
|
||||||
mae=mae, mse=mse
|
mae=mae, mse=mse
|
||||||
))
|
))
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue