Added STNet-prepended TransCrowd GAP
This commit is contained in:
parent
c2a1bb46ef
commit
49a913a328
2 changed files with 16 additions and 2 deletions
|
|
@ -1,4 +1,6 @@
|
||||||
# stole from pytorch tutorial
|
# stole from pytorch tutorial
|
||||||
|
# "Great artists steal" -- they say,
|
||||||
|
# but thieves also steal so you know :P
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
@ -36,7 +38,7 @@ class STNet(nn.Module):
|
||||||
|
|
||||||
# (3.2) Regressor for the 3 * 2 affine matrix
|
# (3.2) Regressor for the 3 * 2 affine matrix
|
||||||
self.fc_loc = nn.Sequential(
|
self.fc_loc = nn.Sequential(
|
||||||
# XXX: Should
|
# XXX: Dimensionality reduction across channels or not?
|
||||||
nn.Linear(np.prod(self._loc_net_out_shape), 32),
|
nn.Linear(np.prod(self._loc_net_out_shape), 32),
|
||||||
nn.ReLU(True),
|
nn.ReLU(True),
|
||||||
nn.Linear(32, 3 * 2)
|
nn.Linear(32, 3 * 2)
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,9 @@ from timm.models.vision_transformer import VisionTransformer, _cfg
|
||||||
from timm.models.registry import register_model
|
from timm.models.registry import register_model
|
||||||
from timm.models.layers import trunc_normal_
|
from timm.models.layers import trunc_normal_
|
||||||
|
|
||||||
class VisionTransformerGAPwithFeatureMap(VisionTransformer):
|
from .stn import STNet
|
||||||
|
|
||||||
|
class VisionTransformerGAP(VisionTransformer):
|
||||||
# [XXX] It might be a bad idea to use vision transformer for small datasets.
|
# [XXX] It might be a bad idea to use vision transformer for small datasets.
|
||||||
# ref: ViT paper -- "transformers lack some of the inductive biases inherent
|
# ref: ViT paper -- "transformers lack some of the inductive biases inherent
|
||||||
# to CNNs, such as translation equivariance and locality".
|
# to CNNs, such as translation equivariance and locality".
|
||||||
|
|
@ -93,3 +95,13 @@ class VisionTransformerGAPwithFeatureMap(VisionTransformer):
|
||||||
# Resized to ???
|
# Resized to ???
|
||||||
x = self.output1(x) # Regression head
|
x = self.output1(x) # Regression head
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class STNet_VisionTransformerGAP(VisionTransformerGAP):
|
||||||
|
def __init__(self, img_shape: torch.Size, *args, **kwargs):
|
||||||
|
super(STNet_VisionTransformerGAP, self).__init__(*args, **kwargs)
|
||||||
|
self.stnet = STNet(img_shape)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.stnet(x)
|
||||||
|
return super(STNet_VisionTransformerGAP, self).forward(x)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue