From 9ad5da4b206d767c41cca4f3dd4d125ffbc6eed4 Mon Sep 17 00:00:00 2001 From: ardagoreci <62720042+ardagoreci@users.noreply.github.com> Date: Thu, 30 May 2024 19:06:49 +0100 Subject: [PATCH] Structured AtomAttentionEncoder outputs --- .gitignore | 2 +- README.md | 2 +- configs/callbacks/early_stopping.yaml | 2 +- configs/debug/default.yaml | 6 +- configs/train.yaml | 2 +- scripts/utils.py | 2 +- src/common/residue_constants.py | 8 +- src/data/components/protein_dataset.py | 2 +- src/diffusion/augmentation.py | 1 - src/diffusion/conditioning.py | 10 +- src/eval.py | 2 +- src/models/components/atom_attention.py | 54 +++++---- src/models/components/triangular_attention.py | 2 +- src/models/diffusion_module.py | 103 +++++++++++++++++- src/models/kestrel_module.py | 2 +- src/train.py | 2 +- src/utils/chunk_utils.py | 4 +- src/utils/losses.py | 2 +- src/utils/rigid_utils.py | 4 +- src/utils/tensor_utils.py | 2 +- src/utils/utils.py | 2 +- tests/test_atom_attention.py | 26 +++-- tests/test_conditioning.py | 2 +- 23 files changed, 174 insertions(+), 70 deletions(-) diff --git a/.gitignore b/.gitignore index 04a0648..ca7b649 100644 --- a/.gitignore +++ b/.gitignore @@ -87,7 +87,7 @@ ipython_config.py # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not +# having no cross-platform support, pipenv may install dependencies that don'timesteps work, or not # install all needed dependencies. #Pipfile.lock diff --git a/README.md b/README.md index ec1a60c..da3cbd7 100644 --- a/README.md +++ b/README.md @@ -573,7 +573,7 @@ defaults: - hparams_search: null # optional local config for machine/user specific settings - # it's optional since it doesn't need to exist and is excluded from version control + # it's optional since it doesn'timesteps need to exist and is excluded from version control - optional local: default.yaml # debugging config (enable through command line, e.g. `python train.py debug=default) diff --git a/configs/callbacks/early_stopping.yaml b/configs/callbacks/early_stopping.yaml index c826c8d..befb45e 100644 --- a/configs/callbacks/early_stopping.yaml +++ b/configs/callbacks/early_stopping.yaml @@ -12,4 +12,4 @@ early_stopping: stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch - # log_rank_zero_only: False # this keyword argument isn't available in stable version + # log_rank_zero_only: False # this keyword argument isn'timesteps available in stable version diff --git a/configs/debug/default.yaml b/configs/debug/default.yaml index 1886902..559779e 100644 --- a/configs/debug/default.yaml +++ b/configs/debug/default.yaml @@ -26,10 +26,10 @@ hydra: trainer: max_epochs: 1 - accelerator: cpu # debuggers don't like gpus - devices: 1 # debuggers don't like multiprocessing + accelerator: cpu # debuggers don'timesteps like gpus + devices: 1 # debuggers don'timesteps like multiprocessing detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor data: - num_workers: 0 # debuggers don't like multiprocessing + num_workers: 0 # debuggers don'timesteps like multiprocessing pin_memory: False # disable gpu memory pin diff --git a/configs/train.yaml b/configs/train.yaml index 7096a37..cbcf0b8 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -22,7 +22,7 @@ defaults: - hparams_search: null # optional local config for machine/user specific settings - # it's optional since it doesn't need to exist and is excluded from version control + # it's optional since it doesn'timesteps need to exist and is excluded from version control - optional local: default # debugging config (enable through command line, e.g. `python train.py debug=default) diff --git a/scripts/utils.py b/scripts/utils.py index 24af53c..ea89f07 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -10,7 +10,7 @@ def get_nvidia_cc(): installed in the system (formatted as a tuple of strings) and an error message. When the former is provided, the latter is None, and vice versa. - Adapted from script by Jan Schlüte t + Adapted from script by Jan Schlüte timesteps https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 """ CUDA_SUCCESS = 0 diff --git a/src/common/residue_constants.py b/src/common/residue_constants.py index f80fff3..b1755b3 100644 --- a/src/common/residue_constants.py +++ b/src/common/residue_constants.py @@ -29,7 +29,7 @@ ca_ca = 3.80209737096 # Format: The list for each AA type contains chi1, chi2, chi3, chi4 in -# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have +# this order (or a relevant subset from chi1 onwards). ALA and GLY don'timesteps have # chi angles so their chi angle lists are empty. chi_angles_atoms = { 'ALA': [], @@ -553,7 +553,7 @@ def sequence_to_onehot( sequence: An amino acid sequence. mapping: A dictionary mapping amino acids to integers. map_unknown_to_x: If True, any amino acid that is not in the mapping will be - mapped to the unknown amino acid 'X'. If the mapping doesn't contain + mapped to the unknown amino acid 'X'. If the mapping doesn'timesteps contain amino acid 'X', an error will be thrown. If False, any amino acid not in the mapping will throw an error. @@ -562,7 +562,7 @@ def sequence_to_onehot( the sequence. Raises: - ValueError: If the mapping doesn't contain values from 0 to + ValueError: If the mapping doesn'timesteps contain values from 0 to num_unique_aas - 1 without any gaps. """ num_entries = max(mapping.values()) + 1 @@ -642,7 +642,7 @@ def atom_id_to_type(atom_id: str) -> str: # NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple # 1-to-1 mapping of 3 letter names to one letter names. The latter contains # many more, and less common, three letter names as keys and maps many of these -# to the same one letter name (including 'X' and 'U' which we don't use here). +# to the same one letter name (including 'X' and 'U' which we don'timesteps use here). restype_3to1 = {v: k for k, v in restype_1to3.items()} # Define a restype name for all unknown residues. diff --git a/src/data/components/protein_dataset.py b/src/data/components/protein_dataset.py index 68f9df8..b0d6d25 100644 --- a/src/data/components/protein_dataset.py +++ b/src/data/components/protein_dataset.py @@ -127,7 +127,7 @@ def __init__( use_fraction : float, default 1 the fraction of the clusters to use (first N in alphabetic order) load_to_ram : bool, default False - if `True`, the data will be stored in RAM (use with caution! if RAM isn't big enough the machine might crash) + if `True`, the data will be stored in RAM (use with caution! if RAM isn'timesteps big enough the machine might crash) debug : bool, default False only process 1000 files interpolate : {"none", "only_middle", "all"} diff --git a/src/diffusion/augmentation.py b/src/diffusion/augmentation.py index 1429fc9..9c9393b 100644 --- a/src/diffusion/augmentation.py +++ b/src/diffusion/augmentation.py @@ -1,5 +1,4 @@ """Data augmentations applied prior to sampling from the diffusion trajectory.""" -import torch from src.utils.geometry.vector import Vec3Array from src.utils.geometry.rotation_matrix import Rot3Array diff --git a/src/diffusion/conditioning.py b/src/diffusion/conditioning.py index 0863ad1..1a91b54 100644 --- a/src/diffusion/conditioning.py +++ b/src/diffusion/conditioning.py @@ -139,16 +139,16 @@ def __init__( def forward( self, - t: torch.Tensor, # timestep (bs, 1) + timesteps: torch.Tensor, # timestep (bs, 1) features: Dict[str, torch.Tensor], # input feature dict s_inputs: torch.Tensor, # (bs, n_tokens, c_token) s_trunk: torch.Tensor, # (bs, n_tokens, c_token) z_trunk: torch.Tensor, # (bs, n_tokens, n_tokens, c_pair) - sd_data: torch.Tensor # standard dev of data (bs, 1) + sd_data: float = 16.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Diffusion conditioning. Args: - t: + timesteps: [*, 1] timestep tensor features: input feature dictionary for the RelativePositionEncoding containing: @@ -170,7 +170,7 @@ def forward( z_trunk: [*, n_tokens, n_tokens, c_pair] Pair conditioning from Pairformer trunk sd_data: - [*, 1] Standard deviation of the data + Scaling factor for the timesteps before fourier embedding """ # Pair conditioning pair_repr = torch.cat([z_trunk, self.relative_position_encoding(features)], dim=-1) @@ -181,7 +181,7 @@ def forward( # Single conditioning token_repr = torch.cat([s_trunk, s_inputs], dim=-1) token_repr = self.linear_single(self.single_layer_norm(token_repr)) - fourier_repr = self.fourier_embedding(torch.log(t / sd_data) / 4.0) + fourier_repr = self.fourier_embedding(torch.log(torch.div(torch.div(timesteps, sd_data), 4.0))) fourier_repr = self.linear_fourier(self.fourier_layer_norm(fourier_repr)) token_repr = token_repr + fourier_repr.unsqueeze(1) for transition in self.single_transitions: diff --git a/src/eval.py b/src/eval.py index b70faae..0144b45 100644 --- a/src/eval.py +++ b/src/eval.py @@ -10,7 +10,7 @@ # ------------------------------------------------------------------------------------ # # the setup_root above is equivalent to: # - adding project root dir to PYTHONPATH -# (so you don't need to force user to install project as a package) +# (so you don'timesteps need to force user to install project as a package) # (necessary before importing any local modules e.g. `from src import utils`) # - setting up PROJECT_ROOT environment variable # (which is used as a base for paths in "configs/paths/default.yaml") diff --git a/src/models/components/atom_attention.py b/src/models/components/atom_attention.py index 9c2b6c0..2555130 100644 --- a/src/models/components/atom_attention.py +++ b/src/models/components/atom_attention.py @@ -12,7 +12,7 @@ from src.models.components.primitives import AdaLN, Linear from src.models.components.transition import ConditionedTransitionBlock from src.utils.tensor_utils import partition_tensor -from typing import Dict, Tuple +from typing import Dict, Tuple, NamedTuple def _split_heads(x, n_heads): @@ -383,6 +383,14 @@ def aggregate_atom_to_token( return token_representation +class AtomAttentionEncoderOutput(NamedTuple): + """Structured output class for AtomAttentionEncoder.""" + token_single: torch.Tensor # (bs, n_tokens, c_token) + atom_single_skip_repr: torch.Tensor # (bs, n_atoms, c_atom) + atom_single_skip_proj: torch.Tensor # (bs, n_atoms, c_atom) + atom_pair_skip_repr: torch.Tensor # (bs, n_atoms, n_atoms c_atompair) + + class AtomAttentionEncoder(nn.Module): """AtomAttentionEncoder""" @@ -492,9 +500,11 @@ def __init__( def forward( self, features: Dict[str, torch.Tensor], - pairformer_output: Dict[str, torch.tensor] = None, + # pairformer_output: Dict[str, torch.tensor] = None, + s_trunk: torch.Tensor = None, # (bs, n_tokens, c_token) + z_trunk: torch.Tensor = None, # (bs, n_tokens, c_trunk_pair) noisy_pos: torch.Tensor = None, # (bs, n_atoms, 3) - ) -> Dict[str, torch.Tensor]: + ) -> AtomAttentionEncoderOutput: """Forward pass for the AtomAttentionEncoder module. Args: features: @@ -516,23 +526,21 @@ def forward( length 4. "tok_idx": [*, N_atoms] Token index for each atom in the flat atom representation. - pairformer_output: - Dictionary containing the output of the Pairformer model: - "token_single": - [*, N_tokens, c_token] single representation - "token_pair": - [*, N_tokens, N_tokens, c_pair] pair representation + s_trunk: + [*, N_tokens, c_token] single representation of the Pairformer trunk + z_trunk: + [*, N_tokens, N_tokens, c_pair] pair representation of the Pairformer trunk noisy_pos: [*, N_atoms, 3] Tensor containing the noisy positions. Defaults to None. Returns: - Dictionary containing the output features: - "token_single": + A named tuple containing the following fields: + token_single: [*, N_tokens, c_token] single representation - "atom_single_skip_repr": + atom_single_skip_repr: [*, N_atoms, c_atom] atom single representation (denoted q_l in AF3 Supplement) - "atom_single_skip_proj": + atom_single_skip_proj: [*, N_atoms, c_atom] atom single projection (denoted c_l in AF3 Supplement) - "atom_pair_skip_repr": + atom_pair_skip_repr: [*, N_atoms, N_atoms, c_atompair] atom pair representation (denoted p_lm in AF3 Supplement) """ batch_size, n_atoms, _ = features['ref_pos'].size() @@ -568,11 +576,11 @@ def forward( # If provided, add trunk embeddings and noisy positions if self.trunk_conditioning: atom_single = atom_single + self.linear_trunk_single( - self.trunk_single_layer_norm(gather_token_repr(pairformer_output['token_single'], features['tok_idx'])) + self.trunk_single_layer_norm(gather_token_repr(s_trunk, features['tok_idx'])) ) atom_pair = atom_pair + self.linear_trunk_pair( self.trunk_pair_layer_norm(map_token_pairs_to_atom_pairs( - pairformer_output['token_pair'], features['tok_idx']) + z_trunk, features['tok_idx']) ) ) # Add the noisy positions @@ -592,13 +600,13 @@ def forward( token_repr = aggregate_atom_to_token(atom_representation=F.relu(self.linear_output(atom_single_conditioning)), tok_idx=features['tok_idx'], n_tokens=self.n_tokens) - output_dict = { - "token_single": token_repr, - "atom_single_skip_repr": atom_single_conditioning, - "atom_single_skip_proj": atom_single, - "atom_pair_skip_repr": atom_pair, - } - return output_dict + output = AtomAttentionEncoderOutput( + token_single=token_repr, + atom_single_skip_repr=atom_single_conditioning, + atom_single_skip_proj=atom_single, + atom_pair_skip_repr=atom_pair, + ) + return output class AtomAttentionDecoder(nn.Module): diff --git a/src/models/components/triangular_attention.py b/src/models/components/triangular_attention.py index 71f2b4e..fc9fa91 100644 --- a/src/models/components/triangular_attention.py +++ b/src/models/components/triangular_attention.py @@ -102,7 +102,7 @@ def forward(self, chunk_size: The number of sub-batches per chunk. If multiple batch dimensions are specified, a "sub-batch" is defined as a single - indexing of all batch dimensions simultaneously (s.t. the + indexing of all batch dimensions simultaneously (s.timesteps. the number of sub-batches is the product of the batch dimensions). use_deepspeed_evo_attention: whether to use DeepSpeed's EvoFormer attention diff --git a/src/models/diffusion_module.py b/src/models/diffusion_module.py index e803643..f4cda32 100644 --- a/src/models/diffusion_module.py +++ b/src/models/diffusion_module.py @@ -9,11 +9,106 @@ - A two-level architecture, working first on atoms, then tokens, then atoms again. """ import torch +from torch import nn +from typing import Dict, Tuple +from src.utils.geometry.vector import Vec3Array +from src.diffusion.conditioning import DiffusionConditioning +from src.diffusion.attention import DiffusionTransformer +from src.models.components.atom_attention import AtomAttentionEncoder, AtomAttentionDecoder +from src.models.components.primitives import Linear -class DiffusionTransformer(torch.nn.Module): - pass +class DiffusionModule(torch.nn.Module): + def __init__( + self, + c_atom: int = 128, + c_atompair=16, + c_token: int = 768, + c_tokenpair: int = 128, + n_tokens: int = 384, + atom_encoder_blocks: int = 3, + atom_encoder_heads: int = 16, + dropout: float = 0.0, + atom_attention_n_queries: int = 32, + atom_attention_n_keys: int = 128, + atom_decoder_blocks: int = 3, + atom_decoder_heads: int = 16, + token_transformer_blocks: int = 24, + token_transformer_heads: int = 16, + ): + super(DiffusionModule, self).__init__() + self.c_atom = c_atom + self.c_atompair = c_atompair + self.c_token = c_token + self.c_tokenpair = c_tokenpair + self.n_tokens = n_tokens + self.atom_encoder_blocks = atom_encoder_blocks + self.atom_encoder_heads = atom_encoder_heads + self.dropout = dropout + self.atom_attention_n_queries = atom_attention_n_queries + self.atom_attention_n_keys = atom_attention_n_keys + self.token_transformer_blocks = token_transformer_blocks + self.token_transformer_heads = token_transformer_heads + # Conditioning + self.diffusion_conditioning = DiffusionConditioning(c_token=c_token, c_pair=c_tokenpair) -class DiffusionConditioning(torch.nn.Module): - pass + # Sequence-local atom attention and aggregation to coarse-grained tokens + self.atom_attention_encoder = AtomAttentionEncoder( + n_tokens=n_tokens, + c_token=c_token, + c_atom=c_atom, + c_atompair=c_atompair, + c_trunk_pair=c_tokenpair, + num_blocks=atom_decoder_blocks, + num_heads=atom_encoder_heads, + dropout=dropout, + n_queries=atom_attention_n_queries, + n_keys=atom_attention_n_keys, + trunk_conditioning=True + ) + + # Full self-attention on token level + self.linear_token_residual = Linear(c_token, c_token, bias=False, init='final') + self.token_residual_layer_norm = nn.LayerNorm(c_token) + self.diffusion_transformer = DiffusionTransformer( + c_token=c_token, + c_pair=c_tokenpair, + num_blocks=token_transformer_blocks, + num_heads=token_transformer_heads, + dropout=dropout, + ) + + # Broadcast token activations to atoms and run sequence-local atom attention + self.atom_attention_decoder = AtomAttentionDecoder( + c_token=c_token, + c_atom=c_atom, + c_atompair=c_atompair, + num_blocks=atom_decoder_blocks, + num_heads=atom_decoder_heads, + dropout=dropout, + n_queries=atom_attention_n_queries, + n_keys=atom_attention_n_keys, + ) + + def forward( + self, + noisy_atoms: Vec3Array, # (bs, n_atoms) + timesteps: torch.Tensor, # (bs, 1) + features: Dict[str, torch.Tensor], # input feature dict + s_inputs: torch.Tensor, # (bs, n_tokens, c_token) + s_trunk: torch.Tensor, # (bs, n_tokens, c_token) + z_trunk: torch.Tensor, # (bs, n_tokens, n_tokens, c_pair) + sd_data: float = 16.0 + ) -> Vec3Array: + """Diffusion module that denoises atomic coordinates based on conditioning""" + # Conditioning + token_repr, pair_repr = self.diffusion_conditioning(timesteps, features, s_inputs, s_trunk, z_trunk) + + # Scale positions to dimensionless vectors with approximately unit variance + scale_factor = torch.reciprocal((torch.sqrt(torch.add(timesteps ** 2, sd_data ** 2)))) + r_noisy = noisy_atoms / scale_factor + + # Sequence local atom attention and aggregation to coarse-grained tokens + atom_encoder_out = self.atom_attention_encoder(features=features, ) + pass diff --git a/src/models/kestrel_module.py b/src/models/kestrel_module.py index 94032f5..84b73af 100644 --- a/src/models/kestrel_module.py +++ b/src/models/kestrel_module.py @@ -123,7 +123,7 @@ def forward( def on_train_start(self) -> None: """Lightning hook that is called when training begins.""" # by default lightning executes validation step sanity checks before training starts, - # so it's worth to make sure validation metrics don't store results from these checks + # so it's worth to make sure validation metrics don'timesteps store results from these checks self.val_loss.reset() def model_step( diff --git a/src/train.py b/src/train.py index 4adbcf4..e7fb7d6 100644 --- a/src/train.py +++ b/src/train.py @@ -12,7 +12,7 @@ # ------------------------------------------------------------------------------------ # # the setup_root above is equivalent to: # - adding project root dir to PYTHONPATH -# (so you don't need to force user to install project as a package) +# (so you don'timesteps need to force user to install project as a package) # (necessary before importing any local modules e.g. `from src import utils`) # - setting up PROJECT_ROOT environment variable # (which is used as a base for paths in "configs/paths/default.yaml") diff --git a/src/utils/chunk_utils.py b/src/utils/chunk_utils.py index 6f9fec4..45db895 100644 --- a/src/utils/chunk_utils.py +++ b/src/utils/chunk_utils.py @@ -183,7 +183,7 @@ def _chunk_slice( """ Equivalent to - t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end] + timesteps.reshape((-1,) + timesteps.shape[no_batch_dims:])[flat_start:flat_end] but without the need for the initial reshape call, which can be memory-intensive in certain situations. The only reshape operations @@ -235,7 +235,7 @@ def chunk_layer( chunk_size: The number of sub-batches per chunk. If multiple batch dimensions are specified, a "sub-batch" is defined as a single - indexing of all batch dimensions simultaneously (s.t. the + indexing of all batch dimensions simultaneously (s.timesteps. the number of sub-batches is the product of the batch dimensions). no_batch_dims: How many of the initial dimensions of each input tensor can diff --git a/src/utils/losses.py b/src/utils/losses.py index aa032d3..3848d8e 100644 --- a/src/utils/losses.py +++ b/src/utils/losses.py @@ -139,7 +139,7 @@ def lddt( ) -> torch.Tensor: """Compute the local distance difference test (LDDT) score. TODO: there is something off with per_residue=True, it seems to be the other way around - TODO: the lddt values seem far too high as I noise the protein, it doesn't go below 0.8. + TODO: the lddt values seem far too high as I noise the protein, it doesn'timesteps go below 0.8. """ n = all_atom_mask.shape[-2] dmat_true = torch.sqrt( diff --git a/src/utils/rigid_utils.py b/src/utils/rigid_utils.py index 22baaeb..46ba2d8 100644 --- a/src/utils/rigid_utils.py +++ b/src/utils/rigid_utils.py @@ -907,8 +907,8 @@ def __getitem__(self, E.g.:: r = Rotations(rot_mats=torch.rand(10, 10, 3, 3), quats=None) - t = Rigids(r, torch.rand(10, 10, 3)) - indexed = t[3, 4:6] + timesteps = Rigids(r, torch.rand(10, 10, 3)) + indexed = timesteps[3, 4:6] assert(indexed.shape == (2,)) assert(indexed.get_rots().shape == (2,)) assert(indexed.get_trans().shape == (2, 3)) diff --git a/src/utils/tensor_utils.py b/src/utils/tensor_utils.py index fb93049..b448107 100644 --- a/src/utils/tensor_utils.py +++ b/src/utils/tensor_utils.py @@ -22,7 +22,7 @@ def add(m1, m2, inplace): - # The first operation in a checkpoint can't be in-place, but it's + # The first operation in a checkpoint can'timesteps be in-place, but it's # nice to have in-place addition during inference. Thus... if not inplace: m1 = m1 + m2 diff --git a/src/utils/utils.py b/src/utils/utils.py index 02b5576..2b5f75c 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -82,7 +82,7 @@ def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: # display output dir path in terminal log.info(f"Output dir: {cfg.paths.output_dir}") - # always close wandb run (even if exception occurs so multirun won't fail) + # always close wandb run (even if exception occurs so multirun won'timesteps fail) if find_spec("wandb"): # check if wandb is installed import wandb diff --git a/tests/test_atom_attention.py b/tests/test_atom_attention.py index 685665a..2794bed 100644 --- a/tests/test_atom_attention.py +++ b/tests/test_atom_attention.py @@ -83,18 +83,20 @@ def test_forward_dimensions(self): } noisy_pos = torch.rand(self.batch_size, self.n_atoms, 3) - # Pairformer output mock (adjust as per actual module expectations) - pairformer_output = { - 'token_single': torch.rand(self.batch_size, self.n_tokens, self.c_token), - 'token_pair': torch.rand(self.batch_size, self.n_tokens, self.n_tokens, self.c_trunk_pair) - } - - output = self.encoder(features, pairformer_output, noisy_pos) - self.assertEqual(output['token_single'].shape, torch.Size([self.batch_size, self.n_tokens, self.c_token])) - self.assertEqual(output['atom_single_skip_repr'].shape, torch.Size([self.batch_size, self.n_atoms, self.c_atom])) - self.assertEqual(output['atom_single_skip_proj'].shape, torch.Size([self.batch_size, self.n_atoms, self.c_atom])) - self.assertEqual(output['atom_pair_skip_repr'].shape, torch.Size([self.batch_size, self.n_atoms, - self.n_atoms, self.c_atompair])) + # Pairformer outputs (adjust as per actual module expectations) + s_trunk = torch.rand(self.batch_size, self.n_tokens, self.c_token) + z_trunk = torch.rand(self.batch_size, self.n_tokens, self.n_tokens, self.c_trunk_pair) + # pairformer_output = { + # 'token_single': torch.rand(self.batch_size, self.n_tokens, self.c_token), + # 'token_pair': torch.rand(self.batch_size, self.n_tokens, self.n_tokens, self.c_trunk_pair) + # } + + output = self.encoder(features, s_trunk, z_trunk, noisy_pos) + self.assertEqual(output.token_single.shape, torch.Size([self.batch_size, self.n_tokens, self.c_token])) + self.assertEqual(output.atom_single_skip_repr.shape, torch.Size([self.batch_size, self.n_atoms, self.c_atom])) + self.assertEqual(output.atom_single_skip_proj.shape, torch.Size([self.batch_size, self.n_atoms, self.c_atom])) + self.assertEqual(output.atom_pair_skip_repr.shape, torch.Size([self.batch_size, self.n_atoms, + self.n_atoms, self.c_atompair])) class TestAtomAttentionDecoder(unittest.TestCase): diff --git a/tests/test_conditioning.py b/tests/test_conditioning.py index 1989ce0..96c0c57 100644 --- a/tests/test_conditioning.py +++ b/tests/test_conditioning.py @@ -86,7 +86,7 @@ def test_forward(self): s_inputs = torch.randn(self.batch_size, self.n_tokens, self.c_token) s_trunk = torch.randn(self.batch_size, self.n_tokens, self.c_token) z_trunk = torch.randn(self.batch_size, self.n_tokens, self.n_tokens, self.c_pair) - sd_data = torch.randn(self.batch_size, 1) # standard dev of data (bs, 1) + sd_data = 16.0 # torch.randn(self.batch_size, 1) # standard dev of data (bs, 1) output = self.module(t, features, s_inputs, s_trunk, z_trunk, sd_data) self.assertEqual(output[0].shape, (self.batch_size, self.n_tokens, self.c_token))