Skip to content

Commit

Permalink
Fixed MSAModule, implemented trunk iteration of AlphaFold3
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Jul 8, 2024
1 parent 6ad5ab3 commit ca086e3
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 39 deletions.
2 changes: 1 addition & 1 deletion src/models/diffusion_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,6 @@ def forward(
mask=atom_mask, # (bs, n_atoms)
) # (bs, n_atoms, 3)

# Rescale updates to positions and combine with x positions
# Rescale updates to positions and combine with input positions
output_pos = self.rescale_with_updates(atom_pos_updates, noisy_atoms, timesteps)
return output_pos
155 changes: 127 additions & 28 deletions src/models/model.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,143 @@
"""AlphaFold3 model implementation."""


class AlphaFold3:
import torch
from torch import nn
from torch import Tensor
from src.models.embedders import InputEmbedder, TemplateEmbedder
from src.models.pairformer import PairformerStack
from src.models.msa_module import MSAModule
from src.models.diffusion_module import DiffusionModule
from src.models.components.primitives import LinearNoBias, LayerNorm
from src.utils.tensor_utils import add
from typing import Tuple, Dict


class AlphaFold3(nn.Module):
def __init__(self, config):
# InputFeatureEmbedder
# RelativePositionEncoder
# TemplateEmbedder
# MsaModule
# PairformerStack
# DiffusionModule
# ConfidenceHead
pass

def embed_templates(self, batch, feats, z, pair_mask, template_dim, inplace_safe):
pass

def iteration(self, features, prevs, _recycle=True):
# initialize output dictionary
super(AlphaFold3, self).__init__()
self.globals = config.globals
self.config = config.model

self.input_embedder = InputEmbedder(
**self.config["input_embedder"]
)

self.template_embedder = TemplateEmbedder(
**self.config["template_embedder"]
)

self.msa_module = MSAModule(
**self.config["msa_module"]
)

self.pairformer_stack = PairformerStack(
**self.config["pairformer_stack"]
)

self.diffusion_module = DiffusionModule(
**self.config["diffusion_module"]
)

# Projections during recycling
c_token = self.config.input_embedder.c_token
c_trunk_pair = self.config.input_embedder.c_trunk_pair
self.recycling_s_proj = nn.Sequential(
LayerNorm(c_token),
LinearNoBias(c_token, c_token)
)
self.recycling_z_proj = nn.Sequential(
LayerNorm(c_trunk_pair),
LinearNoBias(c_trunk_pair, c_trunk_pair)
)

def run_trunk(
self,
feats: Dict[str, Tensor],
s_inputs: Tensor,
s_init: Tensor,
z_init: Tensor,
s_prev: Tensor,
z_prev: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Run a single recycling iteration.
Args:
feats:
dictionary containing the AlphaFold3 features and a
few additional keys such as "msa_mask" and "token_mask"
s_inputs:
[*, N_token, C_token] single inputs embedding from InputFeatureEmbedder
s_init:
[*, N_token, C_token] initial token representation
z_init:
[*, N_token, N_token, C_z] initial pair representation
s_prev:
[*, N_token, C_token] previous token representation from recycling.
If this is the first iteration, it should be zeros.
z_prev:
[*, N_token, N_token, C_z] previous pair representation from recycling.
If this is the first iteration, it should be zeros.
"""

# This needs to be done manually for DeepSpeed's sake
dtype = next(self.parameters()).dtype
for k in feats:
if feats[k].dtype == torch.float32:
feats[k] = feats[k].to(dtype=dtype)

# dtype cast the features
# Controls whether the model uses in-place operations throughout
# The dual condition accounts for activation checkpoints
inplace_safe = not (self.training or torch.is_grad_enabled())

# Grab some data about the input
# Prep masks
token_mask = feats["token_mask"]
pair_mask = token_mask[..., None] * token_mask[..., None, :]

# Controls whether the model uses in-place operations throughout
# Embed the input features
z = add(z_init, self.recycling_z_proj(z_prev), inplace=inplace_safe)
s = add(s_init, self.recycling_s_proj(s_prev), inplace=inplace_safe)

# Unpack the recycling embeddings. Removing them from the list allows
# them to be freed further down in this function, saving memory

# Initialize the recycling embeddings, if needs be
del s_prev, z_prev, s_init, z_init

# Embed the templates
z = add(z,
self.template_embedder(
feats,
z,
pair_mask=pair_mask,
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe
),
inplace=inplace_safe
)

# Process the MSA
z = add(z,
self.msa_module(
feats=feats,
z=z,
s_inputs=s_inputs,
z_mask=pair_mask,
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe
),
inplace=inplace_safe
)

# Run the pairformer stack

# outputs, s_prev, z_prev

pass
s, z = self.pairformer(
s, z,
single_mask=token_mask,
pair_mask=pair_mask,
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe
)
return s, z

def _disable_activation_checkpointing(self):
pass
Expand Down
80 changes: 71 additions & 9 deletions src/models/msa_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from src.models.components.dropout import DropoutRowwise
from src.utils.tensor_utils import add
from functools import partial
from src.utils.checkpointing import checkpoint_blocks
from typing import Dict
from src.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
checkpoint = get_checkpoint_fn()


class MSAPairWeightedAveraging(nn.Module):
Expand Down Expand Up @@ -299,6 +301,7 @@ def __init__(
self,
no_blocks: int = 4,
c_msa: int = 64,
c_token: int = 384,
c_z: int = 128,
c_hidden: int = 32,
no_heads: int = 8,
Expand All @@ -319,6 +322,8 @@ def __init__(
number of MSAModuleBlocks
c_msa:
MSA representation dim
c_token:
Single representation dim
c_z:
pair representation dim
c_hidden:
Expand Down Expand Up @@ -359,13 +364,15 @@ def __init__(
inf=inf)
for _ in range(no_blocks)
])

# MSA featurization
self.linear_msa_feat = LinearNoBias(34, c_msa)
self.proj_s_inputs = LinearNoBias(c_token, c_msa, init='final')
self.blocks_per_ckpt = blocks_per_ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks

def _prep_blocks(
self,
m: Tensor,
z: Tensor,
msa_mask: Optional[Tensor] = None,
z_mask: Optional[Tensor] = None,
chunk_size: Optional[int] = None,
Expand Down Expand Up @@ -396,22 +403,74 @@ def block_with_cache_clear(block, *args, **kwargs):

return blocks

def init_msa_repr(
self,
feats: Dict[str, Tensor],
s_inputs: Tensor,
msa_mask: Optional[Tensor] = None,
inplace_safe: bool = False,
) -> Tensor:
"""Initializes the MSA representation."""
msa_feats = torch.cat([
feats["msa"],
feats["has_deletion"][..., None],
feats["deletion_value"][..., None]],
dim=-1)
m = self.linear_msa_feat(msa_feats)
m = add(m,
self.proj_s_inputs(s_inputs[..., None, :, :]),
inplace=inplace_safe)
if msa_mask is not None:
m = m * msa_mask[..., None]
return m

def forward(
self,
m: Tensor,
feats: Dict[str, Tensor],
z: Tensor,
msa_mask: Optional[Tensor] = None,
s_inputs: Tensor,
z_mask: Optional[Tensor] = None,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> Tensor:
# TODO: combine the input single representation here
# this module should also receive the features dict and embed MSA features
"""
Args:
feats:
Dictionary containing the MSA features with the following features:
"msa":
[*, N_msa, N_token, 32] One-hot encoding of the processed MSA, using the same classes
as restype.
"has_deletion":
[*, N_msa, N_token] Binary feature indicating if there is a deletion to the left of
each position in the MSA.
"deletion_value":
[*, N_msa, N_token] Raw deletion counts (the number of deletions to the left of each MSA
position) are transformed to [0, 1] using 2/π * arctan(d/3).
"msa_mask":
[*, N_seq, N_token] MSA mask
z:
[*, N_token, N_token, C_z] pair embeddings
s_inputs:
[*, N_token, c_token] single input embeddings
z_mask:
[*, N_token, N_token] pair mask
chunk_size:
chunk size
use_deepspeed_evo_attention:
whether to use Deepspeed's optimized kernels for attention
use_lma:
whether to use low-memory attention. Mutually exclusive with
use_deepspeed_evo_attention.
inplace_safe:
whether to perform ops inplace
"""
# Prep MSA mask
msa_mask = feats["msa_mask"]

# Prep the blocks
blocks = self._prep_blocks(
m=m,
z=z,
msa_mask=msa_mask,
z_mask=z_mask,
chunk_size=chunk_size,
Expand All @@ -423,6 +482,9 @@ def forward(
if not torch.is_grad_enabled():
blocks_per_ckpt = None

# Initialize the MSA embedding
m = checkpoint(self.init_msa_repr, feats, s_inputs, msa_mask, inplace_safe)

# Run with grad checkpointing
m, z = checkpoint_blocks(
blocks,
Expand Down
25 changes: 24 additions & 1 deletion tests/test_msa_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Tests for the MSAModule."""
import torch
import unittest
from src.models.msa_module import MSAPairWeightedAveraging, MSAModuleBlock
from src.models.msa_module import MSAPairWeightedAveraging, MSAModuleBlock, MSAModule


class TestMSAPairWeightedAveraging(unittest.TestCase):
Expand Down Expand Up @@ -46,3 +46,26 @@ def test_forward(self):
self.assertEqual(z_out.shape, (self.batch_size, self.n_tokens, self.n_tokens, self.c_z))


class TestMSAModule(unittest.TestCase):
def setUp(self):
self.batch_size = 1
self.n_tokens = 384
self.n_seq = 3
self.c_msa = 64
self.c_z = 128
self.c_token = 768
self.no_blocks = 2
self.module = MSAModule(c_msa=self.c_msa, c_token=self.c_token, c_z=self.c_z, no_blocks=self.no_blocks)

def test_forward(self):
s_inputs = torch.randn((self.batch_size, self.n_tokens, self.c_token))
z = torch.randn((self.batch_size, self.n_tokens, self.n_tokens, self.c_z))
z_mask = torch.randint(0, 2, (self.batch_size, self.n_tokens, self.n_tokens))
feats = {
"msa": torch.randn((self.batch_size, self.n_seq, self.n_tokens, 32)),
"has_deletion": torch.randn((self.batch_size, self.n_seq, self.n_tokens)),
"deletion_value": torch.randn((self.batch_size, self.n_seq, self.n_tokens)),
"msa_mask": torch.randint(0, 2, (self.batch_size, self.n_seq, self.n_tokens))
}
z_out = self.module(feats, z, s_inputs, z_mask)
self.assertEqual(z_out.shape, (self.batch_size, self.n_tokens, self.n_tokens, self.c_z))

0 comments on commit ca086e3

Please sign in to comment.