Skip to content

Commit

Permalink
Bug fixes for training
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Jun 4, 2024
1 parent 5432984 commit 956352e
Show file tree
Hide file tree
Showing 12 changed files with 146 additions and 77 deletions.
4 changes: 2 additions & 2 deletions configs/data/protein.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ _target_: src.data.protein_datamodule.ProteinDataModule
data_dir: "./data/"
resolution_thr: 3.5 # Resolution threshold for PDB structures
min_seq_id: 0.3 # Minimum sequence identity for MMSeq2 clustering
crop_size: 256 # The number of residues to crop the proteins to.
crop_size: 384 # The number of residues to crop the proteins to.
max_length: 10_000 # Entries with total length of chains larger than max_length will be disregarded.
use_fraction: 1.0 # the fraction of the clusters to use (first N in alphabetic order)
entry_type: "chain" # { "biounit", "chain", "pair" } the type of entries to generate
Expand All @@ -14,7 +14,7 @@ mask_frac: None # if given, the number of residues to mask is mask_frac times t
mask_sequential: False # if True, the masked residues will be neighbors in the sequence; otherwise geometric mask
mask_whole_chains: False # if True, the whole chain is masked
force_binding_sites_frac: 0.15
batch_size: 8 # The batch size. Defaults to `64`.
batch_size: 2 # The batch size. Defaults to `64`.
num_workers: 7 # The number of workers. Defaults to `0`.
pin_memory: False # Whether to pin memory. Defaults to `False`.
debug: False
49 changes: 49 additions & 0 deletions configs/model/proteus.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
_target_: src.models.proteus_module.ProteusLitModule

optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 0.00018 # 1.8 * 1e-3 # 0.00018
# betas: (0.9, 0.95) # TODO: problem here!
eps: 1e-08
weight_decay: 0.0
fused: false

scheduler:
_target_: torch.optim.lr_scheduler.StepLR
_partial_: true
step_size: 5 * 1e4
gamma: 0.95

diffusion_module:
_target_: src.models.diffusion_module.DiffusionModule
c_atom: 128
c_atompair: 16
c_token: 128 # original: 384
c_tokenpair: 128
n_tokens: 384
atom_encoder_blocks: 3
atom_encoder_heads: 16
dropout: 0.0
atom_attention_n_queries: 32
atom_attention_n_keys: 128
atom_decoder_blocks: 3
atom_decoder_heads: 16
token_transformer_blocks: 12 # original: 24
token_transformer_heads: 16

feature_embedder:
_target_: src.models.input_feature_embedder.ProteusFeatureEmbedder
n_tokens: 384
c_token: 128 # original: 384
c_atom: 128
c_atompair: 16
c_trunk_pair: 128
num_blocks: 3
num_heads: 4
dropout: 0.0
n_queries: 32
n_keys: 128

# compile model for faster training with pytorch 2.0
compile: false
2 changes: 1 addition & 1 deletion configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ defaults:
- _self_
- callbacks: default
- data: protein
- model: kestrel
- model: proteus

- logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- trainer: gpu
Expand Down
4 changes: 2 additions & 2 deletions configs/trainer/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ precision: '32-true' # 'transformer-engine', 'transformer-engine-float16', '16-t
check_val_every_n_epoch: 1

# frequency of logging
log_every_n_steps: 100
log_every_n_steps: 10

# gradient clipping
gradient_clip_val: null
gradient_clip_val: 10.0 # gradient clipping if global norm is greater than 10

# How much of training/test/validation dataset to check.
# Useful when debugging or testing something that happens at the end of an epoch
Expand Down
4 changes: 2 additions & 2 deletions configs/trainer/gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ defaults:
accelerator: gpu
devices: 1

precision: '32-true' # 'transformer-engine', 'transformer-engine-float16', '16-true', '16-mixed', 'bf16-true',
precision: 'bf16' # 'transformer-engine', 'transformer-engine-float16', '16-true', '16-mixed', 'bf16-true',
# 'bf16-mixed', '32-true',

# Gradient accumulation
accumulate_grad_batches: 1
accumulate_grad_batches: 4
110 changes: 61 additions & 49 deletions src/data/protein_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,52 +129,55 @@ def forward(
a dictionary of chain ids (keys are chain ids, e.g. 'A', values are the indices
used in 'chain_id' and 'chain_encoding_all' objects)
Returns:
a dictionary containing the features of AlphaFold3 containing the following elements:
"residue_index":
[n_tokens] Residue number in the token’s original input chain.
"token_index":
[n_tokens] Token number. Increases monotonically; does not restart at 1 for new chains.
"asym_id":
[n_tokens] Unique integer for each distinct chain.
"entity_id":
[n_tokens] Unique integer for each distinct entity.
"sym_id":
[N_tokens] Unique integer within chains of this sequence. E.g. if chains
A, B and C share a sequence but D does not, their sym_ids would be [0, 1, 2, 0]
"ref_pos":
[N_atoms, 3] atom positions in the reference conformers, with
a random rotation and translation applied. Atom positions in Angstroms.
"ref_mask":
[N_atoms] Mask indicating which atom slots are used in the reference
conformer.
"ref_element":
[N_atoms, 128] One-hot encoding of the element atomic number for each atom
in the reference conformer, up to atomic number 128.
"ref_charge":
[N_atoms] Charge for each atom in the reference conformer.
"ref_atom_name_chars":
[N_atom, 4, 64] One-hot encoding of the unique atom names in the reference
conformer. Each character is encoded as ord(c - 32), and names are padded to
length 4.
"ref_space_uid":
[N_atoms] Numerical encoding of the chain id and residue index associated
with this reference conformer. Each (chain id, residue index) tuple is assigned
an integer on first appearance.
"atom_to_token":
[N_atoms] Token index for each atom in the flat atom representation.
"atom_exists":
[N_atoms] binary mask for atoms, whether atom exists, used for loss masking
"token_mask":
[n_tokens] Mask indicating which tokens are non-padding tokens
"atom_mask":
[N_atoms] Mask indicating which atoms are non-padding atoms
a dictionary with the following elements:
"features":
a dictionary containing the features of AlphaFold3 containing the following elements:
"residue_index":
[n_tokens] Residue number in the token’s original input chain.
"token_index":
[n_tokens] Token number. Increases monotonically; does not restart at 1 for new chains.
"asym_id":
[n_tokens] Unique integer for each distinct chain.
"entity_id":
[n_tokens] Unique integer for each distinct entity.
"sym_id":
[N_tokens] Unique integer within chains of this sequence. E.g. if chains
A, B and C share a sequence but D does not, their sym_ids would be [0, 1, 2, 0]
"ref_pos":
[N_atoms, 3] atom positions in the reference conformers, with
a random rotation and translation applied. Atom positions in Angstroms.
"ref_mask":
[N_atoms] Mask indicating which atom slots are used in the reference
conformer.
"ref_element":
[N_atoms, 128] One-hot encoding of the element atomic number for each atom
in the reference conformer, up to atomic number 128.
"ref_charge":
[N_atoms] Charge for each atom in the reference conformer.
"ref_atom_name_chars":
[N_atom, 4, 64] One-hot encoding of the unique atom names in the reference
conformer. Each character is encoded as ord(c - 32), and names are padded to
length 4.
"ref_space_uid":
[N_atoms] Numerical encoding of the chain id and residue index associated
with this reference conformer. Each (chain id, residue index) tuple is assigned
an integer on first appearance.
"atom_to_token":
[N_atoms] Token index for each atom in the flat atom representation.
"atom_positions":
[N_atoms, 3] ground truth atom positions in Angstroms.
"atom_exists":
[N_atoms] binary mask for atoms, whether atom exists, used for loss masking
"token_mask":
[n_tokens] Mask indicating which tokens are non-padding tokens
"atom_mask":
[N_atoms] Mask indicating which atoms are non-padding atoms
TODO: this should return a dictionary of dictionaries, where batch["features"] returns the actual AF3 features
and the rest of the keys are for masks, ground truth atom positions, etc. This way, there is no danger of
information leakage and everything is more organized.
"""
total_L = protein_dict["residue_idx"].shape[0] # crop_size
masks = {
# Masks
"token_mask": protein_dict["token_mask"], # (n_tokens,)
"atom_mask": protein_dict["token_mask"].unsqueeze(-1).expand(total_L, 4).reshape(total_L * 4)
}

af3_features = {
"residue_index": protein_dict["residue_idx"],
"token_index": torch.arange(total_L, dtype=torch.float32),
Expand All @@ -190,13 +193,22 @@ def forward(
["N", "CA", "C", "O"]).unsqueeze(0).expand(total_L, 4, 4, 64).reshape(total_L * 4, 4, 64),
"ref_space_uid": protein_dict["residue_idx"].unsqueeze(-1).expand(total_L, 4).reshape(total_L * 4),
"atom_to_token": torch.arange(total_L).unsqueeze(-1).expand(total_L, 4).reshape(total_L * 4),
"atom_exists": protein_dict["mask"].unsqueeze(-1).expand(total_L, 4).reshape(total_L * 4) * masks[
"atom_mask"],

# Actual positions
"atom_positions": protein_dict["X"].reshape(total_L * 4, 3),
}
return af3_features | masks

# Compute masks
token_mask = protein_dict["token_mask"]
atom_mask = token_mask.unsqueeze(-1).expand(total_L, 4).reshape(total_L * 4)

# Final output dictionary
output_dict = {
"features": af3_features,
"atom_positions": protein_dict["X"].reshape(total_L * 4, 3).float(),
"atom_exists": protein_dict["mask"].unsqueeze(-1).expand(total_L, 4).reshape(total_L * 4) * atom_mask,
"token_mask": token_mask,
"atom_mask": atom_mask,
}
return output_dict

@staticmethod
def compute_atom_name_chars(atom_names: List[str]) -> torch.Tensor:
Expand Down
3 changes: 2 additions & 1 deletion src/diffusion/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn.functional as F
from src.models.components.transition import ConditionedTransitionBlock
from src.models.components.primitives import AttentionPairBias
from torch.utils.checkpoint import checkpoint


class DiffusionTransformer(nn.Module):
Expand Down Expand Up @@ -53,6 +54,6 @@ def __init__(
def forward(self, single_repr, single_proj, pair_repr, mask=None):
"""Forward pass of the AtomTransformer module. Algorithm 23 in AlphaFold3 supplement."""
for i in range(self.num_blocks):
b = self.attention_blocks[i](single_repr, single_proj, pair_repr, mask)
b = self.attention_blocks[i](single_repr, single_proj, pair_repr, mask) # checkpoint(
single_repr = b + self.conditioned_transition_blocks[i](single_repr, single_proj)
return single_repr
29 changes: 18 additions & 11 deletions src/diffusion/conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,28 @@
from src.models.components.transition import Transition
from typing import Dict, Tuple
from torch.nn import functional as F
from src.utils.tensor_utils import one_hot


class FourierEmbedding(nn.Module):
"""Fourier embedding for diffusion conditioning."""

def __init__(self, embed_dim):
super(FourierEmbedding, self).__init__()
self.embed_dim = embed_dim
# Randomly generate weight/bias once before training
self.weight = nn.Parameter(torch.randn((1, embed_dim)))
self.bias = nn.Parameter(torch.randn((1, embed_dim,)))
self.bias = nn.Parameter(torch.randn((1, embed_dim)))

def forward(self, t):
"""Compute embeddings"""
two_pi = torch.tensor(2 * math.pi, device=t.device)
two_pi = torch.tensor(2 * 3.1415, device=t.device, dtype=t.dtype)
return torch.cos(two_pi * (t * self.weight + self.bias))


class RelativePositionEncoding(nn.Module):
"""Relative position encoding for diffusion conditioning."""

def __init__(
self,
c_pair: int,
Expand Down Expand Up @@ -95,7 +98,7 @@ def forward(self, features: Dict[str, torch.Tensor], mask=None) -> torch.Tensor:

# Mask the output
if mask is not None:
mask = (mask[:, :, None] & mask[:, None, :]).unsqueeze(-1).float() # (bs, n_tokens, n_tokens, 1)
mask = (mask[:, :, None] * mask[:, None, :]).unsqueeze(-1) # (bs, n_tokens, n_tokens, 1)
p_ij = mask * p_ij
return p_ij

Expand All @@ -107,14 +110,18 @@ def encode(feature_tensor: torch.Tensor,
relative_dists = feature_tensor[:, None, :] - feature_tensor[:, :, None]
d_ij = torch.where(
condition_tensor,
torch.clamp(torch.add(relative_dists, clamp_max), min=0, max=2*clamp_max),
torch.full_like(relative_dists, 2*clamp_max + 1)
torch.clamp(torch.add(relative_dists, clamp_max), min=0, max=2 * clamp_max),
torch.full_like(relative_dists, 2 * clamp_max + 1)
)
return F.one_hot(d_ij, num_classes=2 * clamp_max + 2) # (bs, n_tokens, n_tokens, 2 * clamp_max + 2)
a_ij = one_hot(d_ij, v_bins=torch.arange(0, (2 * clamp_max + 2),
device=feature_tensor.device,
dtype=feature_tensor.dtype))
return a_ij # (bs, n_tokens, n_tokens, 2 * clamp_max + 2)


class DiffusionConditioning(nn.Module):
"""Diffusion conditioning module."""

def __init__(
self,
c_token: int = 384,
Expand All @@ -133,13 +140,13 @@ def __init__(

# Pair conditioning
self.relative_position_encoding = RelativePositionEncoding(c_pair)
self.pair_layer_norm = nn.LayerNorm(2*c_pair) # z_trunk + relative_position_encoding
self.linear_pair = Linear(2*c_pair, c_pair, bias=False)
self.pair_layer_norm = nn.LayerNorm(2 * c_pair) # z_trunk + relative_position_encoding
self.linear_pair = Linear(2 * c_pair, c_pair, bias=False)
self.pair_transitions = nn.ModuleList([Transition(input_dim=c_pair, n=2) for _ in range(2)])

# Single conditioning
self.single_layer_norm = nn.LayerNorm(2*c_token) # s_trunk + s_inputs
self.linear_single = Linear(2*c_token, c_token, bias=False)
self.single_layer_norm = nn.LayerNorm(2 * c_token) # s_trunk + s_inputs
self.linear_single = Linear(2 * c_token, c_token, bias=False)
self.fourier_embedding = FourierEmbedding(embed_dim=256) # 256 is the default value in the paper
self.fourier_layer_norm = nn.LayerNorm(256)
self.linear_fourier = Linear(256, c_token, bias=False)
Expand Down Expand Up @@ -201,7 +208,7 @@ def forward(
# Mask outputs
if mask is not None:
token_repr = mask.unsqueeze(-1) * token_repr
pair_mask = (mask[:, :, None] & mask[:, None, :]).unsqueeze(-1).float() # (bs, n_tokens, n_tokens, 1)
pair_mask = (mask[:, :, None] * mask[:, None, :]).unsqueeze(-1) # (bs, n_tokens, n_tokens, 1)
pair_repr = pair_mask * pair_repr

return token_repr, pair_repr
4 changes: 2 additions & 2 deletions src/diffusion/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Diffusion losses"""
"""Diffusion losses."""

import torch
from src.utils.geometry.vector import Vec3Array, square_euclidean_distance, euclidean_distance
Expand Down Expand Up @@ -26,7 +26,7 @@ def smooth_lddt_loss(
F.sigmoid(torch.sub(2.0, delta_lm)) + F.sigmoid(torch.sub(4.0, delta_lm))), 4.0)

# Restrict to bespoke inclusion radius
atom_is_nucleotide = atom_is_dna + atom_is_rna
atom_is_nucleotide = (atom_is_dna + atom_is_rna).unsqueeze(-1).expand_as(delta_x_gt_lm)
atom_not_nucleotide = torch.add(torch.neg(atom_is_nucleotide), 1.0) # (1 - atom_is_nucleotide)
c_lm = (delta_x_gt_lm < 30.0).float() * atom_is_nucleotide + (delta_x_gt_lm < 15.0).float() * atom_not_nucleotide

Expand Down
8 changes: 5 additions & 3 deletions src/models/components/atom_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from src.utils.tensor_utils import partition_tensor
from src.utils.geometry.vector import Vec3Array
from typing import Dict, Tuple, NamedTuple
from torch.utils.checkpoint import checkpoint


def _split_heads(x, n_heads):
Expand Down Expand Up @@ -290,8 +291,9 @@ def __init__(
def forward(self, atom_single_repr, atom_single_proj, atom_pair_repr, mask=None):
"""Forward pass of the AtomTransformer module. Algorithm 23 in AlphaFold3 supplement."""
for i in range(self.num_blocks):
b = self.attention_blocks[i](atom_single_repr, atom_single_proj, atom_pair_repr, mask)
b = self.attention_blocks[i](atom_single_repr, atom_single_proj, atom_pair_repr, mask) # checkpoint()
atom_single_repr = b + self.conditioned_transition_blocks[i](atom_single_repr, atom_single_proj)
# checkpoint(
return atom_single_repr


Expand Down Expand Up @@ -372,7 +374,7 @@ def aggregate_atom_to_token(
bs, n_atoms, c_atom = atom_representation.shape

# Initialize the token representation tensor with zeros
token_representation = torch.zeros(bs, n_tokens, c_atom,
token_representation = torch.zeros((bs, n_tokens, c_atom),
device=atom_representation.device,
dtype=atom_representation.dtype)

Expand Down Expand Up @@ -680,7 +682,7 @@ def __init__(
)

self.linear_atom = Linear(c_token, c_atom, init='default', bias=False)
self.linear_update = Linear(c_atom, 3, init='default', bias=False)
self.linear_update = Linear(c_atom, 3, init='final', bias=False)
self.layer_norm = nn.LayerNorm(c_atom)

def forward(
Expand Down
2 changes: 1 addition & 1 deletion src/models/components/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def _slice_bias(b):

def compute_pair_attention_mask(mask, large_number=-1e6):
# Compute boolean pair mask
pair_mask = (mask[:, :, None] & mask[:, None, :]).unsqueeze(-1).float() # (bs, n, n, 1)
pair_mask = (mask[:, :, None] * mask[:, None, :]).unsqueeze(-1) # (bs, n, n, 1)

# Invert such that 0.0 indicates attention, 1.0 indicates no attention
pair_mask_inv = torch.add(1, -pair_mask)
Expand Down
Loading

0 comments on commit 956352e

Please sign in to comment.