Skip to content

Commit

Permalink
Implemented AtomAttentionPairBias
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed May 28, 2024
1 parent efade4c commit 6d5fe3b
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 39 deletions.
98 changes: 71 additions & 27 deletions src/models/components/atom_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
subset of the nearby 128 atoms (nearby in the sequence space). This gives the network the capacity to learn general
rules about local atom constellations, independently of the coarse-grained tokenization where each standard residue
is represented with a single token only."""

import torch
from torch import nn
import numpy as np
from torch.nn import functional as F
from src.models.components.primitives import AdaLN, Linear
from src.models.components.transition import ConditionedTransitionBlock
from src.utils.tensor_utils import partition_tensor


Expand Down Expand Up @@ -98,7 +100,7 @@ def extract_local_biases(bias_tensor, partition_increment=32, partition_length=1
return output_tensor


class AttentionPairBias(nn.Module):
class AtomAttentionPairBias(nn.Module):
"""Implements the sequence-local atom attention with pair bias.
This is implemented separately to the attention module that performs full self-attention
since sequence-local atom attention requires a memory-efficient implementation.
Expand All @@ -108,17 +110,13 @@ def __init__(
self,
embed_dim,
num_heads=8,
dropout=0.0, bias=True,
add_bias_kv=False, add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
device=None,
dtype=None,
dropout=0.0,
n_queries: int = 32,
n_keys: int = 128,
c_atom: int = 128,
c_pair: int = 16,
device=None,
dtype=None,
):
"""Initialize the AttentionPairBias module.
Args:
Expand All @@ -129,19 +127,6 @@ def __init__(
(i.e. each head will have dimension embed_dim // num_heads).
dropout:
Dropout probability on attn_output_weights. Default: 0.0 (no dropout).
bias:
If specified, adds bias to input / output projection layers. Default: True.
add_bias_kv:
If specified, adds bias to the key and value sequences at dim=0. Default: False.
add_zero_attn:
If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: False.
kdim:
Total number of features for keys. Default: None (uses kdim=embed_dim).
vdim:
Total number of features for values. Default: None (uses vdim=embed_dim).
batch_first:
If True, then the input and output tensors are provided as (batch, seq, feature).
Default: False (seq, batch, feature).
n_queries:
The size of the atom window. Defaults to 32.
n_keys:
Expand All @@ -156,12 +141,6 @@ def __init__(
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.bias = bias
self.add_bias_kv = add_bias_kv
self.add_zero_attn = add_zero_attn
self.kdim = kdim
self.vdim = vdim
self.batch_first = batch_first
self.device = device
self.dtype = dtype
self.n_queries = n_queries
Expand Down Expand Up @@ -201,6 +180,7 @@ def forward(self, atom_single_repr, atom_single_proj, atom_pair_repr, mask=None)
tensor of shape (bs, n_atoms)
Returns:
tensor of shape (bs, n_atoms, embed_dim) after sequence-local atom attention
TODO: implement masking
"""
# Input projections
a = self.ada_ln(atom_single_repr, atom_single_proj) # AdaLN(a, s)
Expand Down Expand Up @@ -241,6 +221,70 @@ def forward(self, atom_single_repr, atom_single_proj, atom_pair_repr, mask=None)
return output


class AtomTransformer(nn.Module):
"""AtomTransformer that applies multiple blocks of AttentionPairBias and ConditionedTransitionBlock."""
def __init__(
self,
embed_dim: int,
num_blocks: int,
num_heads: int = 8,
dropout=0.0,
n_queries: int = 32,
n_keys: int = 128,
c_atom: int = 128,
c_pair: int = 16,
device=None,
dtype=None,
):
"""Initialize the AtomTransformer module.
Args:
embed_dim:
Total dimension of the model.
num_blocks:
Number of blocks.
num_heads:
Number of parallel attention heads. Note that embed_dim will be split across num_heads
(i.e. each head will have dimension embed_dim // num_heads).
dropout:
Dropout probability on attn_output_weights. Default: 0.0 (no dropout).
n_queries:
The size of the atom window. Defaults to 32.
n_keys:
Number of atoms each atom attends to in local sequence space. Defaults to 128.
c_atom:
The number of channels for the atom representation. Defaults to 128.
c_pair:
The number of channels for the pair representation. Defaults to 16.
"""
super().__init__()
self.embed_dim = embed_dim
self.num_blocks = num_blocks
self.num_heads = num_heads
self.dropout = dropout
self.n_queries = n_queries
self.n_keys = n_keys
self.c_atom = c_atom
self.c_pair = c_pair
self.device = device
self.dtype = dtype

self.attention_blocks = nn.ModuleList(
[AtomAttentionPairBias(embed_dim, num_heads, dropout, n_queries, n_keys, c_atom, c_pair, device, dtype)
for _ in range(num_blocks)]
)
self.conditioned_transition_blocks = nn.ModuleList(
[ConditionedTransitionBlock(embed_dim) for _ in range(num_blocks)]
)

def forward(self, atom_single_repr, atom_single_proj, atom_pair_repr, mask=None):
"""Forward pass of the AtomTransformer module. Algorithm 23 in AlphaFold3 supplement."""
for i in range(self.num_blocks):
b = self.attention_blocks[i](atom_single_repr, atom_single_proj, atom_pair_repr, mask)
atom_single_repr = b + self.conditioned_transition_blocks[i](atom_single_repr, b)
return atom_single_repr


class AtomAttentionEncoder(nn.Module):
pass

Expand Down
15 changes: 3 additions & 12 deletions tests/test_atom_attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import torch
import torch.nn as nn
from src.models.components.atom_attention import AttentionPairBias
from src.models.components.atom_attention import AtomAttentionPairBias


class TestAttentionPairBias(unittest.TestCase):
Expand All @@ -20,25 +20,16 @@ def setUp(self):

def test_module_instantiation(self):
"""Test instantiation of the module with default parameters."""
module = AttentionPairBias(embed_dim=self.embed_dim)
module = AtomAttentionPairBias(embed_dim=self.embed_dim)
self.assertIsInstance(module, nn.Module)

def test_forward_output_shape(self):
"""Test the forward function output shape."""
module = AttentionPairBias(embed_dim=self.embed_dim, num_heads=self.num_heads)
module = AtomAttentionPairBias(embed_dim=self.embed_dim, num_heads=self.num_heads)
output = module(self.atom_single_repr, self.atom_single_proj, self.atom_pair_repr)
expected_shape = (self.batch_size, self.n_atoms, self.embed_dim)
self.assertEqual(output.shape, expected_shape)

def test_parameter_effects(self):
"""Test effects of different parameter settings."""
# Test without bias in projections
module_no_bias = AttentionPairBias(embed_dim=self.embed_dim, bias=False)
output_no_bias = module_no_bias(self.atom_single_repr, self.atom_single_proj, self.atom_pair_repr)

# Just check if it runs for now, since we are not setting exact expected outcomes
self.assertEqual(output_no_bias.shape, (self.batch_size, self.n_atoms, self.embed_dim))


# Run the tests
if __name__ == '__main__':
Expand Down

0 comments on commit 6d5fe3b

Please sign in to comment.