diff --git a/model/reverse_perspective.py b/model/reverse_perspective.py index d807c89..3bdcaaf 100644 --- a/model/reverse_perspective.py +++ b/model/reverse_perspective.py @@ -24,7 +24,7 @@ class PerspectiveEstimator(nn.Module): Perspective estimator submodule of the wider reverse-perspective network. Input: Pre-processed, uniformly-sized image data - Output: Perspective factor + Output: Perspective factor :math:`\\in \\mathbb{R}` **Note** -------- diff --git a/model/transcrowd_gap.py b/model/transcrowd_gap.py index 48fc520..5624b9f 100644 --- a/model/transcrowd_gap.py +++ b/model/transcrowd_gap.py @@ -21,6 +21,13 @@ from timm.models.registry import register_model from timm.models.layers import trunc_normal_ class VisionTransformer_GAP(VisionTransformer): + # [XXX] It might be a bad idea to use vision transformer for small datasets. + # ref: ViT paper -- "transformers lack some of the inductive biases inherent + # to CNNs, such as translation equivariance and locality". + # convolution is specifically equivariant in translation (linear and + # shift-equivariant), specifically. + # tl;dr: CNNs might perform better for small datasets. Not sure abt performance. + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) num_patches = self.patch_embed.num_patches