Loss revamp & Renamed model to network
This commit is contained in:
parent
0d35d607fe
commit
9d2a30a226
7 changed files with 44 additions and 32 deletions
59
network/csrnet.py
Normal file
59
network/csrnet.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
# Stolen from https://github.com/leeyeehoo/CSRNet-pytorch.git
|
||||
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from torchvision import models
|
||||
from utils import save_net,load_net
|
||||
|
||||
class CSRNet(nn.Module):
|
||||
def __init__(self, load_weights=False):
|
||||
super(CSRNet, self).__init__()
|
||||
|
||||
# Ref. 2018 paper
|
||||
self.seen = 0
|
||||
self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
|
||||
self.backend_feat = [512, 512, 512, 256, 128, 64] # 4-parallel, 1, 2, 2-then-4, 4 dilation rates
|
||||
|
||||
self.frontend = make_layers(self.frontend_feat)
|
||||
self.backend = make_layers(self.backend_feat,in_channels = 512,dilation = True)
|
||||
self.output_layer = nn.Conv2d(64, 1, kernel_size=1)
|
||||
if not load_weights:
|
||||
mod = models.vgg16(pretrained = True)
|
||||
self._initialize_weights()
|
||||
for i in range(len(self.frontend.state_dict().items())):
|
||||
self.frontend.state_dict().items()[i][1].data[:] = mod.state_dict().items()[i][1].data[:]
|
||||
|
||||
def forward(self,x):
|
||||
x = self.frontend(x)
|
||||
x = self.backend(x)
|
||||
x = self.output_layer(x)
|
||||
return x
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.normal_(m.weight, std=0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
def make_layers(cfg, in_channels = 3, batch_norm=False, dilation=False):
|
||||
if dilation:
|
||||
d_rate = 2
|
||||
else:
|
||||
d_rate = 1
|
||||
layers = []
|
||||
for v in cfg:
|
||||
if v == 'M':
|
||||
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
||||
else:
|
||||
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate, dilation=d_rate)
|
||||
if batch_norm:
|
||||
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
|
||||
else:
|
||||
layers += [conv2d, nn.ReLU(inplace=True)]
|
||||
in_channels = v
|
||||
return nn.Sequential(*layers)
|
||||
55
network/glue.py
Normal file
55
network/glue.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
# Glue layer for transforming whole pictures into 384x384 sequence for encoder input
|
||||
from dataclasses import dataclass
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
|
||||
import numpy as np
|
||||
from torchvision.transforms import v2
|
||||
|
||||
# The v2 way, apparantly. [1]
|
||||
class SquareCropTransformLayer(nn.Module):
|
||||
def __init__(self, crop_size: int):
|
||||
super(SquareCropTransformLayer, self).__init__()
|
||||
self.crop_size = crop_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_: torch.Tensor,
|
||||
kpoints_: torch.Tensor
|
||||
) -> (torch.Tensor, torch.Tensor):
|
||||
# Here, x_ & kpoints_ already applied affine transform.
|
||||
assert len(x_.shape) == 4
|
||||
batch_size, channels, height, width = x_.shape
|
||||
h_split_count = height // self.crop_size
|
||||
w_split_count = width // self.crop_size
|
||||
|
||||
# Perform identical splits -- note kpoints_ does not have C dimension!
|
||||
ret_x = x_.view(
|
||||
batch_size * h_split_count * w_split_count,
|
||||
channels,
|
||||
self.crop_size,
|
||||
self.crop_size,
|
||||
)
|
||||
split_t = kpoints_.view(
|
||||
batch_size * h_split_count * w_split_count,
|
||||
self.crop_size,
|
||||
self.crop_size,
|
||||
)
|
||||
|
||||
# Sum into gt_count
|
||||
ret_gt_count = (torch
|
||||
.sum(split_t.view(split_t.size(0), -1), dim=1)
|
||||
.unsqueeze(1)
|
||||
)
|
||||
|
||||
return ret_x, ret_gt_count
|
||||
|
||||
"""
|
||||
References:
|
||||
[1] https://pytorch.org/vision/stable/auto_examples/transforms/plot_custom_transforms.html
|
||||
"""
|
||||
162
network/reverse_perspective.py
Normal file
162
network/reverse_perspective.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
r"""Reverse Perspective Network Architectural Layers.
|
||||
|
||||
The *Reverse Perspective Network* [#]_ is a general approach to input
|
||||
pre-processing for instance segmentation / density map generation tasks.
|
||||
Roughly speaking, it models the input image into a elliptic coordinate system
|
||||
and tries to learn a foci length modifier parameter to perform perspective
|
||||
transformation on input images.
|
||||
|
||||
.. [#] Yang, Y., Li, G., Wu, Z., Su, L., Huang, Q., & Sebe, N. (2020).
|
||||
Reverse perspective network for perspective-aware object counting.
|
||||
In Proceedings of the IEEE/CVF conference on computer vision and pattern
|
||||
recognition (pp. 4374-4383).
|
||||
"""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class PerspectiveEstimator(nn.Module):
|
||||
"""
|
||||
Perspective estimator submodule of the wider reverse-perspective network.
|
||||
|
||||
Input: Pre-processed, uniformly-sized image data
|
||||
Output: Perspective factor :math:`\\in \\mathbb{R}`
|
||||
|
||||
**Note**
|
||||
--------
|
||||
Loss input needs to be computed from beyond the **entire** rev-perspective
|
||||
network. Needs to therefore compute:
|
||||
- Effective pixel of each row after transformation.
|
||||
- Feature density (count) along row, summed over column.
|
||||
|
||||
Loss is computed as a variance over row feature densities. Ref. paper 3.2.
|
||||
After all, it is reasonable to say that you see more when you look at
|
||||
faraway places.
|
||||
|
||||
The paper utilizes a unsupervised loss -- "row feature density" refers to
|
||||
the density of features computed from ?
|
||||
|
||||
:param input_shape: (N, C, H, W)
|
||||
:param conv_kernel_shape: Oriented as (H, W)
|
||||
:param conv_dilation: equidistance dilation factor along H, W
|
||||
:param pool_capacity: K-number of classes for each (H, W) to be pooled into
|
||||
:param epsilon: Hyperparameter.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
input_shape: Tuple[int, int, int, int],
|
||||
conv_kernel_shape: Tuple[int, int],
|
||||
conv_dilation: int, # We will do equidistance dilation along H, W for now
|
||||
pool_capacity: int,
|
||||
conv_padding: int = 0,
|
||||
conv_padding_mode: str = 'zeros',
|
||||
conv_stride: int = 1,
|
||||
epsilon: float = 1e-5,
|
||||
*args, **kwargs
|
||||
) -> None:
|
||||
# N.B. input_shape has size (N, C_in, H_in, W_in)
|
||||
(_, _, height, width) = input_shape
|
||||
|
||||
# Sanity checking
|
||||
# [TODO] Maybe this is unnecessary, maybe we can automatically suggest new params,
|
||||
# but right now let's just do this...
|
||||
(_conv_height, _conv_width) = (
|
||||
np.floor(
|
||||
(height + 2 * conv_padding - conv_dilation * (conv_kernel_shape[0] - 1) - 1)
|
||||
/ conv_stride
|
||||
+ 1
|
||||
),
|
||||
np.floor(
|
||||
(width + 2 * conv_padding - conv_dilation * (conv_kernel_shape[1] - 1) - 1)
|
||||
/ conv_stride
|
||||
+ 1
|
||||
)
|
||||
)
|
||||
assert(height == _conv_height and width == _conv_width)
|
||||
|
||||
super.__init__(self, *args, **kwargs)
|
||||
self.epsilon = epsilon
|
||||
self.input_shape = input_shape
|
||||
self.layer_dict = nn.ModuleDict({
|
||||
'revpers_dilated_conv0': nn.Conv2d(
|
||||
in_channels=self.input_shape[1], out_channels=1,
|
||||
kernel_size=conv_kernel_shape,
|
||||
padding=conv_padding,
|
||||
padding_mode = conv_padding_mode,
|
||||
stride=conv_stride,
|
||||
dilation=conv_dilation,
|
||||
), # (N, 1, H, W)
|
||||
'revpers_avg_pool0': nn.AdaptiveAvgPool2d(
|
||||
output_size=(pool_capacity, 1)
|
||||
), # (N, 1, K, 1)
|
||||
# [?] Do we need to explicitly translate to (N, K) here?
|
||||
'revpers_fc0': nn.Linear(
|
||||
in_features=pool_capacity,
|
||||
out_features=1,
|
||||
),
|
||||
})
|
||||
|
||||
def forward(self, x):
|
||||
out = x
|
||||
|
||||
# Forward through layers -- there are no activations etc. in-between
|
||||
for (_, layer) in self.layer_dict:
|
||||
out = layer.forward(out)
|
||||
|
||||
# Normalize in (0, 1]
|
||||
F.relu(out, inplace=True)
|
||||
out = torch.exp(-out) + self.epsilon
|
||||
|
||||
return out
|
||||
|
||||
# def unsupervised_loss(predictions, targets):
|
||||
|
||||
# [TODO] We need a modified loss -- one that takes advantage of attention instead
|
||||
# of feature map. I feel like they should work likewise but who knows
|
||||
# [XXX] no forget it, we are pre-training rev-perspective as told by the 2020 paper
|
||||
# i.e., via using CSRNet.
|
||||
# Not sure which part is the feature map derived. Maybe after the front-end?
|
||||
# In any case we can always just use the CSR output (inferred density map) as feature map --
|
||||
# through which we compute, for each image:
|
||||
# criterion = Variance([output.sum(axis=W) * effective_pixel_per_row])
|
||||
# In other cases we sum over channels i.e., each feature map i.e., over each filter output
|
||||
# Not sure what channel means in this case...
|
||||
def warped_output_loss(csrnet_pred):
|
||||
N, H, W = csrnet_pred.shape()
|
||||
|
||||
|
||||
def transform_coordinates(
|
||||
img: torch.Tensor, # (C, W, H)
|
||||
factor: float,
|
||||
in_place: bool = True
|
||||
):
|
||||
dev_of_img = img.device
|
||||
|
||||
# Normalize X coords to [0, pi]
|
||||
min_x = torch.Tensor([0., 0., 0.]).to(dev_of_img)
|
||||
max_x = torch.Tensor([0., np.pi, 0.]).to(dev_of_img)
|
||||
min_xdim = torch.min(img, dim=1, keepdim=True)[0]
|
||||
max_xdim = torch.max(img, dim=1, keepdim=True)[0]
|
||||
(img.sub_(min_xdim)
|
||||
.div_(max_xdim - min_xdim)
|
||||
.mul_(max_x - min_x)
|
||||
.add_(min_x))
|
||||
|
||||
# Normalize Y coords to [0, 1]
|
||||
min_y = torch.Tensor([0., 0., 0.]).to(dev_of_img)
|
||||
max_y = torch.Tensor([0., 1., 0.]).to(dev_of_img)
|
||||
min_ydim = torch.min(img, dim=2, keepdim=True)[0]
|
||||
max_ydim = torch.max(img, dim=2, keepdim=True)[0]
|
||||
(img.sub_(min_ydim)
|
||||
.div_(max_ydim - min_ydim)
|
||||
.mul_(max_y - min_y)
|
||||
.add_(min_y))
|
||||
|
||||
# Do elliptical transformation
|
||||
tmp = img.clone().detach()
|
||||
|
||||
pass
|
||||
0
network/revpers_csrnet.py
Normal file
0
network/revpers_csrnet.py
Normal file
232
network/stn.py
Normal file
232
network/stn.py
Normal file
|
|
@ -0,0 +1,232 @@
|
|||
# stole from pytorch tutorial
|
||||
# "Great artists steal" -- they say,
|
||||
# but thieves also steal so you know :P
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torchvision
|
||||
from torchvision import datasets, transforms
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
class STNet(nn.Module):
|
||||
def __init__(self, input_size: torch.Size):
|
||||
super(STNet, self).__init__()
|
||||
|
||||
# Sanity check
|
||||
assert 5 > len(input_size) > 2 # single or batch ([N, ]C, H, W)
|
||||
if len(input_size) == 3:
|
||||
channels, height, width = input_size
|
||||
_dummy_size_ = torch.Size([1]) + input_size
|
||||
else:
|
||||
channels, height, width = input_size[1:]
|
||||
_dummy_size_ = input_size
|
||||
|
||||
# shape checking
|
||||
_dummy_x_ = torch.zeros(_dummy_size_)
|
||||
|
||||
# (3.1) Spatial transformer localization-network
|
||||
self.localization_net = nn.Sequential( # (N, C, H, W)
|
||||
nn.Conv2d(channels, 8, kernel_size=7), # (N, 8, H-6, W-6)
|
||||
nn.MaxPool2d(2, stride=2), # (N, 8, (H-6)/2, (W-6)/2)
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(8, 10, kernel_size=5), # (N, 10, ((H-6)/2)-4, ((W-6)/2)-4)
|
||||
nn.MaxPool2d(2, stride=2), # (N, 10, (((H-6)/2)-4)/2, (((H-6)/2)-4)/2)
|
||||
nn.ReLU(True)
|
||||
) # -> (N, 10, (((H-6)/2)-4)/2, (((H-6)/2)-4)/2)
|
||||
_dummy_x_ = self.localization_net(_dummy_x_)
|
||||
self._loc_net_out_shape = _dummy_x_.shape[1:]
|
||||
|
||||
# (3.2) Regressor for the 3 * 2 affine matrix
|
||||
self.fc_loc = nn.Sequential(
|
||||
# XXX: Dimensionality reduction across channels or not?
|
||||
nn.Linear(np.prod(self._loc_net_out_shape), 32),
|
||||
nn.ReLU(True),
|
||||
nn.Linear(32, 3 * 2)
|
||||
) # -> (6,)
|
||||
|
||||
# Initialize the weights/bias with identity transformation
|
||||
self.fc_loc[2].weight.data.zero_()
|
||||
self.fc_loc[2].bias.data.copy_(
|
||||
torch.tensor(
|
||||
[1, 0, 0,
|
||||
0, 1, 0],
|
||||
dtype=torch.float
|
||||
)
|
||||
)
|
||||
_dummy_x_ = self.fc_loc(
|
||||
_dummy_x_.view(-1, np.prod(self._loc_net_out_shape))
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
# Spatial transformer network forward function
|
||||
def stn(self, x, t):
|
||||
xs = self.localization_net(x)
|
||||
xs = xs.view(-1, np.prod(self._loc_net_out_shape)) # -> (N, whatever)
|
||||
theta = self.fc_loc(xs)
|
||||
theta = theta.view(-1, 2, 3) # -> (2, 3)
|
||||
|
||||
grid = F.affine_grid(theta, x.size(), align_corners=False)
|
||||
x = F.grid_sample(x, grid, align_corners=False)
|
||||
|
||||
# Do the same transformation to t sans training
|
||||
with torch.no_grad():
|
||||
t = t.view(t.size(0), 1, t.size(1), t.size(2))
|
||||
t = F.grid_sample(t, grid, align_corners=False)
|
||||
t = t.squeeze(1)
|
||||
|
||||
return x, t
|
||||
|
||||
|
||||
def forward(self, x, t):
|
||||
# print("STN: {} | {}".format(x.shape, t.shape))
|
||||
# transform the input, do nothing else
|
||||
return self.stn(x, t)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
class STNetDebug(STNet):
|
||||
def __init__(self, input_size: torch.Size):
|
||||
super(STNetDebug, self).__init__(input_size)
|
||||
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
self.conv2_drop = nn.Dropout2d()
|
||||
self.fc1 = nn.Linear(320, 50)
|
||||
self.fc2 = nn.Linear(50, 10)
|
||||
|
||||
def forward(self, x):
|
||||
# Transform the input
|
||||
x = self.stn(x)
|
||||
|
||||
# Perform usual forward pass for MNIST
|
||||
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
||||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
||||
x = x.view(-1, 320)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.dropout(x, training=self.training)
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset=datasets.MNIST(
|
||||
root="./synchronous/",
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((.1307, ), (.3081, ))
|
||||
]),
|
||||
),
|
||||
batch_size=64,
|
||||
shuffle=True,
|
||||
num_workers=4,
|
||||
)
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
dataset=datasets.MNIST(
|
||||
root = "./synchronous/",
|
||||
train=False,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((.1307, ), (.3081, ))
|
||||
]),
|
||||
),
|
||||
batch_size=64,
|
||||
shuffle=True,
|
||||
num_workers=4,
|
||||
)
|
||||
shape_of_input = next(iter(train_loader))[0].shape
|
||||
model = STNetDebug(shape_of_input).to(device)
|
||||
optimizer = optim.SGD(model.parameters(), lr=.01)
|
||||
def train(epoch):
|
||||
model.train()
|
||||
for i, (x_, t_) in enumerate(train_loader):
|
||||
# XXX: x_.shape == (N, C, H, W)
|
||||
# Inference
|
||||
x_, t_ = x_.to(device), t_.to(device)
|
||||
optimizer.zero_grad()
|
||||
y_ = model(x_)
|
||||
# Backprop
|
||||
l_ = F.nll_loss(y_, t_)
|
||||
l_.backward()
|
||||
optimizer.step()
|
||||
if i % 500 == 0:
|
||||
print("Epoch {}: [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
|
||||
epoch,
|
||||
i * len(x_),
|
||||
len(train_loader.dataset),
|
||||
100. * i / len(train_loader),
|
||||
l_.item()
|
||||
))
|
||||
|
||||
def valid():
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
valid_loss = 0
|
||||
correct = 0
|
||||
for x_, t_ in valid_loader:
|
||||
x_, t_ = x_.to(device), t_.to(device)
|
||||
y_ = model(x_)
|
||||
# Sum batch loss
|
||||
valid_loss += F.nll_loss(y_, t_, size_average=False).item()
|
||||
pred = y_.max(1, keepdim=True)[1]
|
||||
correct += pred.eq(t_.view_as(pred)).sum().item()
|
||||
|
||||
valid_loss /= len(valid_loader.dataset)
|
||||
print("\nValid set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n"
|
||||
.format(
|
||||
valid_loss, correct, len(valid_loader.dataset),
|
||||
100. * correct / len(valid_loader.dataset)
|
||||
)
|
||||
)
|
||||
|
||||
def convert_image_np(inp):
|
||||
"""Convert a Tensor to numpy image."""
|
||||
inp = inp.numpy().transpose((1, 2, 0))
|
||||
mean = np.array([0.485, 0.456, 0.406])
|
||||
std = np.array([0.229, 0.224, 0.225])
|
||||
inp = std * inp + mean
|
||||
inp = np.clip(inp, 0, 1)
|
||||
return inp
|
||||
|
||||
# We want to visualize the output of the spatial transformers layer
|
||||
# after the training, we visualize a batch of input images and
|
||||
# the corresponding transformed batch using STN.
|
||||
def visualize_stn():
|
||||
with torch.no_grad():
|
||||
# Get a batch of training data
|
||||
data = next(iter(valid_loader))[0].to(device)
|
||||
|
||||
input_tensor = data.cpu()
|
||||
transformed_input_tensor = model.stn(data).cpu()
|
||||
|
||||
in_grid = convert_image_np(
|
||||
torchvision.utils.make_grid(input_tensor))
|
||||
|
||||
out_grid = convert_image_np(
|
||||
torchvision.utils.make_grid(transformed_input_tensor))
|
||||
|
||||
# Plot the results side-by-side
|
||||
f, axarr = plt.subplots(1, 2)
|
||||
axarr[0].imshow(in_grid)
|
||||
axarr[0].set_title('Dataset Images')
|
||||
|
||||
axarr[1].imshow(out_grid)
|
||||
axarr[1].set_title('Transformed Images')
|
||||
|
||||
for epoch in range(10):
|
||||
train(epoch + 1)
|
||||
valid()
|
||||
|
||||
# Visualize the STN transformation on some input batch
|
||||
visualize_stn()
|
||||
|
||||
plt.ioff()
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
138
network/transcrowd_gap.py
Normal file
138
network/transcrowd_gap.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
r"""Transformer-encoder-regressor, adapted for reverse-perspective network.
|
||||
|
||||
This model is identical to the *TransCrowd* [#]_ model, which we use for the
|
||||
subproblem of actually counting the heads in each *transformed* raw image.
|
||||
|
||||
.. [#] Liang, D., Chen, X., Xu, W., Zhou, Y., & Bai, X. (2022).
|
||||
Transcrowd: weakly-supervised crowd counting with transformers.
|
||||
Science China Information Sciences, 65(6), 160104.
|
||||
"""
|
||||
from typing import Optional
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# The original paper uses timm to create and import/export custom models,
|
||||
# so we follow suit
|
||||
from timm.models.vision_transformer import VisionTransformer, _cfg
|
||||
from timm.models.registry import register_model
|
||||
from timm.models.layers import trunc_normal_
|
||||
|
||||
from .stn import STNet
|
||||
from .glue import SquareCropTransformLayer
|
||||
|
||||
class VisionTransformerGAP(VisionTransformer):
|
||||
def __init__(self, img_size: int, *args, **kwargs):
|
||||
super().__init__(img_size=img_size, *args, **kwargs)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
# That {p_1, p_2, ..., p_N} pos embedding
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches + 1, self.embed_dim)
|
||||
)
|
||||
# Fill self.pos_embed with N(0, 1) truncated to |0.2 * std|.
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
|
||||
# The "regression head"
|
||||
self.output1 = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(in_features=6912 * 4, out_features=128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(p=0.5),
|
||||
nn.Linear(in_features=128, out_features=1),
|
||||
)
|
||||
self.output1.apply(self._init_weights)
|
||||
|
||||
# glue layer -- since we delay image cropping here
|
||||
self.glue = SquareCropTransformLayer(img_size)
|
||||
|
||||
def forward_features(self, x):
|
||||
B = x.shape[0]
|
||||
|
||||
# 3.2 Patch embed
|
||||
x = self.patch_embed(x)
|
||||
|
||||
# ViT: Classification token
|
||||
# This idea originated from BERT.
|
||||
# Essentially, because we are performing encoding without decoding, we
|
||||
# cannot fix the output dimensionality -- which the classification
|
||||
# problem absolutely needs. Instead, we use the classification token as
|
||||
# the sole input which the transformer would need to learn to encode
|
||||
# whatever it learnt from input into that token.
|
||||
# Source: https://datascience.stackexchange.com/a/110637
|
||||
# That said, I don't think this is useful for GAP...
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1) # [[cls_token, x_i, ...]...]
|
||||
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x) # [XXX] Drop some patches out -- or not?
|
||||
|
||||
# 3.3 Transformer-encoder
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
# Normalize
|
||||
x = self.norm(x)
|
||||
|
||||
# Remove the classification token
|
||||
x = x[:, 1:]
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x, t):
|
||||
with torch.no_grad():
|
||||
x, t = self.glue(x, t)
|
||||
x = self.forward_features(x) # Compute encoding
|
||||
x = F.adaptive_avg_pool1d(x, (48))
|
||||
x = x.view(x.shape[0], -1) # Move data for regression head
|
||||
# Resized to ???
|
||||
x = self.output1(x) # Regression head
|
||||
return x, t
|
||||
|
||||
|
||||
class STNet_VisionTransformerGAP(VisionTransformerGAP):
|
||||
def __init__(self, img_shape: torch.Size, img_size: int, *args, **kwargs):
|
||||
super(STNet_VisionTransformerGAP, self).__init__(img_size, *args, **kwargs)
|
||||
self.stnet = STNet(img_shape)
|
||||
|
||||
def forward(self, x, t):
|
||||
x, t = self.stnet(x, t)
|
||||
return super(STNet_VisionTransformerGAP, self).forward(x, t)
|
||||
|
||||
|
||||
@register_model
|
||||
def base_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
|
||||
model = VisionTransformerGAP(
|
||||
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
||||
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
**kwargs
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
|
||||
if pth_tar is not None:
|
||||
checkpoint = torch.load(pth_tar)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
print("Loaded pre-trained pth.tar from \'{}\'".format(pth_tar))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def stn_patch16_384_gap(pth_tar: Optional[str] = None, **kwargs):
|
||||
model = STNet_VisionTransformerGAP(
|
||||
img_shape=torch.Size((3, 1152, 768)),
|
||||
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
||||
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
**kwargs
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
|
||||
if pth_tar is not None:
|
||||
checkpoint = torch.load(pth_tar)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
print("Loaded pre-trained pth.tar from \'{}\'".format(pth_tar))
|
||||
|
||||
return model
|
||||
Loading…
Add table
Add a link
Reference in a new issue