diff --git a/src/models/components/atom_attention.py b/src/models/components/atom_attention.py index fc5c466..3c472d8 100644 --- a/src/models/components/atom_attention.py +++ b/src/models/components/atom_attention.py @@ -4,11 +4,13 @@ subset of the nearby 128 atoms (nearby in the sequence space). This gives the network the capacity to learn general rules about local atom constellations, independently of the coarse-grained tokenization where each standard residue is represented with a single token only.""" + import torch from torch import nn import numpy as np from torch.nn import functional as F from src.models.components.primitives import AdaLN, Linear +from src.models.components.transition import ConditionedTransitionBlock from src.utils.tensor_utils import partition_tensor @@ -98,7 +100,7 @@ def extract_local_biases(bias_tensor, partition_increment=32, partition_length=1 return output_tensor -class AttentionPairBias(nn.Module): +class AtomAttentionPairBias(nn.Module): """Implements the sequence-local atom attention with pair bias. This is implemented separately to the attention module that performs full self-attention since sequence-local atom attention requires a memory-efficient implementation. @@ -108,17 +110,13 @@ def __init__( self, embed_dim, num_heads=8, - dropout=0.0, bias=True, - add_bias_kv=False, add_zero_attn=False, - kdim=None, - vdim=None, - batch_first=False, - device=None, - dtype=None, + dropout=0.0, n_queries: int = 32, n_keys: int = 128, c_atom: int = 128, c_pair: int = 16, + device=None, + dtype=None, ): """Initialize the AttentionPairBias module. Args: @@ -129,19 +127,6 @@ def __init__( (i.e. each head will have dimension embed_dim // num_heads). dropout: Dropout probability on attn_output_weights. Default: 0.0 (no dropout). - bias: - If specified, adds bias to input / output projection layers. Default: True. - add_bias_kv: - If specified, adds bias to the key and value sequences at dim=0. Default: False. - add_zero_attn: - If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: False. - kdim: - Total number of features for keys. Default: None (uses kdim=embed_dim). - vdim: - Total number of features for values. Default: None (uses vdim=embed_dim). - batch_first: - If True, then the input and output tensors are provided as (batch, seq, feature). - Default: False (seq, batch, feature). n_queries: The size of the atom window. Defaults to 32. n_keys: @@ -156,12 +141,6 @@ def __init__( self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout - self.bias = bias - self.add_bias_kv = add_bias_kv - self.add_zero_attn = add_zero_attn - self.kdim = kdim - self.vdim = vdim - self.batch_first = batch_first self.device = device self.dtype = dtype self.n_queries = n_queries @@ -201,6 +180,7 @@ def forward(self, atom_single_repr, atom_single_proj, atom_pair_repr, mask=None) tensor of shape (bs, n_atoms) Returns: tensor of shape (bs, n_atoms, embed_dim) after sequence-local atom attention + TODO: implement masking """ # Input projections a = self.ada_ln(atom_single_repr, atom_single_proj) # AdaLN(a, s) @@ -241,6 +221,70 @@ def forward(self, atom_single_repr, atom_single_proj, atom_pair_repr, mask=None) return output +class AtomTransformer(nn.Module): + """AtomTransformer that applies multiple blocks of AttentionPairBias and ConditionedTransitionBlock.""" + def __init__( + self, + embed_dim: int, + num_blocks: int, + num_heads: int = 8, + dropout=0.0, + n_queries: int = 32, + n_keys: int = 128, + c_atom: int = 128, + c_pair: int = 16, + device=None, + dtype=None, + ): + """Initialize the AtomTransformer module. + Args: + embed_dim: + Total dimension of the model. + num_blocks: + Number of blocks. + num_heads: + Number of parallel attention heads. Note that embed_dim will be split across num_heads + (i.e. each head will have dimension embed_dim // num_heads). + dropout: + Dropout probability on attn_output_weights. Default: 0.0 (no dropout). + n_queries: + The size of the atom window. Defaults to 32. + n_keys: + Number of atoms each atom attends to in local sequence space. Defaults to 128. + c_atom: + The number of channels for the atom representation. Defaults to 128. + c_pair: + The number of channels for the pair representation. Defaults to 16. + + """ + super().__init__() + self.embed_dim = embed_dim + self.num_blocks = num_blocks + self.num_heads = num_heads + self.dropout = dropout + self.n_queries = n_queries + self.n_keys = n_keys + self.c_atom = c_atom + self.c_pair = c_pair + self.device = device + self.dtype = dtype + + self.attention_blocks = nn.ModuleList( + [AtomAttentionPairBias(embed_dim, num_heads, dropout, n_queries, n_keys, c_atom, c_pair, device, dtype) + for _ in range(num_blocks)] + ) + self.conditioned_transition_blocks = nn.ModuleList( + [ConditionedTransitionBlock(embed_dim) for _ in range(num_blocks)] + ) + + def forward(self, atom_single_repr, atom_single_proj, atom_pair_repr, mask=None): + """Forward pass of the AtomTransformer module. Algorithm 23 in AlphaFold3 supplement.""" + for i in range(self.num_blocks): + b = self.attention_blocks[i](atom_single_repr, atom_single_proj, atom_pair_repr, mask) + atom_single_repr = b + self.conditioned_transition_blocks[i](atom_single_repr, b) + return atom_single_repr + + class AtomAttentionEncoder(nn.Module): pass diff --git a/tests/test_atom_attention.py b/tests/test_atom_attention.py index 59d5ae3..c2e570f 100644 --- a/tests/test_atom_attention.py +++ b/tests/test_atom_attention.py @@ -1,7 +1,7 @@ import unittest import torch import torch.nn as nn -from src.models.components.atom_attention import AttentionPairBias +from src.models.components.atom_attention import AtomAttentionPairBias class TestAttentionPairBias(unittest.TestCase): @@ -20,25 +20,16 @@ def setUp(self): def test_module_instantiation(self): """Test instantiation of the module with default parameters.""" - module = AttentionPairBias(embed_dim=self.embed_dim) + module = AtomAttentionPairBias(embed_dim=self.embed_dim) self.assertIsInstance(module, nn.Module) def test_forward_output_shape(self): """Test the forward function output shape.""" - module = AttentionPairBias(embed_dim=self.embed_dim, num_heads=self.num_heads) + module = AtomAttentionPairBias(embed_dim=self.embed_dim, num_heads=self.num_heads) output = module(self.atom_single_repr, self.atom_single_proj, self.atom_pair_repr) expected_shape = (self.batch_size, self.n_atoms, self.embed_dim) self.assertEqual(output.shape, expected_shape) - def test_parameter_effects(self): - """Test effects of different parameter settings.""" - # Test without bias in projections - module_no_bias = AttentionPairBias(embed_dim=self.embed_dim, bias=False) - output_no_bias = module_no_bias(self.atom_single_repr, self.atom_single_proj, self.atom_pair_repr) - - # Just check if it runs for now, since we are not setting exact expected outcomes - self.assertEqual(output_no_bias.shape, (self.batch_size, self.n_atoms, self.embed_dim)) - # Run the tests if __name__ == '__main__':