From f059924b75dfbf07e50fe48627dc79f2724c09a7 Mon Sep 17 00:00:00 2001 From: Zhengyi Chen Date: Sat, 2 Mar 2024 23:16:47 +0000 Subject: [PATCH 1/3] Crude impl of glue layer, not sure anything works --- .gitignore | 1 + make_dataset.py | 11 ++++----- model/glue.py | 62 +++++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 66 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index acf9ecf..6f51f96 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ baseline-experiments/ synchronous/ +npydata/ \ No newline at end of file diff --git a/make_dataset.py b/make_dataset.py index 2d1fc6b..ac9ee40 100644 --- a/make_dataset.py +++ b/make_dataset.py @@ -10,6 +10,7 @@ import h5py CWD = os.getcwd() def pre_dataset_sh(): + dataset_name = "ShanghaiTech" root = CWD + "/synchronous/dataset/" + dataset_name + "/" part_A_train = os.path.join(root, "part_A_final/train_data", "images") @@ -53,8 +54,7 @@ def pre_dataset_sh(): gt_data[:, 0] = gt_data[:, 0] * rate_x gt_data[:, 1] = gt_data[:, 1] * rate_y - # Compute gt_count from density map (gt_data) - # XXX: what does it do exactly? + # Compute 0/1 counts from density map kpoint = np.zeros((img_data.shape[0], img_data.shape[1])) for i in range(len(gt_data)): if ( int(gt_data[i][1]) < img_data.shape[0] @@ -65,15 +65,14 @@ def pre_dataset_sh(): root_path = img_path.split("IMG_")[0].replace("images", "images_crop") # Likewise, we do not crop to patched sequences here... - # Skip directly to saving fixed-size data & gt_count. + # Skip directly to saving fixed-size data & kpoint. img_path = img_path.replace("images", "images_crop") cv2.imwrite(img_path, img_data) - gt_count = np.sum(kpoint) with h5py.File( img_path.replace('.jpg', '.h5').replace('images', 'gt_density_map'), - 'w' + mode='w' ) as hf: - hf["gt_count"] = gt_count + hf["kpoint"] = kpoint def make_npydata(): diff --git a/model/glue.py b/model/glue.py index afdbb21..2750683 100644 --- a/model/glue.py +++ b/model/glue.py @@ -1,5 +1,7 @@ -# Glue layer for transforming whole pictures into 384x384 sequence for encoder -# input +# Glue layer for transforming whole pictures into 384x384 sequence for encoder input +from dataclasses import dataclass +from itertools import product + import torch import torch.nn as nn import torch.nn.functional as F @@ -7,3 +9,59 @@ import torchvision from torchvision import transforms import numpy as np +from torchvision.transforms import v2 + +# The v2 way, apparantly. [1] +class SquareCropTransformLayer(nn.Module): + def __init__(self, crop_size: int): + super(SquareCropTransformLayer, self).__init__() + self.crop_size = crop_size + + def forward( + self, + x_: torch.Tensor, + kpoints_: torch.Tensor + ) -> (torch.Tensor, torch.Tensor): + # Here, x_ & kpoints_ already applied affine transform. + assert len(x_.shape) == 4 + channels, height, width = x_.shape[1:] + h_split_count = height // self.crop_size + w_split_count = width // self.crop_size + + # Perform identical splits -- note kpoints_ does not have C dimension! + ret_x = torch.cat( + torch.tensor_split( + torch.cat( + torch.tensor_split( + x_, + h_split_count, + dim=2 + ) + ), + w_split_count, + dim=3 + ) + ) # Performance should be acceptable but looks dumb as hell, is there a better way? + split_t = torch.cat( + torch.tensor_split( + torch.cat( + torch.tensor_split( + t_, + h_split_count, + dim=1 + ) + ), + w_split_count, + dim=2 + ) + ) + + # Sum into gt_count + ret_gt_count = torch.sum(split_t.view(split_t.size(0), -1), dim=1) + + return ret_x, ret_gt_count + +""" +References: +[1] https://pytorch.org/vision/stable/auto_examples/transforms/plot_custom_transforms.html +""" \ No newline at end of file From c88b93868070559888f02bfcc5a335123942bce9 Mon Sep 17 00:00:00 2001 From: Zhengyi Chen Date: Sat, 2 Mar 2024 23:32:19 +0000 Subject: [PATCH 2/3] Extra changes, also not sure if works --- model/stn.py | 14 ++++++++++---- model/transcrowd_gap.py | 9 ++++++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/model/stn.py b/model/stn.py index e9cd307..53fa8ff 100644 --- a/model/stn.py +++ b/model/stn.py @@ -62,7 +62,7 @@ class STNet(nn.Module): # Spatial transformer network forward function - def stn(self, x): + def stn(self, x, t): xs = self.localization_net(x) xs = xs.view(-1, np.prod(self._loc_net_out_shape)) # -> (N, whatever) theta = self.fc_loc(xs) @@ -71,12 +71,18 @@ class STNet(nn.Module): grid = F.affine_grid(theta, x.size()) x = F.grid_sample(x, grid) - return x + # Do the same transformation to t sans training + with torch.no_grad(): + t = t.view(t.size(0), 1, t.size(1), t.size(2)) + t = F.grid_sample(t, grid) + t = t.squeeze(1) + + return x, t - def forward(self, x): + def forward(self, x, t): # transform the input, do nothing else - return self.stn(x) + return self.stn(x, t) if __name__ == "__main__": diff --git a/model/transcrowd_gap.py b/model/transcrowd_gap.py index ebd807f..a733b79 100644 --- a/model/transcrowd_gap.py +++ b/model/transcrowd_gap.py @@ -22,6 +22,7 @@ from timm.models.registry import register_model from timm.models.layers import trunc_normal_ from .stn import STNet +from .glue import SquareCropTransformLayer class VisionTransformerGAP(VisionTransformer): # [XXX] It might be a bad idea to use vision transformer for small datasets. @@ -102,10 +103,12 @@ class STNet_VisionTransformerGAP(VisionTransformerGAP): def __init__(self, img_shape: torch.Size, *args, **kwargs): super(STNet_VisionTransformerGAP, self).__init__(*args, **kwargs) self.stnet = STNet(img_shape) + self.glue = SquareCropTransformLayer(img_size) - def forward(self, x): - x = self.stnet(x) - return super(STNet_VisionTransformerGAP, self).forward(x) + def forward(self, x, t): + x, t = self.stnet(x, t) + x, t = self.glue(x, t) + return super(STNet_VisionTransformerGAP, self).forward(x), t @register_model From 5d77c1da4ea1e6addc8a28cb5ae7a772865c7135 Mon Sep 17 00:00:00 2001 From: rubberhead Date: Sat, 2 Mar 2024 23:40:18 +0000 Subject: [PATCH 3/3] Added comment --- make_dataset.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/make_dataset.py b/make_dataset.py index ac9ee40..5e3380c 100644 --- a/make_dataset.py +++ b/make_dataset.py @@ -1,3 +1,19 @@ +""" +The TransCrowd paper lists ShanghaiTech dataset as from here: + + https://drive.google.com/file/d/1CkYppr_IqR1s6wi53l2gKoGqm7LkJ-Lc/view + +Alternatively, you could prob. download from here: + + https://www.kaggle.com/datasets/tthien/shanghaitech?resource=download + +After downloading, execute: + + $ unzip -d /synchronous/dataset/ + +To unzip the dataset correctly prior to running this script. +""" + import os import glob import random @@ -130,4 +146,4 @@ def make_npydata(): if __name__ == "__main__": # Download manually... pre_dataset_sh() # XXX: preliminary - make_npydata() \ No newline at end of file + make_npydata()