diff --git a/configs/data/protein.yaml b/configs/data/protein.yaml index 1cb459e..f5351e9 100644 --- a/configs/data/protein.yaml +++ b/configs/data/protein.yaml @@ -2,7 +2,7 @@ _target_: src.data.protein_datamodule.ProteinDataModule data_dir: "./data/" resolution_thr: 3.5 # Resolution threshold for PDB structures min_seq_id: 0.3 # Minimum sequence identity for MMSeq2 clustering -crop_size: 256 # The number of residues to crop the proteins to. +crop_size: 384 # The number of residues to crop the proteins to. max_length: 10_000 # Entries with total length of chains larger than max_length will be disregarded. use_fraction: 1.0 # the fraction of the clusters to use (first N in alphabetic order) entry_type: "chain" # { "biounit", "chain", "pair" } the type of entries to generate @@ -14,7 +14,7 @@ mask_frac: None # if given, the number of residues to mask is mask_frac times t mask_sequential: False # if True, the masked residues will be neighbors in the sequence; otherwise geometric mask mask_whole_chains: False # if True, the whole chain is masked force_binding_sites_frac: 0.15 -batch_size: 8 # The batch size. Defaults to `64`. +batch_size: 2 # The batch size. Defaults to `64`. num_workers: 7 # The number of workers. Defaults to `0`. pin_memory: False # Whether to pin memory. Defaults to `False`. debug: False diff --git a/configs/model/proteus.yaml b/configs/model/proteus.yaml new file mode 100644 index 0000000..bfecd76 --- /dev/null +++ b/configs/model/proteus.yaml @@ -0,0 +1,49 @@ +_target_: src.models.proteus_module.ProteusLitModule + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.00018 # 1.8 * 1e-3 # 0.00018 + # betas: (0.9, 0.95) # TODO: problem here! + eps: 1e-08 + weight_decay: 0.0 + fused: false + +scheduler: + _target_: torch.optim.lr_scheduler.StepLR + _partial_: true + step_size: 5 * 1e4 + gamma: 0.95 + +diffusion_module: + _target_: src.models.diffusion_module.DiffusionModule + c_atom: 128 + c_atompair: 16 + c_token: 128 # original: 384 + c_tokenpair: 128 + n_tokens: 384 + atom_encoder_blocks: 3 + atom_encoder_heads: 16 + dropout: 0.0 + atom_attention_n_queries: 32 + atom_attention_n_keys: 128 + atom_decoder_blocks: 3 + atom_decoder_heads: 16 + token_transformer_blocks: 12 # original: 24 + token_transformer_heads: 16 + +feature_embedder: + _target_: src.models.input_feature_embedder.ProteusFeatureEmbedder + n_tokens: 384 + c_token: 128 # original: 384 + c_atom: 128 + c_atompair: 16 + c_trunk_pair: 128 + num_blocks: 3 + num_heads: 4 + dropout: 0.0 + n_queries: 32 + n_keys: 128 + +# compile model for faster training with pytorch 2.0 +compile: false \ No newline at end of file diff --git a/configs/train.yaml b/configs/train.yaml index cbcf0b8..d9b86f6 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -6,7 +6,7 @@ defaults: - _self_ - callbacks: default - data: protein - - model: kestrel + - model: proteus - logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`) - trainer: gpu diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml index f9de697..ca250cd 100644 --- a/configs/trainer/default.yaml +++ b/configs/trainer/default.yaml @@ -16,10 +16,10 @@ precision: '32-true' # 'transformer-engine', 'transformer-engine-float16', '16-t check_val_every_n_epoch: 1 # frequency of logging -log_every_n_steps: 100 +log_every_n_steps: 10 # gradient clipping -gradient_clip_val: null +gradient_clip_val: 10.0 # gradient clipping if global norm is greater than 10 # How much of training/test/validation dataset to check. # Useful when debugging or testing something that happens at the end of an epoch diff --git a/configs/trainer/gpu.yaml b/configs/trainer/gpu.yaml index 65e8de1..6c2e96f 100644 --- a/configs/trainer/gpu.yaml +++ b/configs/trainer/gpu.yaml @@ -4,8 +4,8 @@ defaults: accelerator: gpu devices: 1 -precision: '32-true' # 'transformer-engine', 'transformer-engine-float16', '16-true', '16-mixed', 'bf16-true', +precision: 'bf16' # 'transformer-engine', 'transformer-engine-float16', '16-true', '16-mixed', 'bf16-true', # 'bf16-mixed', '32-true', # Gradient accumulation -accumulate_grad_batches: 1 \ No newline at end of file +accumulate_grad_batches: 4 \ No newline at end of file diff --git a/src/data/protein_datamodule.py b/src/data/protein_datamodule.py index 44ae380..240cb5a 100644 --- a/src/data/protein_datamodule.py +++ b/src/data/protein_datamodule.py @@ -129,52 +129,55 @@ def forward( a dictionary of chain ids (keys are chain ids, e.g. 'A', values are the indices used in 'chain_id' and 'chain_encoding_all' objects) Returns: - a dictionary containing the features of AlphaFold3 containing the following elements: - "residue_index": - [n_tokens] Residue number in the token’s original input chain. - "token_index": - [n_tokens] Token number. Increases monotonically; does not restart at 1 for new chains. - "asym_id": - [n_tokens] Unique integer for each distinct chain. - "entity_id": - [n_tokens] Unique integer for each distinct entity. - "sym_id": - [N_tokens] Unique integer within chains of this sequence. E.g. if chains - A, B and C share a sequence but D does not, their sym_ids would be [0, 1, 2, 0] - "ref_pos": - [N_atoms, 3] atom positions in the reference conformers, with - a random rotation and translation applied. Atom positions in Angstroms. - "ref_mask": - [N_atoms] Mask indicating which atom slots are used in the reference - conformer. - "ref_element": - [N_atoms, 128] One-hot encoding of the element atomic number for each atom - in the reference conformer, up to atomic number 128. - "ref_charge": - [N_atoms] Charge for each atom in the reference conformer. - "ref_atom_name_chars": - [N_atom, 4, 64] One-hot encoding of the unique atom names in the reference - conformer. Each character is encoded as ord(c - 32), and names are padded to - length 4. - "ref_space_uid": - [N_atoms] Numerical encoding of the chain id and residue index associated - with this reference conformer. Each (chain id, residue index) tuple is assigned - an integer on first appearance. - "atom_to_token": - [N_atoms] Token index for each atom in the flat atom representation. - "atom_exists": - [N_atoms] binary mask for atoms, whether atom exists, used for loss masking - "token_mask": - [n_tokens] Mask indicating which tokens are non-padding tokens - "atom_mask": - [N_atoms] Mask indicating which atoms are non-padding atoms + a dictionary with the following elements: + "features": + a dictionary containing the features of AlphaFold3 containing the following elements: + "residue_index": + [n_tokens] Residue number in the token’s original input chain. + "token_index": + [n_tokens] Token number. Increases monotonically; does not restart at 1 for new chains. + "asym_id": + [n_tokens] Unique integer for each distinct chain. + "entity_id": + [n_tokens] Unique integer for each distinct entity. + "sym_id": + [N_tokens] Unique integer within chains of this sequence. E.g. if chains + A, B and C share a sequence but D does not, their sym_ids would be [0, 1, 2, 0] + "ref_pos": + [N_atoms, 3] atom positions in the reference conformers, with + a random rotation and translation applied. Atom positions in Angstroms. + "ref_mask": + [N_atoms] Mask indicating which atom slots are used in the reference + conformer. + "ref_element": + [N_atoms, 128] One-hot encoding of the element atomic number for each atom + in the reference conformer, up to atomic number 128. + "ref_charge": + [N_atoms] Charge for each atom in the reference conformer. + "ref_atom_name_chars": + [N_atom, 4, 64] One-hot encoding of the unique atom names in the reference + conformer. Each character is encoded as ord(c - 32), and names are padded to + length 4. + "ref_space_uid": + [N_atoms] Numerical encoding of the chain id and residue index associated + with this reference conformer. Each (chain id, residue index) tuple is assigned + an integer on first appearance. + "atom_to_token": + [N_atoms] Token index for each atom in the flat atom representation. + "atom_positions": + [N_atoms, 3] ground truth atom positions in Angstroms. + "atom_exists": + [N_atoms] binary mask for atoms, whether atom exists, used for loss masking + "token_mask": + [n_tokens] Mask indicating which tokens are non-padding tokens + "atom_mask": + [N_atoms] Mask indicating which atoms are non-padding atoms + TODO: this should return a dictionary of dictionaries, where batch["features"] returns the actual AF3 features + and the rest of the keys are for masks, ground truth atom positions, etc. This way, there is no danger of + information leakage and everything is more organized. """ total_L = protein_dict["residue_idx"].shape[0] # crop_size - masks = { - # Masks - "token_mask": protein_dict["token_mask"], # (n_tokens,) - "atom_mask": protein_dict["token_mask"].unsqueeze(-1).expand(total_L, 4).reshape(total_L * 4) - } + af3_features = { "residue_index": protein_dict["residue_idx"], "token_index": torch.arange(total_L, dtype=torch.float32), @@ -190,13 +193,22 @@ def forward( ["N", "CA", "C", "O"]).unsqueeze(0).expand(total_L, 4, 4, 64).reshape(total_L * 4, 4, 64), "ref_space_uid": protein_dict["residue_idx"].unsqueeze(-1).expand(total_L, 4).reshape(total_L * 4), "atom_to_token": torch.arange(total_L).unsqueeze(-1).expand(total_L, 4).reshape(total_L * 4), - "atom_exists": protein_dict["mask"].unsqueeze(-1).expand(total_L, 4).reshape(total_L * 4) * masks[ - "atom_mask"], - # Actual positions - "atom_positions": protein_dict["X"].reshape(total_L * 4, 3), } - return af3_features | masks + + # Compute masks + token_mask = protein_dict["token_mask"] + atom_mask = token_mask.unsqueeze(-1).expand(total_L, 4).reshape(total_L * 4) + + # Final output dictionary + output_dict = { + "features": af3_features, + "atom_positions": protein_dict["X"].reshape(total_L * 4, 3).float(), + "atom_exists": protein_dict["mask"].unsqueeze(-1).expand(total_L, 4).reshape(total_L * 4) * atom_mask, + "token_mask": token_mask, + "atom_mask": atom_mask, + } + return output_dict @staticmethod def compute_atom_name_chars(atom_names: List[str]) -> torch.Tensor: diff --git a/src/diffusion/attention.py b/src/diffusion/attention.py index fd61985..93c292a 100644 --- a/src/diffusion/attention.py +++ b/src/diffusion/attention.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from src.models.components.transition import ConditionedTransitionBlock from src.models.components.primitives import AttentionPairBias +from torch.utils.checkpoint import checkpoint class DiffusionTransformer(nn.Module): @@ -53,6 +54,6 @@ def __init__( def forward(self, single_repr, single_proj, pair_repr, mask=None): """Forward pass of the AtomTransformer module. Algorithm 23 in AlphaFold3 supplement.""" for i in range(self.num_blocks): - b = self.attention_blocks[i](single_repr, single_proj, pair_repr, mask) + b = self.attention_blocks[i](single_repr, single_proj, pair_repr, mask) # checkpoint( single_repr = b + self.conditioned_transition_blocks[i](single_repr, single_proj) return single_repr diff --git a/src/diffusion/conditioning.py b/src/diffusion/conditioning.py index 96fb788..5b23a8b 100644 --- a/src/diffusion/conditioning.py +++ b/src/diffusion/conditioning.py @@ -7,25 +7,28 @@ from src.models.components.transition import Transition from typing import Dict, Tuple from torch.nn import functional as F +from src.utils.tensor_utils import one_hot class FourierEmbedding(nn.Module): """Fourier embedding for diffusion conditioning.""" + def __init__(self, embed_dim): super(FourierEmbedding, self).__init__() self.embed_dim = embed_dim # Randomly generate weight/bias once before training self.weight = nn.Parameter(torch.randn((1, embed_dim))) - self.bias = nn.Parameter(torch.randn((1, embed_dim,))) + self.bias = nn.Parameter(torch.randn((1, embed_dim))) def forward(self, t): """Compute embeddings""" - two_pi = torch.tensor(2 * math.pi, device=t.device) + two_pi = torch.tensor(2 * 3.1415, device=t.device, dtype=t.dtype) return torch.cos(two_pi * (t * self.weight + self.bias)) class RelativePositionEncoding(nn.Module): """Relative position encoding for diffusion conditioning.""" + def __init__( self, c_pair: int, @@ -95,7 +98,7 @@ def forward(self, features: Dict[str, torch.Tensor], mask=None) -> torch.Tensor: # Mask the output if mask is not None: - mask = (mask[:, :, None] & mask[:, None, :]).unsqueeze(-1).float() # (bs, n_tokens, n_tokens, 1) + mask = (mask[:, :, None] * mask[:, None, :]).unsqueeze(-1) # (bs, n_tokens, n_tokens, 1) p_ij = mask * p_ij return p_ij @@ -107,14 +110,18 @@ def encode(feature_tensor: torch.Tensor, relative_dists = feature_tensor[:, None, :] - feature_tensor[:, :, None] d_ij = torch.where( condition_tensor, - torch.clamp(torch.add(relative_dists, clamp_max), min=0, max=2*clamp_max), - torch.full_like(relative_dists, 2*clamp_max + 1) + torch.clamp(torch.add(relative_dists, clamp_max), min=0, max=2 * clamp_max), + torch.full_like(relative_dists, 2 * clamp_max + 1) ) - return F.one_hot(d_ij, num_classes=2 * clamp_max + 2) # (bs, n_tokens, n_tokens, 2 * clamp_max + 2) + a_ij = one_hot(d_ij, v_bins=torch.arange(0, (2 * clamp_max + 2), + device=feature_tensor.device, + dtype=feature_tensor.dtype)) + return a_ij # (bs, n_tokens, n_tokens, 2 * clamp_max + 2) class DiffusionConditioning(nn.Module): """Diffusion conditioning module.""" + def __init__( self, c_token: int = 384, @@ -133,13 +140,13 @@ def __init__( # Pair conditioning self.relative_position_encoding = RelativePositionEncoding(c_pair) - self.pair_layer_norm = nn.LayerNorm(2*c_pair) # z_trunk + relative_position_encoding - self.linear_pair = Linear(2*c_pair, c_pair, bias=False) + self.pair_layer_norm = nn.LayerNorm(2 * c_pair) # z_trunk + relative_position_encoding + self.linear_pair = Linear(2 * c_pair, c_pair, bias=False) self.pair_transitions = nn.ModuleList([Transition(input_dim=c_pair, n=2) for _ in range(2)]) # Single conditioning - self.single_layer_norm = nn.LayerNorm(2*c_token) # s_trunk + s_inputs - self.linear_single = Linear(2*c_token, c_token, bias=False) + self.single_layer_norm = nn.LayerNorm(2 * c_token) # s_trunk + s_inputs + self.linear_single = Linear(2 * c_token, c_token, bias=False) self.fourier_embedding = FourierEmbedding(embed_dim=256) # 256 is the default value in the paper self.fourier_layer_norm = nn.LayerNorm(256) self.linear_fourier = Linear(256, c_token, bias=False) @@ -201,7 +208,7 @@ def forward( # Mask outputs if mask is not None: token_repr = mask.unsqueeze(-1) * token_repr - pair_mask = (mask[:, :, None] & mask[:, None, :]).unsqueeze(-1).float() # (bs, n_tokens, n_tokens, 1) + pair_mask = (mask[:, :, None] * mask[:, None, :]).unsqueeze(-1) # (bs, n_tokens, n_tokens, 1) pair_repr = pair_mask * pair_repr return token_repr, pair_repr diff --git a/src/diffusion/loss.py b/src/diffusion/loss.py index 1569601..6a32b5a 100644 --- a/src/diffusion/loss.py +++ b/src/diffusion/loss.py @@ -1,4 +1,4 @@ -"""Diffusion losses""" +"""Diffusion losses.""" import torch from src.utils.geometry.vector import Vec3Array, square_euclidean_distance, euclidean_distance @@ -26,7 +26,7 @@ def smooth_lddt_loss( F.sigmoid(torch.sub(2.0, delta_lm)) + F.sigmoid(torch.sub(4.0, delta_lm))), 4.0) # Restrict to bespoke inclusion radius - atom_is_nucleotide = atom_is_dna + atom_is_rna + atom_is_nucleotide = (atom_is_dna + atom_is_rna).unsqueeze(-1).expand_as(delta_x_gt_lm) atom_not_nucleotide = torch.add(torch.neg(atom_is_nucleotide), 1.0) # (1 - atom_is_nucleotide) c_lm = (delta_x_gt_lm < 30.0).float() * atom_is_nucleotide + (delta_x_gt_lm < 15.0).float() * atom_not_nucleotide diff --git a/src/models/components/atom_attention.py b/src/models/components/atom_attention.py index 1cd6cb7..9c06ed0 100644 --- a/src/models/components/atom_attention.py +++ b/src/models/components/atom_attention.py @@ -14,6 +14,7 @@ from src.utils.tensor_utils import partition_tensor from src.utils.geometry.vector import Vec3Array from typing import Dict, Tuple, NamedTuple +from torch.utils.checkpoint import checkpoint def _split_heads(x, n_heads): @@ -290,8 +291,9 @@ def __init__( def forward(self, atom_single_repr, atom_single_proj, atom_pair_repr, mask=None): """Forward pass of the AtomTransformer module. Algorithm 23 in AlphaFold3 supplement.""" for i in range(self.num_blocks): - b = self.attention_blocks[i](atom_single_repr, atom_single_proj, atom_pair_repr, mask) + b = self.attention_blocks[i](atom_single_repr, atom_single_proj, atom_pair_repr, mask) # checkpoint() atom_single_repr = b + self.conditioned_transition_blocks[i](atom_single_repr, atom_single_proj) + # checkpoint( return atom_single_repr @@ -372,7 +374,7 @@ def aggregate_atom_to_token( bs, n_atoms, c_atom = atom_representation.shape # Initialize the token representation tensor with zeros - token_representation = torch.zeros(bs, n_tokens, c_atom, + token_representation = torch.zeros((bs, n_tokens, c_atom), device=atom_representation.device, dtype=atom_representation.dtype) @@ -680,7 +682,7 @@ def __init__( ) self.linear_atom = Linear(c_token, c_atom, init='default', bias=False) - self.linear_update = Linear(c_atom, 3, init='default', bias=False) + self.linear_update = Linear(c_atom, 3, init='final', bias=False) self.layer_norm = nn.LayerNorm(c_atom) def forward( diff --git a/src/models/components/primitives.py b/src/models/components/primitives.py index 225d2fa..c9e5861 100644 --- a/src/models/components/primitives.py +++ b/src/models/components/primitives.py @@ -395,7 +395,7 @@ def _slice_bias(b): def compute_pair_attention_mask(mask, large_number=-1e6): # Compute boolean pair mask - pair_mask = (mask[:, :, None] & mask[:, None, :]).unsqueeze(-1).float() # (bs, n, n, 1) + pair_mask = (mask[:, :, None] * mask[:, None, :]).unsqueeze(-1) # (bs, n, n, 1) # Invert such that 0.0 indicates attention, 1.0 indicates no attention pair_mask_inv = torch.add(1, -pair_mask) diff --git a/src/models/diffusion_module.py b/src/models/diffusion_module.py index 5a10a34..81134a8 100644 --- a/src/models/diffusion_module.py +++ b/src/models/diffusion_module.py @@ -22,7 +22,7 @@ class DiffusionModule(torch.nn.Module): def __init__( self, c_atom: int = 128, - c_atompair=16, + c_atompair: int = 16, c_token: int = 768, c_tokenpair: int = 128, n_tokens: int = 384, @@ -133,8 +133,6 @@ def forward( an integer on first appearance. "atom_to_token": [*, N_atoms] Token index for each atom in the flat atom representation. - "atom_exists": - [*, N_atoms] binary mask for atoms, whether atom exists, used for loss masking "residue_index": [*, N_tokens] Residue number in the token’s original input chain. "token_index":