diff --git a/heimdall/transformers/multi_head_pooled_attention.py b/heimdall/transformers/multi_head_pooled_attention.py index 13a4dda..9134d3d 100644 --- a/heimdall/transformers/multi_head_pooled_attention.py +++ b/heimdall/transformers/multi_head_pooled_attention.py @@ -81,8 +81,9 @@ def __init__( self, input_dim: int, head_dim: int = 512, - n_head: int = 8, + n_head: int = 4, has_cls_token: bool = True, + pool: bool = True, ) -> None: super().__init__() @@ -95,10 +96,23 @@ def __init__( self.K = torch.nn.Linear(input_dim, head_dim * n_head) self.V = torch.nn.Linear(input_dim, head_dim * n_head) - self.PQ = AttentionPooler(head_dim, head_dim, has_cls_token=has_cls_token) - self.PK = AttentionPooler(head_dim, head_dim, has_cls_token=has_cls_token) - self.PV = AttentionPooler(head_dim, head_dim, has_cls_token=has_cls_token) - self.PX = AttentionPooler(input_dim, input_dim, has_cls_token=has_cls_token) + patch_size = 2 if pool else 1 + + self.PQ = AttentionPooler( + head_dim, head_dim, patch_size=patch_size, has_cls_token=has_cls_token + ) + self.PK = AttentionPooler( + head_dim, head_dim, patch_size=patch_size, has_cls_token=has_cls_token + ) + self.PV = AttentionPooler( + head_dim, head_dim, patch_size=patch_size, has_cls_token=has_cls_token + ) + self.PX = torch.nn.Conv2d( + input_dim, + head_dim, + kernel_size=patch_size, + stride=patch_size, + ) self.output_layer_norm = torch.nn.LayerNorm(input_dim) diff --git a/heimdall/transformers/vision/encoder.py b/heimdall/transformers/vision/encoder.py index 6e726f7..3ad67e4 100644 --- a/heimdall/transformers/vision/encoder.py +++ b/heimdall/transformers/vision/encoder.py @@ -7,12 +7,12 @@ class MultiScaleEncoderLayer(torch.nn.Module): - def __init__(self, input_dim: int, head_dim: int, n_heads: int) -> None: + def __init__(self, input_dim: int, head_dim: int, n_heads: int, pool: bool) -> None: super().__init__() self.ln1 = LayerNorm(input_dim) self.ln2 = LayerNorm(input_dim) - self.attn = MultiHeadPooledAttention(input_dim, head_dim, n_heads) + self.attn = MultiHeadPooledAttention(input_dim, head_dim, n_heads, pool=pool) self.mlp = MLP(input_dim) def forward(self, x: torch.Tensor, thw_shape: Tuple[int, int, int]) -> torch.Tensor: @@ -27,12 +27,13 @@ def __init__( ) -> None: super().__init__() - self.layers = torch.nn.ModuleList( - [ - MultiScaleEncoderLayer(input_dim, head_dim, n_heads) - for _ in range(n_layers) - ] - ) + self.layers = torch.nn.ModuleList([]) + + for i in range(n_layers): + pool = i == 0 + self.layers.append( + MultiScaleEncoderLayer(input_dim, head_dim, n_heads, pool=pool) + ) def forward(self, x: torch.Tensor, thw_shape: Tuple[int, int, int]) -> torch.Tensor: for layer in self.layers: