Skip to content

Commit

Permalink
implemented templates
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Aug 13, 2024
1 parent 7bd3f73 commit db35c6a
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 58 deletions.
7 changes: 6 additions & 1 deletion configs/data/erebor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -220,5 +220,10 @@ config:
- "between_segment_residues"
- "deletion_matrix"
- "no_recycling_iters"
use_templates: false
use_templates: true
use_template_torsion_angles: false
template_features:
- "template_all_atom_positions"
- "template_sum_probs"
- "template_aatype"
- "template_all_atom_mask"
35 changes: 35 additions & 0 deletions configs/experiment/template.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /data: erebor
- override /model: alphafold3
- override /trainer: ddp

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["single-chain", "template", "smooth_lddt_loss"]

seed: 12345

model:
optimizer:
lr: 0.002
net:
lin1_size: 128
lin2_size: 256
lin3_size: 64
compile: false

data:
batch_size: 64

logger:
wandb:
tags: ${tags}
group: "mnist"
aim:
experiment: "mnist"
9 changes: 8 additions & 1 deletion configs/model/alphafold3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ model:
blocks_per_ckpt: ${model.model.blocks_per_ckpt}
inf: 1e8

# Template Embedder
template_embedder:
no_blocks: 2
c_template: 64
c_z: ${model.model.c_pair}
clear_cache_between_blocks: ${model.model.clear_cache_between_blocks}

# PairformerStack
pairformer_stack:
c_s: ${model.model.c_token}
Expand Down Expand Up @@ -172,4 +179,4 @@ globals:
rollout_samples_per_trunk: 1 # Number of mini rollouts per trunk
eps: 0.00000001
# internal precision of float32 matrix multiplications. "high" or "medium" will utilize Tensor cores
matmul_precision: "medium"
matmul_precision: "high"
13 changes: 11 additions & 2 deletions configs/model/mini-af3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ model:
blocks_per_ckpt: ${model.model.blocks_per_ckpt}
inf: 1e8

# Template Embedder
template_embedder:
no_blocks: 1 # 2
c_template: 64
c_z: ${model.model.c_pair}
clear_cache_between_blocks: ${model.model.clear_cache_between_blocks}

# PairformerStack
pairformer_stack:
c_s: ${model.model.c_token}
Expand Down Expand Up @@ -167,7 +174,9 @@ ema_decay: 0.999
globals:
chunk_size: null # 4
# Use DeepSpeed memory-efficient attention kernel in supported modules.
use_deepspeed_evo_attention: true
samples_per_trunk: 48 # Number of diffusion module replicas per trunk
use_deepspeed_evo_attention: false
samples_per_trunk: 1 # Number of diffusion module replicas per trunk
rollout_samples_per_trunk: 1 # Number of mini rollouts per trunk
eps: 0.00000001
# internal precision of float32 matrix multiplications. "high" or "medium" will utilize Tensor cores
matmul_precision: "high"
92 changes: 60 additions & 32 deletions src/models/embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from src.utils.tensor_utils import add
from src.utils.checkpointing import get_checkpoint_fn
from src.utils.geometry.vector import Vec3Array

from src.common import residue_constants as rc
checkpoint = get_checkpoint_fn()


Expand Down Expand Up @@ -216,11 +216,31 @@ def forward(
return s_inputs, s_init, z_init


# Template Embedder #

def dgram_from_positions(
pos: torch.Tensor,
min_bin: float = 3.25,
max_bin: float = 50.75,
no_bins: int = 39,
inf: float = 1e8,
):
"""Computes a distogram given a position tensor."""
dgram = torch.sum(
(pos[..., None, :] - pos[..., None, :, :]) ** 2, dim=-1, keepdim=True
)
lower = torch.linspace(min_bin, max_bin, no_bins, device=pos.device) ** 2
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)

return dgram


class TemplateEmbedder(nn.Module):
def __init__(
self,
no_blocks: int = 2,
c_template: int = 32,
c_template: int = 64,
c_z: int = 128,
clear_cache_between_blocks: bool = False
):
Expand All @@ -230,7 +250,7 @@ def __init__(
LayerNorm(c_z),
LinearNoBias(c_z, c_template)
)
no_template_features = 108
no_template_features = 84 # 108
self.linear_templ_feat = LinearNoBias(no_template_features, c_template)
self.pair_stack = TemplatePairStack(
no_blocks=no_blocks,
Expand All @@ -240,7 +260,7 @@ def __init__(
self.v_to_u_ln = LayerNorm(c_template)
self.output_proj = nn.Sequential(
nn.ReLU(),
LinearNoBias(c_template, c_template)
LinearNoBias(c_template, c_z)
)
self.clear_cache_between_blocks = clear_cache_between_blocks

Expand All @@ -251,7 +271,6 @@ def forward(
pair_mask: Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> Tensor:
"""
Expand All @@ -260,21 +279,13 @@ def forward(
Args:
features:
Dictionary containing the template features:
"template_restype":
"template_aatype":
[*, N_templ, N_token, 32] One-hot encoding of the template sequence.
"template_pseudo_beta":
[*, N_templ, N_token, 3] coordinates of the representative atoms
"template_pseudo_beta_mask":
[*, N_templ, N_token] Mask indicating if the Cβ (Cα for glycine)
has coordinates for the template at this residue.
"template_backbone_frame_mask":
[*, N_templ, N_token] Mask indicating if coordinates exist for all
atoms required to compute the backbone frame (used in the template_unit_vector feature).
"template_distogram":
[*, N_templ, N_token, N_token, 39] A one-hot pairwise feature indicating the distance
between Cβ atoms (Cα for glycine). Pairwise distances are discretized into 38 bins of
equal width between 3.25 Å and 50.75 Å; one more bin contains any larger distances.
"template_unit_vector":
[*, N_templ, N_token, N_token, 3] The unit vector of the displacement of the Cα atom of
all residues within the local frame of each residue.
"asym_id":
[*, N_token] Unique integer for each distinct chain.
z_trunk:
Expand All @@ -285,41 +296,58 @@ def forward(
Chunk size for the pair stack.
use_deepspeed_evo_attention:
Whether to use DeepSpeed Evo attention within the pair stack.
use_lma:
Whether to use LMA within the pair stack.
inplace_safe:
Whether to use inplace operations.
"""
# Grab data about the inputs
*bs, n_templ, n_token, _ = features["template_restype"].shape
bs = tuple(bs)
bs, n_templ, n_token = features["template_aatype"].shape

# Compute template distogram
template_distogram = dgram_from_positions(features["template_pseudo_beta"])

# Compute the unit vector
# pos = Vec3Array.from_array(features["template_pseudo_beta"])
# template_unit_vector = (pos / pos.norm()).to_tensor().to(template_distogram.dtype)
# print(f"template_unit_vector shape: {template_unit_vector.shape}")

# One-hot encode template restype
template_restype = F.one_hot( # [*, N_templ, N_token, 22]
features["template_aatype"],
num_classes=22 # 20 amino acids + UNK + gap
).to(template_distogram.dtype)

# TODO: add template backbone frame feature

# Compute masks
b_frame_mask = features["template_backbone_frame_mask"]
b_frame_mask = b_frame_mask[..., None] * b_frame_mask[..., None, :] # [*, n_templ, n_token, n_token]
# b_frame_mask = features["template_backbone_frame_mask"]
# b_frame_mask = b_frame_mask[..., None] * b_frame_mask[..., None, :] # [*, n_templ, n_token, n_token]
b_pseudo_beta_mask = features["template_pseudo_beta_mask"]
b_pseudo_beta_mask = b_pseudo_beta_mask[..., None] * b_pseudo_beta_mask[..., None, :]

template_feat = torch.cat([
features["template_distogram"],
b_frame_mask[..., None], # [*, n_templ, n_token, n_token, 1]
features["template_unit_vector"],
template_distogram,
# b_frame_mask[..., None], # [*, n_templ, n_token, n_token, 1]
# template_unit_vector,
b_pseudo_beta_mask[..., None]
], dim=-1)

# Mask out features that are not in the same chain
asym_id_i = features["asym_id"][..., None, :].expand((bs + (n_templ, n_token, n_token)))
asym_id_j = features["asym_id"][..., None].expand((bs + (n_templ, n_token, n_token)))
same_asym_id = torch.isclose(asym_id_i, asym_id_j).to(template_feat.dtype)
asym_id_i = features["asym_id"][..., None, :].expand((bs, n_token, n_token))
asym_id_j = features["asym_id"][..., None].expand((bs, n_token, n_token))
same_asym_id = torch.isclose(asym_id_i, asym_id_j).to(template_feat.dtype) # [*, n_token, n_token]
same_asym_id = same_asym_id.unsqueeze(-3) # for N_templ broadcasting
template_feat = template_feat * same_asym_id.unsqueeze(-1)

# Add residue type information
temp_restype_i = features["template_restype"][..., None, :].expand(bs + (n_templ, n_token, n_token, -1))
temp_restype_j = features["template_restype"][..., None, :, :].expand(bs + (n_templ, n_token, n_token, -1))
temp_restype_i = template_restype[..., None, :].expand((bs, n_templ, n_token, n_token, -1))
temp_restype_j = template_restype[..., None, :, :].expand((bs, n_templ, n_token, n_token, -1))
template_feat = torch.cat([template_feat, temp_restype_i, temp_restype_j], dim=-1)

# Mask the invalid features
template_feat = template_feat * b_pseudo_beta_mask[..., None]

# Run the pair stack per template
single_templates = torch.unbind(template_feat, dim=-4) # each element shape [*, n_token, n_token, c_template]
single_templates = torch.unbind(template_feat, dim=-4) # each element shape [*, n_token, n_token, no_feat]
z_proj = self.proj_pair(z_trunk)
u = torch.zeros_like(z_proj)
for t in range(len(single_templates)):
Expand All @@ -331,7 +359,7 @@ def forward(
pair_mask=pair_mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma, inplace_safe=inplace_safe),
inplace_safe=inplace_safe),
inplace=inplace_safe
)
# Normalize and add to u
Expand Down
31 changes: 15 additions & 16 deletions src/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def __init__(self, config):
**self.config["input_embedder"]
)

# self.template_embedder = TemplateEmbedder(
# **self.config["template_embedder"]
# )
self.template_embedder = TemplateEmbedder(
**self.config["template_embedder"]
)

self.msa_module = MSAModule(
**self.config["msa_module"]
Expand Down Expand Up @@ -106,18 +106,18 @@ def run_trunk(

del s_prev, z_prev, s_init, z_init

# Embed the templates TODO: add the templates
# z = add(z,
# self.template_embedder(
# feats,
# z,
# pair_mask=pair_mask,
# chunk_size=self.globals.chunk_size,
# use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
# inplace_safe=inplace_safe
# ),
# inplace=inplace_safe
# )
# Embed the templates
z = add(z,
self.template_embedder(
feats,
z,
pair_mask=pair_mask,
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
inplace_safe=inplace_safe
),
inplace=inplace_safe
)

# Process the MSA
z = add(z,
Expand Down Expand Up @@ -392,4 +392,3 @@ def forward(self, batch, train: bool = True) -> Dict[str, Tensor]:
# update the outputs dictionary with the confidence head outputs
# outputs.update(confidences)
return outputs

5 changes: 5 additions & 0 deletions src/models/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,17 @@ def reshape_features(batch):
batch["atom_mask"] = batch["all_atom_mask"].reshape(-1, n_res * 4, n_cycle)
batch["token_mask"] = batch["seq_mask"]
batch["token_index"] = batch["residue_index"]

# TODO: One-hot encode the restypes: aatype and template_aatype

# Add assembly features
batch["asym_id"] = torch.zeros_like(batch["seq_mask"]) # int
batch["entity_id"] = torch.zeros_like(batch["seq_mask"]) # int
batch["sym_id"] = torch.zeros_like(batch["seq_mask"]) # , dtype=torch.float32

# Remove gt_features key, the item is usually none
batch.pop("gt_features")

# Compute and add atom_to_token
atom_to_token = torch.arange(n_res).unsqueeze(-1).expand(n_res, 4).long() # (n_res, 4)
atom_to_token = atom_to_token[None, ..., None].expand(bs, n_res, 4, n_cycle) # (bs, n_res, 4, n_cycle)
Expand Down
9 changes: 3 additions & 6 deletions tests/test_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,19 @@ def setUp(self):
self.n_tokens = 64
self.c_template = 32
self.c_z = 128

self.module = TemplateEmbedder(c_template=self.c_template, c_z=self.c_z)

def test_forward(self):
features = {
"template_restype": torch.randn((self.batch_size, self.n_templates, self.n_tokens, 32)),
"template_aatype": torch.randint(0, 20, (self.batch_size, self.n_templates, self.n_tokens)),
"template_pseudo_beta_mask": torch.randn((self.batch_size, self.n_templates, self.n_tokens)),
"template_backbone_frame_mask": torch.randn((self.batch_size, self.n_templates, self.n_tokens)),
"template_distogram": torch.randn((self.batch_size, self.n_templates, self.n_tokens, self.n_tokens, 39)),
"template_unit_vector": torch.randn((self.batch_size, self.n_templates, self.n_tokens, self.n_tokens, 3)),
"template_pseudo_beta": torch.randn((self.batch_size, self.n_templates, self.n_tokens, 3)),
"asym_id": torch.ones((self.batch_size, self.n_tokens))
}
z = torch.randn((self.batch_size, self.n_tokens, self.n_tokens, self.c_z))
pair_mask = torch.randint(0, 2, (self.batch_size, self.n_tokens, self.n_tokens))
embeddings = self.module(features, z, pair_mask)
self.assertEqual(embeddings.shape, (self.batch_size, self.n_tokens, self.n_tokens, self.c_template))
self.assertEqual(embeddings.shape, (self.batch_size, self.n_tokens, self.n_tokens, self.c_z))


class TestInputEmbedder(unittest.TestCase):
Expand Down

0 comments on commit db35c6a

Please sign in to comment.