# Glue layer for transforming whole pictures into 384x384 sequence for encoder input from dataclasses import dataclass from itertools import product # product library import torch import torch.nn as nn import torch.nn.functional as F import torchvision from torchvision import transforms import numpy as np from torchvision.transforms import v2 # The v2 way, apparantly. [1] class SquareCropTransformLayer(nn.Module): def __init__(self, crop_size: int): super(SquareCropTransformLayer, self).__init__() self.crop_size = crop_size def forward( self, x_: torch.Tensor, kpoints_: torch.Tensor ) -> (torch.Tensor, torch.Tensor): # Here, x_ & kpoints_ already applied affine transform. assert len(x_.shape) == 4 batch_size, channels, height, width = x_.shape h_split_count = height // self.crop_size w_split_count = width // self.crop_size # Perform identical splits -- note kpoints_ does not have C dimension! ret_x = x_.view( batch_size * h_split_count * w_split_count, channels, self.crop_size, self.crop_size, ) split_t = kpoints_.view( batch_size * h_split_count * w_split_count, self.crop_size, self.crop_size, ) # Sum into gt_count ret_gt_count = (torch .sum(split_t.view(split_t.size(0), -1), dim=1) .unsqueeze(1) ) return ret_x, ret_gt_count """ References: [1] https://pytorch.org/vision/stable/auto_examples/transforms/plot_custom_transforms.html """