Skip to content

Commit

Permalink
Implemented InputEmbedder and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Jul 7, 2024
1 parent 5640d32 commit 6ad5ab3
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 16 deletions.
10 changes: 4 additions & 6 deletions src/models/components/atom_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ def extract_locals(
The value to use for padding.
Returns:
A tensor of shape [batch_size, N_atoms // partition_increment, partition_length, channels].
TODO: my representation of atoms should not be a full pairwise tensor, should always remain local
"""
batch_size, n_atoms, _, channels = bias_tensor.shape
# Pad bias tensor column-wise by n_keys // 2 - n_queries // 2 on each side
Expand Down Expand Up @@ -438,7 +436,7 @@ def map_token_pairs_to_local_atom_pairs(
pair embeddings. For each atom pair (l, m), the corresponding token pair's embeddings are extracted."""
bs, n_atoms = tok_idx.shape
_, n_tokens, _, c_pair = token_pairs.shape
# tok_idx = tok_idx.int() # convert to int for indexing
# tok_idx = tok_idx.long() # convert to int for indexing

# Expand tok_idx for efficient gather operation
tok_idx_l = tok_idx.unsqueeze(2).expand(-1, -1, n_atoms).unsqueeze(-1)
Expand Down Expand Up @@ -476,7 +474,7 @@ def aggregate_atom_to_token(
masked_atom -> legitimate_token
"""
bs, n_atoms, c_atom = atom_representation.shape
# tok_idx = tok_idx.int() # convert to int for indexing
# tok_idx = tok_idx.long() # convert to int for indexing

# Initialize the token representation tensor with zeros
token_representation = torch.zeros((bs, n_tokens, c_atom),
Expand Down Expand Up @@ -622,7 +620,7 @@ def init_pair_repr(
atom pair representations are large and can be checkpointed to reduce memory usage.
Args:
features:
Dictionary of x features.
Dictionary of input features.
atom_single:
[bs, n_atoms, c_atom] The single atom representation from init_single_repr
z_trunk:
Expand Down Expand Up @@ -679,7 +677,7 @@ def init_single_repr(
atom single representations are large and can be checkpointed to reduce memory usage.
Args:
features:
Dictionary of x features.
Dictionary of input features.
s_trunk:
[*, n_tokens, c_token] the token representation from the trunk
noisy_pos:
Expand Down
2 changes: 1 addition & 1 deletion src/models/components/relative_position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def forward(self, features: Dict[str, torch.Tensor], mask=None) -> torch.Tensor:
"""Computes relative position encoding. AlphaFold3 Supplement Algorithm 3.
Args:
features:
x feature dictionary containing:
input feature dictionary containing:
"residue_index":
[*, n_tokens] Residue number in the token's original x chain.
"token_index":
Expand Down
96 changes: 88 additions & 8 deletions src/models/embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def forward(
mask:
[*, N_atoms] mask indicating which atoms are valid (non-padding).
Returns:
[*, N_tokens, c_token] Embedding of the x features.
[*, N_tokens, c_token] Embedding of the input features.
"""
# Encode the input features
output = self.encoder(features=features, mask=mask, n_tokens=n_tokens)
Expand All @@ -109,13 +109,95 @@ def __init__(
c_atom: int = 128,
c_atompair: int = 16,
c_trunk_pair: int = 128,

):
super(InputEmbedder, self).__init__()
pass

def forward(self):
pass
# InputFeatureEmbedder for the s_inputs representation
self.input_feature_embedder = InputFeatureEmbedder(
c_token=c_token,
c_atom=c_atom,
c_atompair=c_atompair,
c_trunk_pair=c_trunk_pair
)

# Projections
self.linear_single = LinearNoBias(c_token, c_token)
self.linear_proj_i = LinearNoBias(c_token, c_trunk_pair)
self.linear_proj_j = LinearNoBias(c_token, c_trunk_pair)
self.linear_bonds = LinearNoBias(1, c_trunk_pair)

# Relative position encoding
self.relpos = RelativePositionEncoding(c_pair=c_trunk_pair)

def forward(
self,
features: Dict[str, Tensor],
atom_mask: Optional[Tensor] = None,
token_mask: Optional[Tensor] = None,
inplace_safe: bool = False
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Args:
features:
Dictionary containing the following input features:
"ref_pos":
[*, N_atoms, 3] atom positions in the reference conformers, with
a random rotation and translation applied. Atom positions in Angstroms.
"ref_charge":
[*, N_atoms] Charge for each atom in the reference conformer.
"ref_mask":
[*, N_atoms] Mask indicating which atom slots are used in the reference
conformer.
"ref_element":
[*, N_atoms, 128] One-hot encoding of the element atomic number for each atom
in the reference conformer, up to atomic number 128.
"ref_atom_name_chars":
[*, N_atom, 4, 64] One-hot encoding of the unique atom names in the reference
conformer. Each character is encoded as ord(c - 32), and names are padded to
length 4.
"ref_space_uid":
[*, N_atoms] Numerical encoding of the chain id and residue index associated
with this reference conformer. Each (chain id, residue index) tuple is assigned
an integer on first appearance.
"atom_to_token":
[*, N_atoms] Token index for each atom in the flat atom representation.
"token_bonds":
[*, N_tokens, N_tokens] feature indicating which tokens are bonded to each other.
Only includes polymer-ligand and ligand-ligand bonds
atom_mask:
[*, n_atoms] mask indicating which atoms are valid (non-padding).
token_mask:
[*, n_tokens] mask indicating which tokens are valid (non-padding).
inplace_safe:
whether to use inplace operations
"""
*_, n_tokens, _ = features["token_bonds"].shape

# Single representation with input feature embedder
s_inputs = self.input_feature_embedder(features, n_tokens=n_tokens, mask=atom_mask)

# Projections
s_init = self.linear_single(s_inputs)
z_init = add(
self.linear_proj_i(s_inputs[..., None, :]),
self.linear_proj_j(s_inputs[..., None, :, :]),
inplace=inplace_safe
) # (*, n_tokens, n_tokens, c_trunk_pair)

# Add relative position encoding
z_init = add(
z_init,
self.relpos(features, mask=token_mask),
inplace=inplace_safe
)

# Add token bond information
z_init = add(
z_init,
self.linear_bonds(features["token_bonds"][..., None]),
inplace=inplace_safe
)
return s_inputs, s_init, z_init


class TemplateEmbedder(nn.Module):
Expand Down Expand Up @@ -290,9 +372,7 @@ def __init__(
num_heads=num_heads,
dropout=dropout,
n_queries=n_queries,
n_keys=n_keys,
device=device,
dtype=dtype
n_keys=n_keys
)
self.linear_s_init = LinearNoBias(c_token, c_token)
self.linear_z_col = LinearNoBias(c_token, c_trunk_pair)
Expand Down
43 changes: 42 additions & 1 deletion tests/test_embedders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
import torch
from src.models.embedders import TemplateEmbedder
from src.models.embedders import TemplateEmbedder, InputEmbedder


class TestTemplateEmbedder(unittest.TestCase):
Expand All @@ -26,3 +26,44 @@ def test_forward(self):
pair_mask = torch.randint(0, 2, (self.batch_size, self.n_tokens, self.n_tokens))
embeddings = self.module(features, z, pair_mask)
self.assertEqual(embeddings.shape, (self.batch_size, self.n_tokens, self.n_tokens, self.c_template))


class TestInputEmbedder(unittest.TestCase):
def setUp(self):
self.batch_size = 1
self.n_atoms = 64
self.n_tokens = 64
self.c_token = 384
self.c_atom = 128
self.c_atompair = 16
self.c_z = 128

self.module = InputEmbedder(
c_token=self.c_token,
c_atom=self.c_atom,
c_atompair=self.c_atompair,
c_trunk_pair=self.c_z
)

def test_forward(self):
features = {
'ref_pos': torch.randn(self.batch_size, self.n_atoms, 3),
'ref_charge': torch.randn(self.batch_size, self.n_atoms),
'ref_mask': torch.ones(self.batch_size, self.n_atoms),
'ref_element': torch.randn(self.batch_size, self.n_atoms, 128),
'ref_atom_name_chars': torch.randint(0, 2, (self.batch_size, self.n_atoms, 4, 64)),
'ref_space_uid': torch.ones((self.batch_size, self.n_atoms)).float(),
"residue_index": torch.randint(0, self.n_tokens, (self.batch_size, self.n_tokens)),
"token_index": torch.randint(0, self.n_tokens, (self.batch_size, self.n_tokens)),
"asym_id": torch.randint(0, self.n_tokens, (self.batch_size, self.n_tokens)),
"entity_id": torch.randint(0, self.n_tokens, (self.batch_size, self.n_tokens)),
"sym_id": torch.randint(0, self.n_tokens, (self.batch_size, self.n_tokens)),
'atom_to_token': torch.randint(0, self.n_tokens, (self.batch_size, self.n_atoms)),
'token_bonds': torch.randint(0, 2, (self.batch_size, self.n_tokens, self.n_tokens)).float()
}
token_mask = torch.randint(0, 2, (self.batch_size, self.n_tokens))
atom_mask = torch.randint(0, 2, (self.batch_size, self.n_atoms))
s_inputs, s_init, z_init = self.module(features, atom_mask, token_mask)
self.assertEqual(s_inputs.shape, (self.batch_size, self.n_tokens, self.c_token))
self.assertEqual(s_init.shape, (self.batch_size, self.n_tokens, self.c_token))
self.assertEqual(z_init.shape, (self.batch_size, self.n_tokens, self.n_tokens, self.c_z))

0 comments on commit 6ad5ab3

Please sign in to comment.