Merge branch 'main' of github.com:rubberhead/mlp-project
This commit is contained in:
commit
4a03211c83
5 changed files with 99 additions and 16 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -1,2 +1,3 @@
|
||||||
baseline-experiments/
|
baseline-experiments/
|
||||||
synchronous/
|
synchronous/
|
||||||
|
npydata/
|
||||||
|
|
@ -1,3 +1,19 @@
|
||||||
|
"""
|
||||||
|
The TransCrowd paper lists ShanghaiTech dataset as from here:
|
||||||
|
|
||||||
|
https://drive.google.com/file/d/1CkYppr_IqR1s6wi53l2gKoGqm7LkJ-Lc/view
|
||||||
|
|
||||||
|
Alternatively, you could prob. download from here:
|
||||||
|
|
||||||
|
https://www.kaggle.com/datasets/tthien/shanghaitech?resource=download
|
||||||
|
|
||||||
|
After downloading, execute:
|
||||||
|
|
||||||
|
$ unzip <downloaded-zip-file> -d <repo-dir>/synchronous/dataset/
|
||||||
|
|
||||||
|
To unzip the dataset correctly prior to running this script.
|
||||||
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import glob
|
import glob
|
||||||
import random
|
import random
|
||||||
|
|
@ -10,6 +26,7 @@ import h5py
|
||||||
CWD = os.getcwd()
|
CWD = os.getcwd()
|
||||||
|
|
||||||
def pre_dataset_sh():
|
def pre_dataset_sh():
|
||||||
|
dataset_name = "ShanghaiTech"
|
||||||
root = CWD + "/synchronous/dataset/" + dataset_name + "/"
|
root = CWD + "/synchronous/dataset/" + dataset_name + "/"
|
||||||
|
|
||||||
part_A_train = os.path.join(root, "part_A_final/train_data", "images")
|
part_A_train = os.path.join(root, "part_A_final/train_data", "images")
|
||||||
|
|
@ -53,8 +70,7 @@ def pre_dataset_sh():
|
||||||
gt_data[:, 0] = gt_data[:, 0] * rate_x
|
gt_data[:, 0] = gt_data[:, 0] * rate_x
|
||||||
gt_data[:, 1] = gt_data[:, 1] * rate_y
|
gt_data[:, 1] = gt_data[:, 1] * rate_y
|
||||||
|
|
||||||
# Compute gt_count from density map (gt_data)
|
# Compute 0/1 counts from density map
|
||||||
# XXX: what does it do exactly?
|
|
||||||
kpoint = np.zeros((img_data.shape[0], img_data.shape[1]))
|
kpoint = np.zeros((img_data.shape[0], img_data.shape[1]))
|
||||||
for i in range(len(gt_data)):
|
for i in range(len(gt_data)):
|
||||||
if ( int(gt_data[i][1]) < img_data.shape[0]
|
if ( int(gt_data[i][1]) < img_data.shape[0]
|
||||||
|
|
@ -65,15 +81,14 @@ def pre_dataset_sh():
|
||||||
root_path = img_path.split("IMG_")[0].replace("images", "images_crop")
|
root_path = img_path.split("IMG_")[0].replace("images", "images_crop")
|
||||||
|
|
||||||
# Likewise, we do not crop to patched sequences here...
|
# Likewise, we do not crop to patched sequences here...
|
||||||
# Skip directly to saving fixed-size data & gt_count.
|
# Skip directly to saving fixed-size data & kpoint.
|
||||||
img_path = img_path.replace("images", "images_crop")
|
img_path = img_path.replace("images", "images_crop")
|
||||||
cv2.imwrite(img_path, img_data)
|
cv2.imwrite(img_path, img_data)
|
||||||
gt_count = np.sum(kpoint)
|
|
||||||
with h5py.File(
|
with h5py.File(
|
||||||
img_path.replace('.jpg', '.h5').replace('images', 'gt_density_map'),
|
img_path.replace('.jpg', '.h5').replace('images', 'gt_density_map'),
|
||||||
'w'
|
mode='w'
|
||||||
) as hf:
|
) as hf:
|
||||||
hf["gt_count"] = gt_count
|
hf["kpoint"] = kpoint
|
||||||
|
|
||||||
|
|
||||||
def make_npydata():
|
def make_npydata():
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
# Glue layer for transforming whole pictures into 384x384 sequence for encoder
|
# Glue layer for transforming whole pictures into 384x384 sequence for encoder input
|
||||||
# input
|
from dataclasses import dataclass
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
@ -7,3 +9,59 @@ import torchvision
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
|
||||||
|
# The v2 way, apparantly. [1]
|
||||||
|
class SquareCropTransformLayer(nn.Module):
|
||||||
|
def __init__(self, crop_size: int):
|
||||||
|
super(SquareCropTransformLayer, self).__init__()
|
||||||
|
self.crop_size = crop_size
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x_: torch.Tensor,
|
||||||
|
kpoints_: torch.Tensor
|
||||||
|
) -> (torch.Tensor, torch.Tensor):
|
||||||
|
# Here, x_ & kpoints_ already applied affine transform.
|
||||||
|
assert len(x_.shape) == 4
|
||||||
|
channels, height, width = x_.shape[1:]
|
||||||
|
h_split_count = height // self.crop_size
|
||||||
|
w_split_count = width // self.crop_size
|
||||||
|
|
||||||
|
# Perform identical splits -- note kpoints_ does not have C dimension!
|
||||||
|
ret_x = torch.cat(
|
||||||
|
torch.tensor_split(
|
||||||
|
torch.cat(
|
||||||
|
torch.tensor_split(
|
||||||
|
x_,
|
||||||
|
h_split_count,
|
||||||
|
dim=2
|
||||||
|
)
|
||||||
|
),
|
||||||
|
w_split_count,
|
||||||
|
dim=3
|
||||||
|
)
|
||||||
|
) # Performance should be acceptable but looks dumb as hell, is there a better way?
|
||||||
|
split_t = torch.cat(
|
||||||
|
torch.tensor_split(
|
||||||
|
torch.cat(
|
||||||
|
torch.tensor_split(
|
||||||
|
t_,
|
||||||
|
h_split_count,
|
||||||
|
dim=1
|
||||||
|
)
|
||||||
|
),
|
||||||
|
w_split_count,
|
||||||
|
dim=2
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sum into gt_count
|
||||||
|
ret_gt_count = torch.sum(split_t.view(split_t.size(0), -1), dim=1)
|
||||||
|
|
||||||
|
return ret_x, ret_gt_count
|
||||||
|
|
||||||
|
"""
|
||||||
|
References:
|
||||||
|
[1] https://pytorch.org/vision/stable/auto_examples/transforms/plot_custom_transforms.html
|
||||||
|
"""
|
||||||
14
model/stn.py
14
model/stn.py
|
|
@ -62,7 +62,7 @@ class STNet(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
# Spatial transformer network forward function
|
# Spatial transformer network forward function
|
||||||
def stn(self, x):
|
def stn(self, x, t):
|
||||||
xs = self.localization_net(x)
|
xs = self.localization_net(x)
|
||||||
xs = xs.view(-1, np.prod(self._loc_net_out_shape)) # -> (N, whatever)
|
xs = xs.view(-1, np.prod(self._loc_net_out_shape)) # -> (N, whatever)
|
||||||
theta = self.fc_loc(xs)
|
theta = self.fc_loc(xs)
|
||||||
|
|
@ -71,12 +71,18 @@ class STNet(nn.Module):
|
||||||
grid = F.affine_grid(theta, x.size())
|
grid = F.affine_grid(theta, x.size())
|
||||||
x = F.grid_sample(x, grid)
|
x = F.grid_sample(x, grid)
|
||||||
|
|
||||||
return x
|
# Do the same transformation to t sans training
|
||||||
|
with torch.no_grad():
|
||||||
|
t = t.view(t.size(0), 1, t.size(1), t.size(2))
|
||||||
|
t = F.grid_sample(t, grid)
|
||||||
|
t = t.squeeze(1)
|
||||||
|
|
||||||
|
return x, t
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, t):
|
||||||
# transform the input, do nothing else
|
# transform the input, do nothing else
|
||||||
return self.stn(x)
|
return self.stn(x, t)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from timm.models.registry import register_model
|
||||||
from timm.models.layers import trunc_normal_
|
from timm.models.layers import trunc_normal_
|
||||||
|
|
||||||
from .stn import STNet
|
from .stn import STNet
|
||||||
|
from .glue import SquareCropTransformLayer
|
||||||
|
|
||||||
class VisionTransformerGAP(VisionTransformer):
|
class VisionTransformerGAP(VisionTransformer):
|
||||||
# [XXX] It might be a bad idea to use vision transformer for small datasets.
|
# [XXX] It might be a bad idea to use vision transformer for small datasets.
|
||||||
|
|
@ -102,10 +103,12 @@ class STNet_VisionTransformerGAP(VisionTransformerGAP):
|
||||||
def __init__(self, img_shape: torch.Size, *args, **kwargs):
|
def __init__(self, img_shape: torch.Size, *args, **kwargs):
|
||||||
super(STNet_VisionTransformerGAP, self).__init__(*args, **kwargs)
|
super(STNet_VisionTransformerGAP, self).__init__(*args, **kwargs)
|
||||||
self.stnet = STNet(img_shape)
|
self.stnet = STNet(img_shape)
|
||||||
|
self.glue = SquareCropTransformLayer(img_size)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, t):
|
||||||
x = self.stnet(x)
|
x, t = self.stnet(x, t)
|
||||||
return super(STNet_VisionTransformerGAP, self).forward(x)
|
x, t = self.glue(x, t)
|
||||||
|
return super(STNet_VisionTransformerGAP, self).forward(x), t
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue