# Glue layer for transforming whole pictures into 384x384 sequence for encoder input from dataclasses import dataclass from itertools import product import torch import torch.nn as nn import torch.nn.functional as F import torchvision from torchvision import transforms import numpy as np from torchvision.transforms import v2 # The v2 way, apparantly. [1] class SquareCropTransformLayer(nn.Module): def __init__(self, crop_size: int): super(SquareCropTransformLayer, self).__init__() self.crop_size = crop_size def forward( self, x_: torch.Tensor, kpoints_: torch.Tensor ) -> (torch.Tensor, torch.Tensor): # Here, x_ & kpoints_ already applied affine transform. assert len(x_.shape) == 4 channels, height, width = x_.shape[1:] h_split_count = height // self.crop_size w_split_count = width // self.crop_size # Perform identical splits -- note kpoints_ does not have C dimension! ret_x = torch.cat( torch.tensor_split( torch.cat( torch.tensor_split( x_, h_split_count, dim=2 ) ), w_split_count, dim=3 ) ) # Performance should be acceptable but looks dumb as hell, is there a better way? split_t = torch.cat( torch.tensor_split( torch.cat( torch.tensor_split( t_, h_split_count, dim=1 ) ), w_split_count, dim=2 ) ) # Sum into gt_count ret_gt_count = torch.sum(split_t.view(split_t.size(0), -1), dim=1) return ret_x, ret_gt_count """ References: [1] https://pytorch.org/vision/stable/auto_examples/transforms/plot_custom_transforms.html """