Updated glue impl, complex loss fn

This commit is contained in:
Zhengyi Chen 2024-03-04 04:10:17 +00:00
parent 83fcc43f0b
commit d46a027e3f
2 changed files with 37 additions and 41 deletions

View file

@ -24,36 +24,21 @@ class SquareCropTransformLayer(nn.Module):
) -> (torch.Tensor, torch.Tensor):
# Here, x_ & kpoints_ already applied affine transform.
assert len(x_.shape) == 4
channels, height, width = x_.shape[1:]
batch_size, channels, height, width = x_.shape
h_split_count = height // self.crop_size
w_split_count = width // self.crop_size
# Perform identical splits -- note kpoints_ does not have C dimension!
ret_x = torch.cat(
torch.tensor_split(
torch.cat(
torch.tensor_split(
x_,
h_split_count,
dim=2
)
),
w_split_count,
dim=3
)
) # Performance should be acceptable but looks dumb as hell, is there a better way?
split_t = torch.cat(
torch.tensor_split(
torch.cat(
torch.tensor_split(
kpoints_,
h_split_count,
dim=1
)
),
w_split_count,
dim=2
)
ret_x = x_.view(
batch_size * h_split_count * w_split_count,
channels,
self.crop_size,
self.crop_size,
)
split_t = kpoints_.view(
batch_size * h_split_count * w_split_count,
self.crop_size,
self.crop_size,
)
# Sum into gt_count