Skip to content

Commit

Permalink
original alphafold3 learning rates etc., add mini-alphafold3 configs
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Aug 24, 2024
1 parent 9d492a3 commit be9acdf
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 10 deletions.
6 changes: 3 additions & 3 deletions configs/model/alphafold3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# _target_: src.models.model_wrapper.AlphaFoldWrapper
# config:
optimizer:
_target_: torch.optim.Adam # deepspeed.ops.adam.FusedAdam
_target_: deepspeed.ops.adam.FusedAdam # torch.optim.Adam #
_partial_: true
lr: 0.00018
lr: 0.0018
betas:
- 0.9
- 0.95
Expand All @@ -17,7 +17,7 @@ scheduler:
last_epoch: -1
verbose: false
base_lr: 0.0 # the starting learning rate
max_lr: 0.00018
max_lr: 0.0018
warmup_no_steps: 1000
start_decay_after_n_steps: 50_000
decay_every_n_steps: 50_000
Expand Down
176 changes: 176 additions & 0 deletions configs/model/mini-alphafold3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# AlphaFold3 configs
# _target_: src.models.model_wrapper.AlphaFoldWrapper
# config:
optimizer:
_target_: deepspeed.ops.adam.FusedAdam # torch.optim.Adam #
_partial_: true
lr: 0.0018
betas:
- 0.9
- 0.95
eps: 1e-08
weight_decay: 0.0

scheduler:
_target_: src.utils.lr_schedulers.AlphaFoldLRScheduler
_partial_: true
last_epoch: -1
verbose: false
base_lr: 0.0 # the starting learning rate
max_lr: 0.0018
warmup_no_steps: 1000
start_decay_after_n_steps: 50_000
decay_every_n_steps: 50_000
decay_factor: 0.95

# Loss configs
loss:
mse_loss:
sd_data: 16.0
weight: 4.0
smooth_lddt_loss:
weight: 4.0
epsilon: 1e-5

distogram:
min_bin: 0.0
max_bin: 32.0
no_bins: 64
eps: 0.000006 # 1e-6
weight: 0.03

experimentally_resolved:
eps: 0.00000001 # 1e-8,
# min_resolution: 0.1,
# max_resolution: 3.0,
weight: 0.0004

plddt_loss:
min_resolution: 0.1
max_resolution: 3.0
cutoff: 15.0
no_bins": 50
eps: 0.0000000001 # 1e-10,
weight: 0.0004


# TODO: fix the model.model notation in interpolation
model:
c_token: 128 # the token representation dim original=384
c_pair: 128 # the pair representation dim
c_atom: 128 # the atom representation dim
c_atompair: 16 # the atom pair representation dim

# Pair stack parameters (used in Pairformer, MSA module, and Confidence head)
c_hidden_tri_mul: 128 # the hidden dim for the triangle multiplicative update
c_hidden_pair_attn: 32 # the hidden dim for the pair attention ${common.c_hidden_pair_attn}
no_heads_tri_attn: 4
transition_n: 4
pair_dropout: 0.25
fuse_projection_weights: false
blocks_per_ckpt: 1 # number of blocks per checkpoint, if none, no checkpointing
clear_cache_between_blocks: false # whether to clear GPU memory cache between blocks
# Pairformer attention pair bias
no_heads_single_attn: 16

# Input Embedder
input_embedder:
c_token: ${model.model.c_token}
c_trunk_pair: ${model.model.c_pair}
c_atom: ${model.model.c_atom}
c_atompair: ${model.model.c_atompair}

# MSA module
msa_module:
no_blocks: 4
c_msa: 64
c_token: ${model.model.c_token}
c_z: ${model.model.c_pair}
c_hidden: 32
no_heads: 8
c_hidden_tri_mul: ${model.model.c_hidden_tri_mul}
c_hidden_pair_attn: ${model.model.c_hidden_pair_attn}
no_heads_tri_attn: ${model.model.no_heads_tri_attn}
transition_n: ${model.model.transition_n}
pair_dropout: ${model.model.pair_dropout}
fuse_projection_weights: ${model.model.fuse_projection_weights}
clear_cache_between_blocks: ${model.model.clear_cache_between_blocks}
blocks_per_ckpt: ${model.model.blocks_per_ckpt}
inf: 1e8

# Template Embedder
template_embedder:
no_blocks: 2
c_template: 64
c_z: ${model.model.c_pair}
clear_cache_between_blocks: ${model.model.clear_cache_between_blocks}

# PairformerStack
pairformer_stack:
c_s: ${model.model.c_token}
c_z: ${model.model.c_pair}
no_blocks: 12
c_hidden_mul: ${model.model.c_hidden_tri_mul}
c_hidden_pair_attn: ${model.model.c_hidden_pair_attn}
no_heads_tri_attn: ${model.model.no_heads_tri_attn}
no_heads_single_attn: ${model.model.no_heads_single_attn}
transition_n: ${model.model.transition_n}
pair_dropout: ${model.model.pair_dropout}
fuse_projection_weights: ${model.model.fuse_projection_weights}
blocks_per_ckpt: ${model.model.blocks_per_ckpt}
clear_cache_between_blocks: false
inf: 1e8

# Diffusion module
diffusion_module:
c_atom: ${model.model.c_atom}
c_atompair: ${model.model.c_atompair}
c_token: ${model.model.c_token}
c_tokenpair: ${model.model.c_pair}
atom_encoder_blocks: 3
atom_encoder_heads: 16
dropout: 0.0
atom_attention_n_queries: 32
atom_attention_n_keys: 128
atom_decoder_blocks: 3
atom_decoder_heads: 16
token_transformer_blocks: 6
token_transformer_heads: 16
sd_data: 16.0
s_max: 160.0
s_min: 0.0004
p: 7.0
clear_cache_between_blocks: ${model.model.clear_cache_between_blocks}
blocks_per_ckpt: ${model.model.blocks_per_ckpt}

confidence_head:
c_s: ${model.model.c_token}
c_z: ${model.model.c_pair}
no_blocks: 4
no_bins_pde: 64
no_bins_plddt: 64
no_bins_pae: 64
c_hidden_mul: ${model.model.c_hidden_tri_mul}
c_hidden_pair_attn: ${model.model.c_hidden_pair_attn}
no_heads_tri_attn: ${model.model.no_heads_tri_attn}
no_heads_single_attn: ${model.model.no_heads_single_attn}
transition_n: ${model.model.transition_n}
pair_dropout: ${model.model.pair_dropout}
fuse_projection_weights: ${model.model.fuse_projection_weights}

distogram_head:
c_z: ${model.model.c_pair}
no_bins: 64

# Exponential moving average decay rate
ema_decay: 0.999

globals:
chunk_size: null # 4
# Use DeepSpeed memory-efficient attention kernel in supported modules.
use_deepspeed_evo_attention: true
samples_per_trunk: 48 # Number of diffusion module replicas per trunk
rollout_samples_per_trunk: 1 # Number of mini rollouts per trunk
eps: 0.00000001
# internal precision of float32 matrix multiplications. "high" or "medium" will utilize Tensor cores
matmul_precision: "high"
4 changes: 2 additions & 2 deletions configs/model/proteus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ _target_: src.models.proteus_module.ProteusLitModule
optimizer:
_target_: torch.optim.Adam # deepspeed.ops.adam.FusedAdam
_partial_: true
lr: 0.00018 # 0.0018
lr: 0.0018
betas:
- 0.9
- 0.95
Expand All @@ -16,7 +16,7 @@ scheduler:
last_epoch: -1
verbose: false
base_lr: 0.0 # the starting learning rate
max_lr: 0.00018
max_lr: 0.0018
warmup_no_steps: 1000
start_decay_after_n_steps: 50_000
decay_every_n_steps: 50_000
Expand Down
8 changes: 4 additions & 4 deletions configs/model/small-alphafold3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# _target_: src.models.model_wrapper.AlphaFoldWrapper
# config:
optimizer:
_target_: torch.optim.Adam # deepspeed.ops.adam.FusedAdam
_target_: deepspeed.ops.adam.FusedAdam # torch.optim.Adam #
_partial_: true
lr: 0.00018
lr: 0.0018
betas:
- 0.9
- 0.95
Expand All @@ -17,7 +17,7 @@ scheduler:
last_epoch: -1
verbose: false
base_lr: 0.0 # the starting learning rate
max_lr: 0.00018
max_lr: 0.0018
warmup_no_steps: 1000
start_decay_after_n_steps: 50_000
decay_every_n_steps: 50_000
Expand Down Expand Up @@ -130,7 +130,7 @@ model:
atom_encoder_blocks: 3
atom_encoder_heads: 16
dropout: 0.0
atom_attention_n_queries: 32 # TODO: with sliding window attention this is not used.
atom_attention_n_queries: 32
atom_attention_n_keys: 128
atom_decoder_blocks: 3
atom_decoder_heads: 16
Expand Down
2 changes: 1 addition & 1 deletion src/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def loss(self, out, batch, _return_breakdown=False):
atom_is_dna=batch["ref_mask"].new_zeros(batch["ref_mask"].shape), # (bs, n_atoms)
mask=batch["atom_exists"],
),
"diffusion_loss": lambda: mse_loss(
"mse_loss": lambda: mse_loss(
pred_atoms=out["denoised_atoms"],
gt_atoms=out["augmented_gt_atoms"], # rotated gt atoms from diffusion module
timesteps=out["timesteps"],
Expand Down

0 comments on commit be9acdf

Please sign in to comment.