Updated glue impl, complex loss fn
This commit is contained in:
parent
83fcc43f0b
commit
d46a027e3f
2 changed files with 37 additions and 41 deletions
|
|
@ -24,36 +24,21 @@ class SquareCropTransformLayer(nn.Module):
|
|||
) -> (torch.Tensor, torch.Tensor):
|
||||
# Here, x_ & kpoints_ already applied affine transform.
|
||||
assert len(x_.shape) == 4
|
||||
channels, height, width = x_.shape[1:]
|
||||
batch_size, channels, height, width = x_.shape
|
||||
h_split_count = height // self.crop_size
|
||||
w_split_count = width // self.crop_size
|
||||
|
||||
# Perform identical splits -- note kpoints_ does not have C dimension!
|
||||
ret_x = torch.cat(
|
||||
torch.tensor_split(
|
||||
torch.cat(
|
||||
torch.tensor_split(
|
||||
x_,
|
||||
h_split_count,
|
||||
dim=2
|
||||
)
|
||||
),
|
||||
w_split_count,
|
||||
dim=3
|
||||
)
|
||||
) # Performance should be acceptable but looks dumb as hell, is there a better way?
|
||||
split_t = torch.cat(
|
||||
torch.tensor_split(
|
||||
torch.cat(
|
||||
torch.tensor_split(
|
||||
kpoints_,
|
||||
h_split_count,
|
||||
dim=1
|
||||
)
|
||||
),
|
||||
w_split_count,
|
||||
dim=2
|
||||
)
|
||||
ret_x = x_.view(
|
||||
batch_size * h_split_count * w_split_count,
|
||||
channels,
|
||||
self.crop_size,
|
||||
self.crop_size,
|
||||
)
|
||||
split_t = kpoints_.view(
|
||||
batch_size * h_split_count * w_split_count,
|
||||
self.crop_size,
|
||||
self.crop_size,
|
||||
)
|
||||
|
||||
# Sum into gt_count
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue