generated from ashleve/lightning-hydra-template
-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
original alphafold3 learning rates etc., add mini-alphafold3 configs
- Loading branch information
ardagoreci
committed
Aug 24, 2024
1 parent
9d492a3
commit be9acdf
Showing
5 changed files
with
186 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
# AlphaFold3 configs | ||
# _target_: src.models.model_wrapper.AlphaFoldWrapper | ||
# config: | ||
optimizer: | ||
_target_: deepspeed.ops.adam.FusedAdam # torch.optim.Adam # | ||
_partial_: true | ||
lr: 0.0018 | ||
betas: | ||
- 0.9 | ||
- 0.95 | ||
eps: 1e-08 | ||
weight_decay: 0.0 | ||
|
||
scheduler: | ||
_target_: src.utils.lr_schedulers.AlphaFoldLRScheduler | ||
_partial_: true | ||
last_epoch: -1 | ||
verbose: false | ||
base_lr: 0.0 # the starting learning rate | ||
max_lr: 0.0018 | ||
warmup_no_steps: 1000 | ||
start_decay_after_n_steps: 50_000 | ||
decay_every_n_steps: 50_000 | ||
decay_factor: 0.95 | ||
|
||
# Loss configs | ||
loss: | ||
mse_loss: | ||
sd_data: 16.0 | ||
weight: 4.0 | ||
smooth_lddt_loss: | ||
weight: 4.0 | ||
epsilon: 1e-5 | ||
|
||
distogram: | ||
min_bin: 0.0 | ||
max_bin: 32.0 | ||
no_bins: 64 | ||
eps: 0.000006 # 1e-6 | ||
weight: 0.03 | ||
|
||
experimentally_resolved: | ||
eps: 0.00000001 # 1e-8, | ||
# min_resolution: 0.1, | ||
# max_resolution: 3.0, | ||
weight: 0.0004 | ||
|
||
plddt_loss: | ||
min_resolution: 0.1 | ||
max_resolution: 3.0 | ||
cutoff: 15.0 | ||
no_bins": 50 | ||
eps: 0.0000000001 # 1e-10, | ||
weight: 0.0004 | ||
|
||
|
||
# TODO: fix the model.model notation in interpolation | ||
model: | ||
c_token: 128 # the token representation dim original=384 | ||
c_pair: 128 # the pair representation dim | ||
c_atom: 128 # the atom representation dim | ||
c_atompair: 16 # the atom pair representation dim | ||
|
||
# Pair stack parameters (used in Pairformer, MSA module, and Confidence head) | ||
c_hidden_tri_mul: 128 # the hidden dim for the triangle multiplicative update | ||
c_hidden_pair_attn: 32 # the hidden dim for the pair attention ${common.c_hidden_pair_attn} | ||
no_heads_tri_attn: 4 | ||
transition_n: 4 | ||
pair_dropout: 0.25 | ||
fuse_projection_weights: false | ||
blocks_per_ckpt: 1 # number of blocks per checkpoint, if none, no checkpointing | ||
clear_cache_between_blocks: false # whether to clear GPU memory cache between blocks | ||
# Pairformer attention pair bias | ||
no_heads_single_attn: 16 | ||
|
||
# Input Embedder | ||
input_embedder: | ||
c_token: ${model.model.c_token} | ||
c_trunk_pair: ${model.model.c_pair} | ||
c_atom: ${model.model.c_atom} | ||
c_atompair: ${model.model.c_atompair} | ||
|
||
# MSA module | ||
msa_module: | ||
no_blocks: 4 | ||
c_msa: 64 | ||
c_token: ${model.model.c_token} | ||
c_z: ${model.model.c_pair} | ||
c_hidden: 32 | ||
no_heads: 8 | ||
c_hidden_tri_mul: ${model.model.c_hidden_tri_mul} | ||
c_hidden_pair_attn: ${model.model.c_hidden_pair_attn} | ||
no_heads_tri_attn: ${model.model.no_heads_tri_attn} | ||
transition_n: ${model.model.transition_n} | ||
pair_dropout: ${model.model.pair_dropout} | ||
fuse_projection_weights: ${model.model.fuse_projection_weights} | ||
clear_cache_between_blocks: ${model.model.clear_cache_between_blocks} | ||
blocks_per_ckpt: ${model.model.blocks_per_ckpt} | ||
inf: 1e8 | ||
|
||
# Template Embedder | ||
template_embedder: | ||
no_blocks: 2 | ||
c_template: 64 | ||
c_z: ${model.model.c_pair} | ||
clear_cache_between_blocks: ${model.model.clear_cache_between_blocks} | ||
|
||
# PairformerStack | ||
pairformer_stack: | ||
c_s: ${model.model.c_token} | ||
c_z: ${model.model.c_pair} | ||
no_blocks: 12 | ||
c_hidden_mul: ${model.model.c_hidden_tri_mul} | ||
c_hidden_pair_attn: ${model.model.c_hidden_pair_attn} | ||
no_heads_tri_attn: ${model.model.no_heads_tri_attn} | ||
no_heads_single_attn: ${model.model.no_heads_single_attn} | ||
transition_n: ${model.model.transition_n} | ||
pair_dropout: ${model.model.pair_dropout} | ||
fuse_projection_weights: ${model.model.fuse_projection_weights} | ||
blocks_per_ckpt: ${model.model.blocks_per_ckpt} | ||
clear_cache_between_blocks: false | ||
inf: 1e8 | ||
|
||
# Diffusion module | ||
diffusion_module: | ||
c_atom: ${model.model.c_atom} | ||
c_atompair: ${model.model.c_atompair} | ||
c_token: ${model.model.c_token} | ||
c_tokenpair: ${model.model.c_pair} | ||
atom_encoder_blocks: 3 | ||
atom_encoder_heads: 16 | ||
dropout: 0.0 | ||
atom_attention_n_queries: 32 | ||
atom_attention_n_keys: 128 | ||
atom_decoder_blocks: 3 | ||
atom_decoder_heads: 16 | ||
token_transformer_blocks: 6 | ||
token_transformer_heads: 16 | ||
sd_data: 16.0 | ||
s_max: 160.0 | ||
s_min: 0.0004 | ||
p: 7.0 | ||
clear_cache_between_blocks: ${model.model.clear_cache_between_blocks} | ||
blocks_per_ckpt: ${model.model.blocks_per_ckpt} | ||
|
||
confidence_head: | ||
c_s: ${model.model.c_token} | ||
c_z: ${model.model.c_pair} | ||
no_blocks: 4 | ||
no_bins_pde: 64 | ||
no_bins_plddt: 64 | ||
no_bins_pae: 64 | ||
c_hidden_mul: ${model.model.c_hidden_tri_mul} | ||
c_hidden_pair_attn: ${model.model.c_hidden_pair_attn} | ||
no_heads_tri_attn: ${model.model.no_heads_tri_attn} | ||
no_heads_single_attn: ${model.model.no_heads_single_attn} | ||
transition_n: ${model.model.transition_n} | ||
pair_dropout: ${model.model.pair_dropout} | ||
fuse_projection_weights: ${model.model.fuse_projection_weights} | ||
|
||
distogram_head: | ||
c_z: ${model.model.c_pair} | ||
no_bins: 64 | ||
|
||
# Exponential moving average decay rate | ||
ema_decay: 0.999 | ||
|
||
globals: | ||
chunk_size: null # 4 | ||
# Use DeepSpeed memory-efficient attention kernel in supported modules. | ||
use_deepspeed_evo_attention: true | ||
samples_per_trunk: 48 # Number of diffusion module replicas per trunk | ||
rollout_samples_per_trunk: 1 # Number of mini rollouts per trunk | ||
eps: 0.00000001 | ||
# internal precision of float32 matrix multiplications. "high" or "medium" will utilize Tensor cores | ||
matmul_precision: "high" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters