Added STNet-prepended TransCrowd GAP

This commit is contained in:
Zhengyi Chen 2024-02-27 18:57:13 +00:00
parent c2a1bb46ef
commit 49a913a328
2 changed files with 16 additions and 2 deletions

View file

@ -1,4 +1,6 @@
# stole from pytorch tutorial
# "Great artists steal" -- they say,
# but thieves also steal so you know :P
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -36,7 +38,7 @@ class STNet(nn.Module):
# (3.2) Regressor for the 3 * 2 affine matrix
self.fc_loc = nn.Sequential(
# XXX: Should
# XXX: Dimensionality reduction across channels or not?
nn.Linear(np.prod(self._loc_net_out_shape), 32),
nn.ReLU(True),
nn.Linear(32, 3 * 2)