Skip to content

Commit

Permalink
Structured AtomAttentionEncoder outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed May 30, 2024
1 parent 7142e6e commit 9ad5da4
Show file tree
Hide file tree
Showing 23 changed files with 174 additions and 70 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ ipython_config.py
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# having no cross-platform support, pipenv may install dependencies that don'timesteps work, or not
# install all needed dependencies.
#Pipfile.lock

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ defaults:
- hparams_search: null

# optional local config for machine/user specific settings
# it's optional since it doesn't need to exist and is excluded from version control
# it's optional since it doesn'timesteps need to exist and is excluded from version control
- optional local: default.yaml

# debugging config (enable through command line, e.g. `python train.py debug=default)
Expand Down
2 changes: 1 addition & 1 deletion configs/callbacks/early_stopping.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ early_stopping:
stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
# log_rank_zero_only: False # this keyword argument isn't available in stable version
# log_rank_zero_only: False # this keyword argument isn'timesteps available in stable version
6 changes: 3 additions & 3 deletions configs/debug/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ hydra:

trainer:
max_epochs: 1
accelerator: cpu # debuggers don't like gpus
devices: 1 # debuggers don't like multiprocessing
accelerator: cpu # debuggers don'timesteps like gpus
devices: 1 # debuggers don'timesteps like multiprocessing
detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor

data:
num_workers: 0 # debuggers don't like multiprocessing
num_workers: 0 # debuggers don'timesteps like multiprocessing
pin_memory: False # disable gpu memory pin
2 changes: 1 addition & 1 deletion configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ defaults:
- hparams_search: null

# optional local config for machine/user specific settings
# it's optional since it doesn't need to exist and is excluded from version control
# it's optional since it doesn'timesteps need to exist and is excluded from version control
- optional local: default

# debugging config (enable through command line, e.g. `python train.py debug=default)
Expand Down
2 changes: 1 addition & 1 deletion scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def get_nvidia_cc():
installed in the system (formatted as a tuple of strings) and an error
message. When the former is provided, the latter is None, and vice versa.
Adapted from script by Jan Schlüte t
Adapted from script by Jan Schlüte timesteps
https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
"""
CUDA_SUCCESS = 0
Expand Down
8 changes: 4 additions & 4 deletions src/common/residue_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ca_ca = 3.80209737096

# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
# this order (or a relevant subset from chi1 onwards). ALA and GLY don'timesteps have
# chi angles so their chi angle lists are empty.
chi_angles_atoms = {
'ALA': [],
Expand Down Expand Up @@ -553,7 +553,7 @@ def sequence_to_onehot(
sequence: An amino acid sequence.
mapping: A dictionary mapping amino acids to integers.
map_unknown_to_x: If True, any amino acid that is not in the mapping will be
mapped to the unknown amino acid 'X'. If the mapping doesn't contain
mapped to the unknown amino acid 'X'. If the mapping doesn'timesteps contain
amino acid 'X', an error will be thrown. If False, any amino acid not in
the mapping will throw an error.
Expand All @@ -562,7 +562,7 @@ def sequence_to_onehot(
the sequence.
Raises:
ValueError: If the mapping doesn't contain values from 0 to
ValueError: If the mapping doesn'timesteps contain values from 0 to
num_unique_aas - 1 without any gaps.
"""
num_entries = max(mapping.values()) + 1
Expand Down Expand Up @@ -642,7 +642,7 @@ def atom_id_to_type(atom_id: str) -> str:
# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
# many more, and less common, three letter names as keys and maps many of these
# to the same one letter name (including 'X' and 'U' which we don't use here).
# to the same one letter name (including 'X' and 'U' which we don'timesteps use here).
restype_3to1 = {v: k for k, v in restype_1to3.items()}

# Define a restype name for all unknown residues.
Expand Down
2 changes: 1 addition & 1 deletion src/data/components/protein_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(
use_fraction : float, default 1
the fraction of the clusters to use (first N in alphabetic order)
load_to_ram : bool, default False
if `True`, the data will be stored in RAM (use with caution! if RAM isn't big enough the machine might crash)
if `True`, the data will be stored in RAM (use with caution! if RAM isn'timesteps big enough the machine might crash)
debug : bool, default False
only process 1000 files
interpolate : {"none", "only_middle", "all"}
Expand Down
1 change: 0 additions & 1 deletion src/diffusion/augmentation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Data augmentations applied prior to sampling from the diffusion trajectory."""
import torch
from src.utils.geometry.vector import Vec3Array
from src.utils.geometry.rotation_matrix import Rot3Array

Expand Down
10 changes: 5 additions & 5 deletions src/diffusion/conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,16 @@ def __init__(

def forward(
self,
t: torch.Tensor, # timestep (bs, 1)
timesteps: torch.Tensor, # timestep (bs, 1)
features: Dict[str, torch.Tensor], # input feature dict
s_inputs: torch.Tensor, # (bs, n_tokens, c_token)
s_trunk: torch.Tensor, # (bs, n_tokens, c_token)
z_trunk: torch.Tensor, # (bs, n_tokens, n_tokens, c_pair)
sd_data: torch.Tensor # standard dev of data (bs, 1)
sd_data: float = 16.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Diffusion conditioning.
Args:
t:
timesteps:
[*, 1] timestep tensor
features:
input feature dictionary for the RelativePositionEncoding containing:
Expand All @@ -170,7 +170,7 @@ def forward(
z_trunk:
[*, n_tokens, n_tokens, c_pair] Pair conditioning from Pairformer trunk
sd_data:
[*, 1] Standard deviation of the data
Scaling factor for the timesteps before fourier embedding
"""
# Pair conditioning
pair_repr = torch.cat([z_trunk, self.relative_position_encoding(features)], dim=-1)
Expand All @@ -181,7 +181,7 @@ def forward(
# Single conditioning
token_repr = torch.cat([s_trunk, s_inputs], dim=-1)
token_repr = self.linear_single(self.single_layer_norm(token_repr))
fourier_repr = self.fourier_embedding(torch.log(t / sd_data) / 4.0)
fourier_repr = self.fourier_embedding(torch.log(torch.div(torch.div(timesteps, sd_data), 4.0)))
fourier_repr = self.linear_fourier(self.fourier_layer_norm(fourier_repr))
token_repr = token_repr + fourier_repr.unsqueeze(1)
for transition in self.single_transitions:
Expand Down
2 changes: 1 addition & 1 deletion src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# ------------------------------------------------------------------------------------ #
# the setup_root above is equivalent to:
# - adding project root dir to PYTHONPATH
# (so you don't need to force user to install project as a package)
# (so you don'timesteps need to force user to install project as a package)
# (necessary before importing any local modules e.g. `from src import utils`)
# - setting up PROJECT_ROOT environment variable
# (which is used as a base for paths in "configs/paths/default.yaml")
Expand Down
54 changes: 31 additions & 23 deletions src/models/components/atom_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from src.models.components.primitives import AdaLN, Linear
from src.models.components.transition import ConditionedTransitionBlock
from src.utils.tensor_utils import partition_tensor
from typing import Dict, Tuple
from typing import Dict, Tuple, NamedTuple


def _split_heads(x, n_heads):
Expand Down Expand Up @@ -383,6 +383,14 @@ def aggregate_atom_to_token(
return token_representation


class AtomAttentionEncoderOutput(NamedTuple):
"""Structured output class for AtomAttentionEncoder."""
token_single: torch.Tensor # (bs, n_tokens, c_token)
atom_single_skip_repr: torch.Tensor # (bs, n_atoms, c_atom)
atom_single_skip_proj: torch.Tensor # (bs, n_atoms, c_atom)
atom_pair_skip_repr: torch.Tensor # (bs, n_atoms, n_atoms c_atompair)


class AtomAttentionEncoder(nn.Module):
"""AtomAttentionEncoder"""

Expand Down Expand Up @@ -492,9 +500,11 @@ def __init__(
def forward(
self,
features: Dict[str, torch.Tensor],
pairformer_output: Dict[str, torch.tensor] = None,
# pairformer_output: Dict[str, torch.tensor] = None,
s_trunk: torch.Tensor = None, # (bs, n_tokens, c_token)
z_trunk: torch.Tensor = None, # (bs, n_tokens, c_trunk_pair)
noisy_pos: torch.Tensor = None, # (bs, n_atoms, 3)
) -> Dict[str, torch.Tensor]:
) -> AtomAttentionEncoderOutput:
"""Forward pass for the AtomAttentionEncoder module.
Args:
features:
Expand All @@ -516,23 +526,21 @@ def forward(
length 4.
"tok_idx":
[*, N_atoms] Token index for each atom in the flat atom representation.
pairformer_output:
Dictionary containing the output of the Pairformer model:
"token_single":
[*, N_tokens, c_token] single representation
"token_pair":
[*, N_tokens, N_tokens, c_pair] pair representation
s_trunk:
[*, N_tokens, c_token] single representation of the Pairformer trunk
z_trunk:
[*, N_tokens, N_tokens, c_pair] pair representation of the Pairformer trunk
noisy_pos:
[*, N_atoms, 3] Tensor containing the noisy positions. Defaults to None.
Returns:
Dictionary containing the output features:
"token_single":
A named tuple containing the following fields:
token_single:
[*, N_tokens, c_token] single representation
"atom_single_skip_repr":
atom_single_skip_repr:
[*, N_atoms, c_atom] atom single representation (denoted q_l in AF3 Supplement)
"atom_single_skip_proj":
atom_single_skip_proj:
[*, N_atoms, c_atom] atom single projection (denoted c_l in AF3 Supplement)
"atom_pair_skip_repr":
atom_pair_skip_repr:
[*, N_atoms, N_atoms, c_atompair] atom pair representation (denoted p_lm in AF3 Supplement)
"""
batch_size, n_atoms, _ = features['ref_pos'].size()
Expand Down Expand Up @@ -568,11 +576,11 @@ def forward(
# If provided, add trunk embeddings and noisy positions
if self.trunk_conditioning:
atom_single = atom_single + self.linear_trunk_single(
self.trunk_single_layer_norm(gather_token_repr(pairformer_output['token_single'], features['tok_idx']))
self.trunk_single_layer_norm(gather_token_repr(s_trunk, features['tok_idx']))
)
atom_pair = atom_pair + self.linear_trunk_pair(
self.trunk_pair_layer_norm(map_token_pairs_to_atom_pairs(
pairformer_output['token_pair'], features['tok_idx'])
z_trunk, features['tok_idx'])
)
)
# Add the noisy positions
Expand All @@ -592,13 +600,13 @@ def forward(
token_repr = aggregate_atom_to_token(atom_representation=F.relu(self.linear_output(atom_single_conditioning)),
tok_idx=features['tok_idx'],
n_tokens=self.n_tokens)
output_dict = {
"token_single": token_repr,
"atom_single_skip_repr": atom_single_conditioning,
"atom_single_skip_proj": atom_single,
"atom_pair_skip_repr": atom_pair,
}
return output_dict
output = AtomAttentionEncoderOutput(
token_single=token_repr,
atom_single_skip_repr=atom_single_conditioning,
atom_single_skip_proj=atom_single,
atom_pair_skip_repr=atom_pair,
)
return output


class AtomAttentionDecoder(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion src/models/components/triangular_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def forward(self,
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
indexing of all batch dimensions simultaneously (s.t. the
indexing of all batch dimensions simultaneously (s.timesteps. the
number of sub-batches is the product of the batch dimensions).
use_deepspeed_evo_attention:
whether to use DeepSpeed's EvoFormer attention
Expand Down
103 changes: 99 additions & 4 deletions src/models/diffusion_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,106 @@
- A two-level architecture, working first on atoms, then tokens, then atoms again.
"""
import torch
from torch import nn
from typing import Dict, Tuple
from src.utils.geometry.vector import Vec3Array
from src.diffusion.conditioning import DiffusionConditioning
from src.diffusion.attention import DiffusionTransformer
from src.models.components.atom_attention import AtomAttentionEncoder, AtomAttentionDecoder
from src.models.components.primitives import Linear


class DiffusionTransformer(torch.nn.Module):
pass
class DiffusionModule(torch.nn.Module):
def __init__(
self,
c_atom: int = 128,
c_atompair=16,
c_token: int = 768,
c_tokenpair: int = 128,
n_tokens: int = 384,
atom_encoder_blocks: int = 3,
atom_encoder_heads: int = 16,
dropout: float = 0.0,
atom_attention_n_queries: int = 32,
atom_attention_n_keys: int = 128,
atom_decoder_blocks: int = 3,
atom_decoder_heads: int = 16,
token_transformer_blocks: int = 24,
token_transformer_heads: int = 16,
):
super(DiffusionModule, self).__init__()
self.c_atom = c_atom
self.c_atompair = c_atompair
self.c_token = c_token
self.c_tokenpair = c_tokenpair
self.n_tokens = n_tokens
self.atom_encoder_blocks = atom_encoder_blocks
self.atom_encoder_heads = atom_encoder_heads
self.dropout = dropout
self.atom_attention_n_queries = atom_attention_n_queries
self.atom_attention_n_keys = atom_attention_n_keys
self.token_transformer_blocks = token_transformer_blocks
self.token_transformer_heads = token_transformer_heads

# Conditioning
self.diffusion_conditioning = DiffusionConditioning(c_token=c_token, c_pair=c_tokenpair)

class DiffusionConditioning(torch.nn.Module):
pass
# Sequence-local atom attention and aggregation to coarse-grained tokens
self.atom_attention_encoder = AtomAttentionEncoder(
n_tokens=n_tokens,
c_token=c_token,
c_atom=c_atom,
c_atompair=c_atompair,
c_trunk_pair=c_tokenpair,
num_blocks=atom_decoder_blocks,
num_heads=atom_encoder_heads,
dropout=dropout,
n_queries=atom_attention_n_queries,
n_keys=atom_attention_n_keys,
trunk_conditioning=True
)

# Full self-attention on token level
self.linear_token_residual = Linear(c_token, c_token, bias=False, init='final')
self.token_residual_layer_norm = nn.LayerNorm(c_token)
self.diffusion_transformer = DiffusionTransformer(
c_token=c_token,
c_pair=c_tokenpair,
num_blocks=token_transformer_blocks,
num_heads=token_transformer_heads,
dropout=dropout,
)

# Broadcast token activations to atoms and run sequence-local atom attention
self.atom_attention_decoder = AtomAttentionDecoder(
c_token=c_token,
c_atom=c_atom,
c_atompair=c_atompair,
num_blocks=atom_decoder_blocks,
num_heads=atom_decoder_heads,
dropout=dropout,
n_queries=atom_attention_n_queries,
n_keys=atom_attention_n_keys,
)

def forward(
self,
noisy_atoms: Vec3Array, # (bs, n_atoms)
timesteps: torch.Tensor, # (bs, 1)
features: Dict[str, torch.Tensor], # input feature dict
s_inputs: torch.Tensor, # (bs, n_tokens, c_token)
s_trunk: torch.Tensor, # (bs, n_tokens, c_token)
z_trunk: torch.Tensor, # (bs, n_tokens, n_tokens, c_pair)
sd_data: float = 16.0
) -> Vec3Array:
"""Diffusion module that denoises atomic coordinates based on conditioning"""
# Conditioning
token_repr, pair_repr = self.diffusion_conditioning(timesteps, features, s_inputs, s_trunk, z_trunk)

# Scale positions to dimensionless vectors with approximately unit variance
scale_factor = torch.reciprocal((torch.sqrt(torch.add(timesteps ** 2, sd_data ** 2))))
r_noisy = noisy_atoms / scale_factor

# Sequence local atom attention and aggregation to coarse-grained tokens
atom_encoder_out = self.atom_attention_encoder(features=features, )
pass
2 changes: 1 addition & 1 deletion src/models/kestrel_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def forward(
def on_train_start(self) -> None:
"""Lightning hook that is called when training begins."""
# by default lightning executes validation step sanity checks before training starts,
# so it's worth to make sure validation metrics don't store results from these checks
# so it's worth to make sure validation metrics don'timesteps store results from these checks
self.val_loss.reset()

def model_step(
Expand Down
2 changes: 1 addition & 1 deletion src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# ------------------------------------------------------------------------------------ #
# the setup_root above is equivalent to:
# - adding project root dir to PYTHONPATH
# (so you don't need to force user to install project as a package)
# (so you don'timesteps need to force user to install project as a package)
# (necessary before importing any local modules e.g. `from src import utils`)
# - setting up PROJECT_ROOT environment variable
# (which is used as a base for paths in "configs/paths/default.yaml")
Expand Down
Loading

0 comments on commit 9ad5da4

Please sign in to comment.