From be9acdf9e787a88b31d5a290755ac26ad4272e48 Mon Sep 17 00:00:00 2001 From: ardagoreci <62720042+ardagoreci@users.noreply.github.com> Date: Fri, 23 Aug 2024 23:41:18 -0700 Subject: [PATCH] original alphafold3 learning rates etc., add mini-alphafold3 configs --- configs/model/alphafold3.yaml | 6 +- configs/model/mini-alphafold3.yaml | 176 ++++++++++++++++++++++++++++ configs/model/proteus.yaml | 4 +- configs/model/small-alphafold3.yaml | 8 +- src/utils/loss.py | 2 +- 5 files changed, 186 insertions(+), 10 deletions(-) create mode 100644 configs/model/mini-alphafold3.yaml diff --git a/configs/model/alphafold3.yaml b/configs/model/alphafold3.yaml index ca5d00f..553fb6e 100644 --- a/configs/model/alphafold3.yaml +++ b/configs/model/alphafold3.yaml @@ -2,9 +2,9 @@ # _target_: src.models.model_wrapper.AlphaFoldWrapper # config: optimizer: - _target_: torch.optim.Adam # deepspeed.ops.adam.FusedAdam + _target_: deepspeed.ops.adam.FusedAdam # torch.optim.Adam # _partial_: true - lr: 0.00018 + lr: 0.0018 betas: - 0.9 - 0.95 @@ -17,7 +17,7 @@ scheduler: last_epoch: -1 verbose: false base_lr: 0.0 # the starting learning rate - max_lr: 0.00018 + max_lr: 0.0018 warmup_no_steps: 1000 start_decay_after_n_steps: 50_000 decay_every_n_steps: 50_000 diff --git a/configs/model/mini-alphafold3.yaml b/configs/model/mini-alphafold3.yaml new file mode 100644 index 0000000..09bc110 --- /dev/null +++ b/configs/model/mini-alphafold3.yaml @@ -0,0 +1,176 @@ +# AlphaFold3 configs +# _target_: src.models.model_wrapper.AlphaFoldWrapper +# config: +optimizer: + _target_: deepspeed.ops.adam.FusedAdam # torch.optim.Adam # + _partial_: true + lr: 0.0018 + betas: + - 0.9 + - 0.95 + eps: 1e-08 + weight_decay: 0.0 + +scheduler: + _target_: src.utils.lr_schedulers.AlphaFoldLRScheduler + _partial_: true + last_epoch: -1 + verbose: false + base_lr: 0.0 # the starting learning rate + max_lr: 0.0018 + warmup_no_steps: 1000 + start_decay_after_n_steps: 50_000 + decay_every_n_steps: 50_000 + decay_factor: 0.95 + +# Loss configs +loss: + mse_loss: + sd_data: 16.0 + weight: 4.0 + smooth_lddt_loss: + weight: 4.0 + epsilon: 1e-5 + + distogram: + min_bin: 0.0 + max_bin: 32.0 + no_bins: 64 + eps: 0.000006 # 1e-6 + weight: 0.03 + + experimentally_resolved: + eps: 0.00000001 # 1e-8, + # min_resolution: 0.1, + # max_resolution: 3.0, + weight: 0.0004 + + plddt_loss: + min_resolution: 0.1 + max_resolution: 3.0 + cutoff: 15.0 + no_bins": 50 + eps: 0.0000000001 # 1e-10, + weight: 0.0004 + + +# TODO: fix the model.model notation in interpolation +model: + c_token: 128 # the token representation dim original=384 + c_pair: 128 # the pair representation dim + c_atom: 128 # the atom representation dim + c_atompair: 16 # the atom pair representation dim + + # Pair stack parameters (used in Pairformer, MSA module, and Confidence head) + c_hidden_tri_mul: 128 # the hidden dim for the triangle multiplicative update + c_hidden_pair_attn: 32 # the hidden dim for the pair attention ${common.c_hidden_pair_attn} + no_heads_tri_attn: 4 + transition_n: 4 + pair_dropout: 0.25 + fuse_projection_weights: false + blocks_per_ckpt: 1 # number of blocks per checkpoint, if none, no checkpointing + clear_cache_between_blocks: false # whether to clear GPU memory cache between blocks + # Pairformer attention pair bias + no_heads_single_attn: 16 + + # Input Embedder + input_embedder: + c_token: ${model.model.c_token} + c_trunk_pair: ${model.model.c_pair} + c_atom: ${model.model.c_atom} + c_atompair: ${model.model.c_atompair} + + # MSA module + msa_module: + no_blocks: 4 + c_msa: 64 + c_token: ${model.model.c_token} + c_z: ${model.model.c_pair} + c_hidden: 32 + no_heads: 8 + c_hidden_tri_mul: ${model.model.c_hidden_tri_mul} + c_hidden_pair_attn: ${model.model.c_hidden_pair_attn} + no_heads_tri_attn: ${model.model.no_heads_tri_attn} + transition_n: ${model.model.transition_n} + pair_dropout: ${model.model.pair_dropout} + fuse_projection_weights: ${model.model.fuse_projection_weights} + clear_cache_between_blocks: ${model.model.clear_cache_between_blocks} + blocks_per_ckpt: ${model.model.blocks_per_ckpt} + inf: 1e8 + + # Template Embedder + template_embedder: + no_blocks: 2 + c_template: 64 + c_z: ${model.model.c_pair} + clear_cache_between_blocks: ${model.model.clear_cache_between_blocks} + + # PairformerStack + pairformer_stack: + c_s: ${model.model.c_token} + c_z: ${model.model.c_pair} + no_blocks: 12 + c_hidden_mul: ${model.model.c_hidden_tri_mul} + c_hidden_pair_attn: ${model.model.c_hidden_pair_attn} + no_heads_tri_attn: ${model.model.no_heads_tri_attn} + no_heads_single_attn: ${model.model.no_heads_single_attn} + transition_n: ${model.model.transition_n} + pair_dropout: ${model.model.pair_dropout} + fuse_projection_weights: ${model.model.fuse_projection_weights} + blocks_per_ckpt: ${model.model.blocks_per_ckpt} + clear_cache_between_blocks: false + inf: 1e8 + + # Diffusion module + diffusion_module: + c_atom: ${model.model.c_atom} + c_atompair: ${model.model.c_atompair} + c_token: ${model.model.c_token} + c_tokenpair: ${model.model.c_pair} + atom_encoder_blocks: 3 + atom_encoder_heads: 16 + dropout: 0.0 + atom_attention_n_queries: 32 + atom_attention_n_keys: 128 + atom_decoder_blocks: 3 + atom_decoder_heads: 16 + token_transformer_blocks: 6 + token_transformer_heads: 16 + sd_data: 16.0 + s_max: 160.0 + s_min: 0.0004 + p: 7.0 + clear_cache_between_blocks: ${model.model.clear_cache_between_blocks} + blocks_per_ckpt: ${model.model.blocks_per_ckpt} + + confidence_head: + c_s: ${model.model.c_token} + c_z: ${model.model.c_pair} + no_blocks: 4 + no_bins_pde: 64 + no_bins_plddt: 64 + no_bins_pae: 64 + c_hidden_mul: ${model.model.c_hidden_tri_mul} + c_hidden_pair_attn: ${model.model.c_hidden_pair_attn} + no_heads_tri_attn: ${model.model.no_heads_tri_attn} + no_heads_single_attn: ${model.model.no_heads_single_attn} + transition_n: ${model.model.transition_n} + pair_dropout: ${model.model.pair_dropout} + fuse_projection_weights: ${model.model.fuse_projection_weights} + + distogram_head: + c_z: ${model.model.c_pair} + no_bins: 64 + +# Exponential moving average decay rate +ema_decay: 0.999 + +globals: + chunk_size: null # 4 + # Use DeepSpeed memory-efficient attention kernel in supported modules. + use_deepspeed_evo_attention: true + samples_per_trunk: 48 # Number of diffusion module replicas per trunk + rollout_samples_per_trunk: 1 # Number of mini rollouts per trunk + eps: 0.00000001 + # internal precision of float32 matrix multiplications. "high" or "medium" will utilize Tensor cores + matmul_precision: "high" diff --git a/configs/model/proteus.yaml b/configs/model/proteus.yaml index ca2b834..53380d4 100644 --- a/configs/model/proteus.yaml +++ b/configs/model/proteus.yaml @@ -3,7 +3,7 @@ _target_: src.models.proteus_module.ProteusLitModule optimizer: _target_: torch.optim.Adam # deepspeed.ops.adam.FusedAdam _partial_: true - lr: 0.00018 # 0.0018 + lr: 0.0018 betas: - 0.9 - 0.95 @@ -16,7 +16,7 @@ scheduler: last_epoch: -1 verbose: false base_lr: 0.0 # the starting learning rate - max_lr: 0.00018 + max_lr: 0.0018 warmup_no_steps: 1000 start_decay_after_n_steps: 50_000 decay_every_n_steps: 50_000 diff --git a/configs/model/small-alphafold3.yaml b/configs/model/small-alphafold3.yaml index 2ca0db8..6b5a244 100644 --- a/configs/model/small-alphafold3.yaml +++ b/configs/model/small-alphafold3.yaml @@ -2,9 +2,9 @@ # _target_: src.models.model_wrapper.AlphaFoldWrapper # config: optimizer: - _target_: torch.optim.Adam # deepspeed.ops.adam.FusedAdam + _target_: deepspeed.ops.adam.FusedAdam # torch.optim.Adam # _partial_: true - lr: 0.00018 + lr: 0.0018 betas: - 0.9 - 0.95 @@ -17,7 +17,7 @@ scheduler: last_epoch: -1 verbose: false base_lr: 0.0 # the starting learning rate - max_lr: 0.00018 + max_lr: 0.0018 warmup_no_steps: 1000 start_decay_after_n_steps: 50_000 decay_every_n_steps: 50_000 @@ -130,7 +130,7 @@ model: atom_encoder_blocks: 3 atom_encoder_heads: 16 dropout: 0.0 - atom_attention_n_queries: 32 # TODO: with sliding window attention this is not used. + atom_attention_n_queries: 32 atom_attention_n_keys: 128 atom_decoder_blocks: 3 atom_decoder_heads: 16 diff --git a/src/utils/loss.py b/src/utils/loss.py index 12fc215..51fa345 100644 --- a/src/utils/loss.py +++ b/src/utils/loss.py @@ -330,7 +330,7 @@ def loss(self, out, batch, _return_breakdown=False): atom_is_dna=batch["ref_mask"].new_zeros(batch["ref_mask"].shape), # (bs, n_atoms) mask=batch["atom_exists"], ), - "diffusion_loss": lambda: mse_loss( + "mse_loss": lambda: mse_loss( pred_atoms=out["denoised_atoms"], gt_atoms=out["augmented_gt_atoms"], # rotated gt atoms from diffusion module timesteps=out["timesteps"],