diff --git a/configs/data/erebor.yaml b/configs/data/erebor.yaml index a81b1cd..4d84353 100644 --- a/configs/data/erebor.yaml +++ b/configs/data/erebor.yaml @@ -220,5 +220,10 @@ config: - "between_segment_residues" - "deletion_matrix" - "no_recycling_iters" - use_templates: false + use_templates: true use_template_torsion_angles: false + template_features: + - "template_all_atom_positions" + - "template_sum_probs" + - "template_aatype" + - "template_all_atom_mask" diff --git a/configs/experiment/template.yaml b/configs/experiment/template.yaml new file mode 100644 index 0000000..074b263 --- /dev/null +++ b/configs/experiment/template.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: erebor + - override /model: alphafold3 + - override /trainer: ddp + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["single-chain", "template", "smooth_lddt_loss"] + +seed: 12345 + +model: + optimizer: + lr: 0.002 + net: + lin1_size: 128 + lin2_size: 256 + lin3_size: 64 + compile: false + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "mnist" + aim: + experiment: "mnist" diff --git a/configs/model/alphafold3.yaml b/configs/model/alphafold3.yaml index 19db468..382f266 100644 --- a/configs/model/alphafold3.yaml +++ b/configs/model/alphafold3.yaml @@ -103,6 +103,13 @@ model: blocks_per_ckpt: ${model.model.blocks_per_ckpt} inf: 1e8 + # Template Embedder + template_embedder: + no_blocks: 2 + c_template: 64 + c_z: ${model.model.c_pair} + clear_cache_between_blocks: ${model.model.clear_cache_between_blocks} + # PairformerStack pairformer_stack: c_s: ${model.model.c_token} @@ -172,4 +179,4 @@ globals: rollout_samples_per_trunk: 1 # Number of mini rollouts per trunk eps: 0.00000001 # internal precision of float32 matrix multiplications. "high" or "medium" will utilize Tensor cores - matmul_precision: "medium" + matmul_precision: "high" diff --git a/configs/model/mini-af3.yaml b/configs/model/mini-af3.yaml index 8eb7aa7..b4c0993 100644 --- a/configs/model/mini-af3.yaml +++ b/configs/model/mini-af3.yaml @@ -103,6 +103,13 @@ model: blocks_per_ckpt: ${model.model.blocks_per_ckpt} inf: 1e8 + # Template Embedder + template_embedder: + no_blocks: 1 # 2 + c_template: 64 + c_z: ${model.model.c_pair} + clear_cache_between_blocks: ${model.model.clear_cache_between_blocks} + # PairformerStack pairformer_stack: c_s: ${model.model.c_token} @@ -167,7 +174,9 @@ ema_decay: 0.999 globals: chunk_size: null # 4 # Use DeepSpeed memory-efficient attention kernel in supported modules. - use_deepspeed_evo_attention: true - samples_per_trunk: 48 # Number of diffusion module replicas per trunk + use_deepspeed_evo_attention: false + samples_per_trunk: 1 # Number of diffusion module replicas per trunk rollout_samples_per_trunk: 1 # Number of mini rollouts per trunk eps: 0.00000001 + # internal precision of float32 matrix multiplications. "high" or "medium" will utilize Tensor cores + matmul_precision: "high" diff --git a/src/models/embedders.py b/src/models/embedders.py index 1f47c1c..fd190dd 100644 --- a/src/models/embedders.py +++ b/src/models/embedders.py @@ -11,7 +11,7 @@ from src.utils.tensor_utils import add from src.utils.checkpointing import get_checkpoint_fn from src.utils.geometry.vector import Vec3Array - +from src.common import residue_constants as rc checkpoint = get_checkpoint_fn() @@ -216,11 +216,31 @@ def forward( return s_inputs, s_init, z_init +# Template Embedder # + +def dgram_from_positions( + pos: torch.Tensor, + min_bin: float = 3.25, + max_bin: float = 50.75, + no_bins: int = 39, + inf: float = 1e8, +): + """Computes a distogram given a position tensor.""" + dgram = torch.sum( + (pos[..., None, :] - pos[..., None, :, :]) ** 2, dim=-1, keepdim=True + ) + lower = torch.linspace(min_bin, max_bin, no_bins, device=pos.device) ** 2 + upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) + dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype) + + return dgram + + class TemplateEmbedder(nn.Module): def __init__( self, no_blocks: int = 2, - c_template: int = 32, + c_template: int = 64, c_z: int = 128, clear_cache_between_blocks: bool = False ): @@ -230,7 +250,7 @@ def __init__( LayerNorm(c_z), LinearNoBias(c_z, c_template) ) - no_template_features = 108 + no_template_features = 84 # 108 self.linear_templ_feat = LinearNoBias(no_template_features, c_template) self.pair_stack = TemplatePairStack( no_blocks=no_blocks, @@ -240,7 +260,7 @@ def __init__( self.v_to_u_ln = LayerNorm(c_template) self.output_proj = nn.Sequential( nn.ReLU(), - LinearNoBias(c_template, c_template) + LinearNoBias(c_template, c_z) ) self.clear_cache_between_blocks = clear_cache_between_blocks @@ -251,7 +271,6 @@ def forward( pair_mask: Tensor, chunk_size: Optional[int] = None, use_deepspeed_evo_attention: bool = False, - use_lma: bool = False, inplace_safe: bool = False, ) -> Tensor: """ @@ -260,21 +279,13 @@ def forward( Args: features: Dictionary containing the template features: - "template_restype": + "template_aatype": [*, N_templ, N_token, 32] One-hot encoding of the template sequence. + "template_pseudo_beta": + [*, N_templ, N_token, 3] coordinates of the representative atoms "template_pseudo_beta_mask": [*, N_templ, N_token] Mask indicating if the Cβ (Cα for glycine) has coordinates for the template at this residue. - "template_backbone_frame_mask": - [*, N_templ, N_token] Mask indicating if coordinates exist for all - atoms required to compute the backbone frame (used in the template_unit_vector feature). - "template_distogram": - [*, N_templ, N_token, N_token, 39] A one-hot pairwise feature indicating the distance - between Cβ atoms (Cα for glycine). Pairwise distances are discretized into 38 bins of - equal width between 3.25 Å and 50.75 Å; one more bin contains any larger distances. - "template_unit_vector": - [*, N_templ, N_token, N_token, 3] The unit vector of the displacement of the Cα atom of - all residues within the local frame of each residue. "asym_id": [*, N_token] Unique integer for each distinct chain. z_trunk: @@ -285,41 +296,58 @@ def forward( Chunk size for the pair stack. use_deepspeed_evo_attention: Whether to use DeepSpeed Evo attention within the pair stack. - use_lma: - Whether to use LMA within the pair stack. inplace_safe: Whether to use inplace operations. """ # Grab data about the inputs - *bs, n_templ, n_token, _ = features["template_restype"].shape - bs = tuple(bs) + bs, n_templ, n_token = features["template_aatype"].shape + + # Compute template distogram + template_distogram = dgram_from_positions(features["template_pseudo_beta"]) + + # Compute the unit vector + # pos = Vec3Array.from_array(features["template_pseudo_beta"]) + # template_unit_vector = (pos / pos.norm()).to_tensor().to(template_distogram.dtype) + # print(f"template_unit_vector shape: {template_unit_vector.shape}") + + # One-hot encode template restype + template_restype = F.one_hot( # [*, N_templ, N_token, 22] + features["template_aatype"], + num_classes=22 # 20 amino acids + UNK + gap + ).to(template_distogram.dtype) + + # TODO: add template backbone frame feature # Compute masks - b_frame_mask = features["template_backbone_frame_mask"] - b_frame_mask = b_frame_mask[..., None] * b_frame_mask[..., None, :] # [*, n_templ, n_token, n_token] + # b_frame_mask = features["template_backbone_frame_mask"] + # b_frame_mask = b_frame_mask[..., None] * b_frame_mask[..., None, :] # [*, n_templ, n_token, n_token] b_pseudo_beta_mask = features["template_pseudo_beta_mask"] b_pseudo_beta_mask = b_pseudo_beta_mask[..., None] * b_pseudo_beta_mask[..., None, :] template_feat = torch.cat([ - features["template_distogram"], - b_frame_mask[..., None], # [*, n_templ, n_token, n_token, 1] - features["template_unit_vector"], + template_distogram, + # b_frame_mask[..., None], # [*, n_templ, n_token, n_token, 1] + # template_unit_vector, b_pseudo_beta_mask[..., None] ], dim=-1) # Mask out features that are not in the same chain - asym_id_i = features["asym_id"][..., None, :].expand((bs + (n_templ, n_token, n_token))) - asym_id_j = features["asym_id"][..., None].expand((bs + (n_templ, n_token, n_token))) - same_asym_id = torch.isclose(asym_id_i, asym_id_j).to(template_feat.dtype) + asym_id_i = features["asym_id"][..., None, :].expand((bs, n_token, n_token)) + asym_id_j = features["asym_id"][..., None].expand((bs, n_token, n_token)) + same_asym_id = torch.isclose(asym_id_i, asym_id_j).to(template_feat.dtype) # [*, n_token, n_token] + same_asym_id = same_asym_id.unsqueeze(-3) # for N_templ broadcasting template_feat = template_feat * same_asym_id.unsqueeze(-1) # Add residue type information - temp_restype_i = features["template_restype"][..., None, :].expand(bs + (n_templ, n_token, n_token, -1)) - temp_restype_j = features["template_restype"][..., None, :, :].expand(bs + (n_templ, n_token, n_token, -1)) + temp_restype_i = template_restype[..., None, :].expand((bs, n_templ, n_token, n_token, -1)) + temp_restype_j = template_restype[..., None, :, :].expand((bs, n_templ, n_token, n_token, -1)) template_feat = torch.cat([template_feat, temp_restype_i, temp_restype_j], dim=-1) + # Mask the invalid features + template_feat = template_feat * b_pseudo_beta_mask[..., None] + # Run the pair stack per template - single_templates = torch.unbind(template_feat, dim=-4) # each element shape [*, n_token, n_token, c_template] + single_templates = torch.unbind(template_feat, dim=-4) # each element shape [*, n_token, n_token, no_feat] z_proj = self.proj_pair(z_trunk) u = torch.zeros_like(z_proj) for t in range(len(single_templates)): @@ -331,7 +359,7 @@ def forward( pair_mask=pair_mask, chunk_size=chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, - use_lma=use_lma, inplace_safe=inplace_safe), + inplace_safe=inplace_safe), inplace=inplace_safe ) # Normalize and add to u diff --git a/src/models/model.py b/src/models/model.py index cb1e003..4363512 100644 --- a/src/models/model.py +++ b/src/models/model.py @@ -22,9 +22,9 @@ def __init__(self, config): **self.config["input_embedder"] ) - # self.template_embedder = TemplateEmbedder( - # **self.config["template_embedder"] - # ) + self.template_embedder = TemplateEmbedder( + **self.config["template_embedder"] + ) self.msa_module = MSAModule( **self.config["msa_module"] @@ -106,18 +106,18 @@ def run_trunk( del s_prev, z_prev, s_init, z_init - # Embed the templates TODO: add the templates - # z = add(z, - # self.template_embedder( - # feats, - # z, - # pair_mask=pair_mask, - # chunk_size=self.globals.chunk_size, - # use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention, - # inplace_safe=inplace_safe - # ), - # inplace=inplace_safe - # ) + # Embed the templates + z = add(z, + self.template_embedder( + feats, + z, + pair_mask=pair_mask, + chunk_size=self.globals.chunk_size, + use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention, + inplace_safe=inplace_safe + ), + inplace=inplace_safe + ) # Process the MSA z = add(z, @@ -392,4 +392,3 @@ def forward(self, batch, train: bool = True) -> Dict[str, Tensor]: # update the outputs dictionary with the confidence head outputs # outputs.update(confidences) return outputs - diff --git a/src/models/model_wrapper.py b/src/models/model_wrapper.py index 4056f45..4623e62 100644 --- a/src/models/model_wrapper.py +++ b/src/models/model_wrapper.py @@ -240,12 +240,17 @@ def reshape_features(batch): batch["atom_mask"] = batch["all_atom_mask"].reshape(-1, n_res * 4, n_cycle) batch["token_mask"] = batch["seq_mask"] batch["token_index"] = batch["residue_index"] + + # TODO: One-hot encode the restypes: aatype and template_aatype + # Add assembly features batch["asym_id"] = torch.zeros_like(batch["seq_mask"]) # int batch["entity_id"] = torch.zeros_like(batch["seq_mask"]) # int batch["sym_id"] = torch.zeros_like(batch["seq_mask"]) # , dtype=torch.float32 + # Remove gt_features key, the item is usually none batch.pop("gt_features") + # Compute and add atom_to_token atom_to_token = torch.arange(n_res).unsqueeze(-1).expand(n_res, 4).long() # (n_res, 4) atom_to_token = atom_to_token[None, ..., None].expand(bs, n_res, 4, n_cycle) # (bs, n_res, 4, n_cycle) diff --git a/tests/test_embedders.py b/tests/test_embedders.py index 8fc6867..9b20020 100644 --- a/tests/test_embedders.py +++ b/tests/test_embedders.py @@ -10,22 +10,19 @@ def setUp(self): self.n_tokens = 64 self.c_template = 32 self.c_z = 128 - self.module = TemplateEmbedder(c_template=self.c_template, c_z=self.c_z) def test_forward(self): features = { - "template_restype": torch.randn((self.batch_size, self.n_templates, self.n_tokens, 32)), + "template_aatype": torch.randint(0, 20, (self.batch_size, self.n_templates, self.n_tokens)), "template_pseudo_beta_mask": torch.randn((self.batch_size, self.n_templates, self.n_tokens)), - "template_backbone_frame_mask": torch.randn((self.batch_size, self.n_templates, self.n_tokens)), - "template_distogram": torch.randn((self.batch_size, self.n_templates, self.n_tokens, self.n_tokens, 39)), - "template_unit_vector": torch.randn((self.batch_size, self.n_templates, self.n_tokens, self.n_tokens, 3)), + "template_pseudo_beta": torch.randn((self.batch_size, self.n_templates, self.n_tokens, 3)), "asym_id": torch.ones((self.batch_size, self.n_tokens)) } z = torch.randn((self.batch_size, self.n_tokens, self.n_tokens, self.c_z)) pair_mask = torch.randint(0, 2, (self.batch_size, self.n_tokens, self.n_tokens)) embeddings = self.module(features, z, pair_mask) - self.assertEqual(embeddings.shape, (self.batch_size, self.n_tokens, self.n_tokens, self.c_template)) + self.assertEqual(embeddings.shape, (self.batch_size, self.n_tokens, self.n_tokens, self.c_z)) class TestInputEmbedder(unittest.TestCase):