From fa95f99bde3b9ae2ea4f931744083d7b650eeb95 Mon Sep 17 00:00:00 2001 From: ardagoreci <62720042+ardagoreci@users.noreply.github.com> Date: Fri, 23 Aug 2024 23:53:28 -0700 Subject: [PATCH] add compilation --- configs/model/mini-af3.yaml | 176 --------------------------------- src/models/diffusion_module.py | 1 + src/models/model.py | 1 + 3 files changed, 2 insertions(+), 176 deletions(-) delete mode 100644 configs/model/mini-af3.yaml diff --git a/configs/model/mini-af3.yaml b/configs/model/mini-af3.yaml deleted file mode 100644 index becff93..0000000 --- a/configs/model/mini-af3.yaml +++ /dev/null @@ -1,176 +0,0 @@ -# AlphaFold3 configs -# _target_: src.models.model_wrapper.AlphaFoldWrapper -# config: -optimizer: - _target_: torch.optim.Adam # deepspeed.ops.adam.FusedAdam - _partial_: true - lr: 0.0018 # 0.0018 - betas: - - 0.9 - - 0.95 - eps: 1e-08 - weight_decay: 0.0 - -scheduler: - _target_: src.utils.lr_schedulers.AlphaFoldLRScheduler - _partial_: true - last_epoch: -1 - verbose: false - base_lr: 0.0 # the starting learning rate - max_lr: 0.0018 - warmup_no_steps: 1000 - start_decay_after_n_steps: 50_000 - decay_every_n_steps: 50_000 - decay_factor: 0.95 - -# Loss configs -loss: - mse_loss: - sd_data: 16.0 - weight: 4.0 - smooth_lddt_loss: - weight: 4.0 - epsilon: 1e-5 - - distogram: - min_bin: 2.3125 - max_bin: 21.6875 - no_bins: 64 - eps: 0.000006 # 1e-6 - weight: 0.03 - - experimentally_resolved: - eps: 0.00000001 # 1e-8, - # min_resolution: 0.1, - # max_resolution: 3.0, - weight: 0.0004 - - plddt_loss: - min_resolution: 0.1 - max_resolution: 3.0 - cutoff: 15.0 - no_bins": 50 - eps: 0.0000000001 # 1e-10, - weight: 0.0004 - - -# TODO: fix the model.model notation in interpolation -model: - c_token: 64 # the token representation dim - c_pair: 16 # the pair representation dim - c_atom: 16 # the atom representation dim - c_atompair: 16 # the atom pair representation dim - - # Pair stack parameters (used in Pairformer, MSA module, and Confidence head) - c_hidden_tri_mul: 16 # the hidden dim for the triangle multiplicative update - c_hidden_pair_attn: 16 # the hidden dim for the pair attention ${common.c_hidden_pair_attn} - no_heads_tri_attn: 1 - transition_n: 1 - pair_dropout: 0.25 - fuse_projection_weights: false - blocks_per_ckpt: null # number of blocks per checkpoint, if none, no checkpointing - clear_cache_between_blocks: false # whether to clear GPU memory cache between blocks - # Pairformer attention pair bias - no_heads_single_attn: 1 - - # Input Embedder - input_embedder: - c_token: ${model.model.c_token} - c_trunk_pair: ${model.model.c_pair} - c_atom: ${model.model.c_atom} - c_atompair: ${model.model.c_atompair} - - # MSA module - msa_module: - no_blocks: 1 # 4 - c_msa: 16 - c_token: ${model.model.c_token} - c_z: ${model.model.c_pair} - c_hidden: 8 - no_heads: 1 - c_hidden_tri_mul: ${model.model.c_hidden_tri_mul} - c_hidden_pair_attn: ${model.model.c_hidden_pair_attn} - no_heads_tri_attn: ${model.model.no_heads_tri_attn} - transition_n: ${model.model.transition_n} - pair_dropout: ${model.model.pair_dropout} - fuse_projection_weights: ${model.model.fuse_projection_weights} - clear_cache_between_blocks: ${model.model.clear_cache_between_blocks} - blocks_per_ckpt: ${model.model.blocks_per_ckpt} - inf: 1e8 - - # Template Embedder - template_embedder: - no_blocks: 1 # 2 - c_template: 64 - c_z: ${model.model.c_pair} - clear_cache_between_blocks: ${model.model.clear_cache_between_blocks} - - # PairformerStack - pairformer_stack: - c_s: ${model.model.c_token} - c_z: ${model.model.c_pair} - no_blocks: 1 # 48 - c_hidden_mul: ${model.model.c_hidden_tri_mul} - c_hidden_pair_attn: ${model.model.c_hidden_pair_attn} - no_heads_tri_attn: ${model.model.no_heads_tri_attn} - no_heads_single_attn: ${model.model.no_heads_single_attn} - transition_n: ${model.model.transition_n} - pair_dropout: ${model.model.pair_dropout} - fuse_projection_weights: ${model.model.fuse_projection_weights} - blocks_per_ckpt: ${model.model.blocks_per_ckpt} - clear_cache_between_blocks: false - inf: 1e8 - - # Diffusion module - diffusion_module: - c_atom: ${model.model.c_atom} - c_atompair: ${model.model.c_atompair} - c_token: ${model.model.c_token} - c_tokenpair: ${model.model.c_pair} - atom_encoder_blocks: 1 - atom_encoder_heads: 1 - dropout: 0.0 - atom_attention_n_queries: 32 # TODO: with sliding window attention this is not used. - atom_attention_n_keys: 128 - atom_decoder_blocks: 1 - atom_decoder_heads: 1 - token_transformer_blocks: 1 # 24 - token_transformer_heads: 1 - sd_data: 16.0 - s_max: 160.0 - s_min: 4e-4 - p: 7.0 - clear_cache_between_blocks: ${model.model.clear_cache_between_blocks} - blocks_per_ckpt: ${model.model.blocks_per_ckpt} - - confidence_head: - c_s: 384 # ${model.c_token} - c_z: ${model.model.c_pair} - no_blocks: 4 - no_bins_pde: 64 - no_bins_plddt: 64 - no_bins_pae: 64 - c_hidden_mul: ${model.model.c_hidden_tri_mul} - c_hidden_pair_attn: ${model.model.c_hidden_pair_attn} - no_heads_tri_attn: ${model.model.no_heads_tri_attn} - no_heads_single_attn: ${model.model.no_heads_single_attn} - transition_n: ${model.model.transition_n} - pair_dropout: ${model.model.pair_dropout} - fuse_projection_weights: ${model.model.fuse_projection_weights} - - distogram_head: - c_z: ${model.model.c_pair} - no_bins: 64 - -# Exponential moving average decay rate -ema_decay: 0.999 - -globals: - chunk_size: null # 4 - # Use DeepSpeed memory-efficient attention kernel in supported modules. - use_deepspeed_evo_attention: false - samples_per_trunk: 1 # Number of diffusion module replicas per trunk - rollout_samples_per_trunk: 1 # Number of mini rollouts per trunk - eps: 0.00000001 - # internal precision of float32 matrix multiplications. "high" or "medium" will utilize Tensor cores - matmul_precision: "high" diff --git a/src/models/diffusion_module.py b/src/models/diffusion_module.py index 0484f4b..6acc182 100644 --- a/src/models/diffusion_module.py +++ b/src/models/diffusion_module.py @@ -158,6 +158,7 @@ def rescale_with_updates( r_update_scale = torch.sqrt(noisy_pos_scale) * timesteps.unsqueeze(-1) return noisy_atoms * noisy_pos_scale + r_updates * r_update_scale + @torch.compile def forward( self, noisy_atoms: Tensor, # (bs, S, n_atoms, 3) diff --git a/src/models/model.py b/src/models/model.py index d99623d..d8c4485 100644 --- a/src/models/model.py +++ b/src/models/model.py @@ -59,6 +59,7 @@ def __init__(self, config): LinearNoBias(self.c_z, self.c_z) ) + @torch.compile def run_trunk( self, feats: Dict[str, Tensor],