Skip to content

Commit

Permalink
Adding: lost updates
Browse files Browse the repository at this point in the history
  • Loading branch information
KanishkNavale committed Apr 11, 2024
1 parent 9657ff1 commit 2733494
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
24 changes: 19 additions & 5 deletions heimdall/transformers/multi_head_pooled_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ def __init__(
self,
input_dim: int,
head_dim: int = 512,
n_head: int = 8,
n_head: int = 4,
has_cls_token: bool = True,
pool: bool = True,
) -> None:
super().__init__()

Expand All @@ -95,10 +96,23 @@ def __init__(
self.K = torch.nn.Linear(input_dim, head_dim * n_head)
self.V = torch.nn.Linear(input_dim, head_dim * n_head)

self.PQ = AttentionPooler(head_dim, head_dim, has_cls_token=has_cls_token)
self.PK = AttentionPooler(head_dim, head_dim, has_cls_token=has_cls_token)
self.PV = AttentionPooler(head_dim, head_dim, has_cls_token=has_cls_token)
self.PX = AttentionPooler(input_dim, input_dim, has_cls_token=has_cls_token)
patch_size = 2 if pool else 1

self.PQ = AttentionPooler(
head_dim, head_dim, patch_size=patch_size, has_cls_token=has_cls_token
)
self.PK = AttentionPooler(
head_dim, head_dim, patch_size=patch_size, has_cls_token=has_cls_token
)
self.PV = AttentionPooler(
head_dim, head_dim, patch_size=patch_size, has_cls_token=has_cls_token
)
self.PX = torch.nn.Conv2d(
input_dim,
head_dim,
kernel_size=patch_size,
stride=patch_size,
)

self.output_layer_norm = torch.nn.LayerNorm(input_dim)

Expand Down
17 changes: 9 additions & 8 deletions heimdall/transformers/vision/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@


class MultiScaleEncoderLayer(torch.nn.Module):
def __init__(self, input_dim: int, head_dim: int, n_heads: int) -> None:
def __init__(self, input_dim: int, head_dim: int, n_heads: int, pool: bool) -> None:
super().__init__()

self.ln1 = LayerNorm(input_dim)
self.ln2 = LayerNorm(input_dim)
self.attn = MultiHeadPooledAttention(input_dim, head_dim, n_heads)
self.attn = MultiHeadPooledAttention(input_dim, head_dim, n_heads, pool=pool)
self.mlp = MLP(input_dim)

def forward(self, x: torch.Tensor, thw_shape: Tuple[int, int, int]) -> torch.Tensor:
Expand All @@ -27,12 +27,13 @@ def __init__(
) -> None:
super().__init__()

self.layers = torch.nn.ModuleList(
[
MultiScaleEncoderLayer(input_dim, head_dim, n_heads)
for _ in range(n_layers)
]
)
self.layers = torch.nn.ModuleList([])

for i in range(n_layers):
pool = i == 0
self.layers.append(
MultiScaleEncoderLayer(input_dim, head_dim, n_heads, pool=pool)
)

def forward(self, x: torch.Tensor, thw_shape: Tuple[int, int, int]) -> torch.Tensor:
for layer in self.layers:
Expand Down

0 comments on commit 2733494

Please sign in to comment.