Extra changes, also not sure if works
This commit is contained in:
parent
f059924b75
commit
c88b938680
2 changed files with 16 additions and 7 deletions
14
model/stn.py
14
model/stn.py
|
|
@ -62,7 +62,7 @@ class STNet(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
# Spatial transformer network forward function
|
# Spatial transformer network forward function
|
||||||
def stn(self, x):
|
def stn(self, x, t):
|
||||||
xs = self.localization_net(x)
|
xs = self.localization_net(x)
|
||||||
xs = xs.view(-1, np.prod(self._loc_net_out_shape)) # -> (N, whatever)
|
xs = xs.view(-1, np.prod(self._loc_net_out_shape)) # -> (N, whatever)
|
||||||
theta = self.fc_loc(xs)
|
theta = self.fc_loc(xs)
|
||||||
|
|
@ -71,12 +71,18 @@ class STNet(nn.Module):
|
||||||
grid = F.affine_grid(theta, x.size())
|
grid = F.affine_grid(theta, x.size())
|
||||||
x = F.grid_sample(x, grid)
|
x = F.grid_sample(x, grid)
|
||||||
|
|
||||||
return x
|
# Do the same transformation to t sans training
|
||||||
|
with torch.no_grad():
|
||||||
|
t = t.view(t.size(0), 1, t.size(1), t.size(2))
|
||||||
|
t = F.grid_sample(t, grid)
|
||||||
|
t = t.squeeze(1)
|
||||||
|
|
||||||
|
return x, t
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, t):
|
||||||
# transform the input, do nothing else
|
# transform the input, do nothing else
|
||||||
return self.stn(x)
|
return self.stn(x, t)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from timm.models.registry import register_model
|
||||||
from timm.models.layers import trunc_normal_
|
from timm.models.layers import trunc_normal_
|
||||||
|
|
||||||
from .stn import STNet
|
from .stn import STNet
|
||||||
|
from .glue import SquareCropTransformLayer
|
||||||
|
|
||||||
class VisionTransformerGAP(VisionTransformer):
|
class VisionTransformerGAP(VisionTransformer):
|
||||||
# [XXX] It might be a bad idea to use vision transformer for small datasets.
|
# [XXX] It might be a bad idea to use vision transformer for small datasets.
|
||||||
|
|
@ -102,10 +103,12 @@ class STNet_VisionTransformerGAP(VisionTransformerGAP):
|
||||||
def __init__(self, img_shape: torch.Size, *args, **kwargs):
|
def __init__(self, img_shape: torch.Size, *args, **kwargs):
|
||||||
super(STNet_VisionTransformerGAP, self).__init__(*args, **kwargs)
|
super(STNet_VisionTransformerGAP, self).__init__(*args, **kwargs)
|
||||||
self.stnet = STNet(img_shape)
|
self.stnet = STNet(img_shape)
|
||||||
|
self.glue = SquareCropTransformLayer(img_size)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, t):
|
||||||
x = self.stnet(x)
|
x, t = self.stnet(x, t)
|
||||||
return super(STNet_VisionTransformerGAP, self).forward(x)
|
x, t = self.glue(x, t)
|
||||||
|
return super(STNet_VisionTransformerGAP, self).forward(x), t
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue