From e73447f44ef9c94904dc90058b0e459cd323153b Mon Sep 17 00:00:00 2001 From: ardagoreci <62720042+ardagoreci@users.noreply.github.com> Date: Thu, 23 Nov 2023 13:48:21 +0000 Subject: [PATCH] Code improvements, implemented Structure net --- configs/logger/wandb.yaml | 1 + src/data/protein_datamodule.py | 7 +- src/models/components/backbone_update.py | 6 +- .../components/invariant_point_attention.py | 4 +- src/models/components/structure_transition.py | 6 +- src/models/components/triangular_attention.py | 7 +- .../{components => }/evoformer_pair_stack.py | 9 - src/models/structure_net.py | 173 +++++++++++++ src/utils/kernel/attention_core.py | 103 -------- src/utils/kernel/csrc/compat.h | 11 - src/utils/kernel/csrc/softmax_cuda.cpp | 44 ---- src/utils/kernel/csrc/softmax_cuda_kernel.cu | 241 ------------------ src/utils/kernel/csrc/softmax_cuda_stub.cpp | 36 --- src/utils/rigid_utils.py | 228 ++++++++--------- tests/config.py | 34 +++ tests/test_evoformer_pair_stack.py | 3 +- tests/test_invariant_point_attention.py | 6 +- tests/test_structure_net.py | 52 ++++ 18 files changed, 395 insertions(+), 576 deletions(-) rename src/models/{components => }/evoformer_pair_stack.py (89%) create mode 100644 src/models/structure_net.py delete mode 100644 src/utils/kernel/attention_core.py delete mode 100644 src/utils/kernel/csrc/compat.h delete mode 100644 src/utils/kernel/csrc/softmax_cuda.cpp delete mode 100644 src/utils/kernel/csrc/softmax_cuda_kernel.cu delete mode 100644 src/utils/kernel/csrc/softmax_cuda_stub.cpp create mode 100644 tests/config.py create mode 100644 tests/test_structure_net.py diff --git a/configs/logger/wandb.yaml b/configs/logger/wandb.yaml index ece1658..ebcdff9 100644 --- a/configs/logger/wandb.yaml +++ b/configs/logger/wandb.yaml @@ -8,6 +8,7 @@ wandb: id: null # pass correct id to resume experiment! anonymous: null # enable anonymous logging project: "lightning-hydra-template" + entity: "ligo-technologies" log_model: False # upload lightning ckpts prefix: "" # a string to put at the beginning of metric keys # entity: "" # set to name of your wandb team diff --git a/src/data/protein_datamodule.py b/src/data/protein_datamodule.py index 5e26343..451a2cf 100644 --- a/src/data/protein_datamodule.py +++ b/src/data/protein_datamodule.py @@ -31,10 +31,9 @@ class Reorder(torch.nn.Module): """A transformation that reorders the 3D coordinates of backbone atoms from N, C, Ca, O -> N, Ca, C, O.""" def forward(self, protein_dict): - if 'reordered' not in protein_dict.keys(): - # If not already reordered, switch to N, Ca, C, ordering. - reordered_X = protein_dict['X'].index_select(1, torch.tensor([0, 2, 1, 3])) - protein_dict['X'] = reordered_X + # Switch to N, Ca, C, ordering. + reordered_X = protein_dict['X'].index_select(1, torch.tensor([0, 2, 1, 3])) + protein_dict['X'] = reordered_X return protein_dict diff --git a/src/models/components/backbone_update.py b/src/models/components/backbone_update.py index 8622295..91dd636 100644 --- a/src/models/components/backbone_update.py +++ b/src/models/components/backbone_update.py @@ -18,7 +18,7 @@ from torch import nn from src.models.components.primitives import Linear -from src.utils.rigid_utils import quat_to_rot, Rigid +from src.utils.rigid_utils import Rigids, Rotations class BackboneUpdate(nn.Module): @@ -62,6 +62,6 @@ def forward(self, s): quats = quats / norm_denominator.unsqueeze(-1) # [*, 3, 3] - rots = quat_to_rot(quats) + rots = Rotations(quats=quats) - return Rigid(rots, trans) + return Rigids(rots, trans) diff --git a/src/models/components/invariant_point_attention.py b/src/models/components/invariant_point_attention.py index 5421068..e2de2c7 100644 --- a/src/models/components/invariant_point_attention.py +++ b/src/models/components/invariant_point_attention.py @@ -24,7 +24,7 @@ from typing import Optional, Tuple, Sequence from src.utils.precision_utils import is_fp16_enabled -from src.utils.rigid_utils import Rotation, Rigid +from src.utils.rigid_utils import Rotations, Rigids from src.models.components.primitives import Linear, ipa_point_weights_init_ from src.utils.tensor_utils import ( @@ -108,7 +108,7 @@ def forward( self, s: torch.Tensor, z: Optional[torch.Tensor], - r: Rigid, + r: Rigids, mask: torch.Tensor, inplace_safe: bool = False, _offload_inference: bool = False, diff --git a/src/models/components/structure_transition.py b/src/models/components/structure_transition.py index 0b37e0e..3e1b3e1 100644 --- a/src/models/components/structure_transition.py +++ b/src/models/components/structure_transition.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# TODO: I suspect that this module can easily be deleted. It is a simple feedforward nonlinear transition. import torch.nn as nn from src.models.components.primitives import Linear @@ -45,7 +44,10 @@ def forward(self, s): class StructureTransition(nn.Module): - def __init__(self, c, num_layers, dropout_rate): + def __init__(self, + c, + num_layers: int = 1, + dropout_rate: float = 0.1): super(StructureTransition, self).__init__() self.c = c diff --git a/src/models/components/triangular_attention.py b/src/models/components/triangular_attention.py index fb486c7..db7e330 100644 --- a/src/models/components/triangular_attention.py +++ b/src/models/components/triangular_attention.py @@ -148,8 +148,11 @@ def forward(self, return x -# Implements Algorithm 13 -TriangleAttentionStartingNode = TriangleAttention +class TriangleAttentionStartingNode(TriangleAttention): + """ + Implements Algorithm 13. + """ + __init__ = partialmethod(TriangleAttention.__init__, starting=True) class TriangleAttentionEndingNode(TriangleAttention): diff --git a/src/models/components/evoformer_pair_stack.py b/src/models/evoformer_pair_stack.py similarity index 89% rename from src/models/components/evoformer_pair_stack.py rename to src/models/evoformer_pair_stack.py index 6bf5dc4..a3548cb 100644 --- a/src/models/components/evoformer_pair_stack.py +++ b/src/models/evoformer_pair_stack.py @@ -70,15 +70,6 @@ def __init__( dropout_rate: float = 0.25 ): super().__init__() - """ - self.blocks = nn.Sequential( - OrderedDict( - [(f'evoformer_pair_stack_block_{i}', EvoformerPairStackBlock(c_s=c_s, n_heads=n_heads, - c_hidden=c_hidden, - dropout_rate=dropout_rate)) - for i in range(n_blocks)] - ) - )""" self.blocks = nn.ModuleList([EvoformerPairStackBlock(c_s=c_s, n_heads=n_heads, c_hidden=c_hidden, diff --git a/src/models/structure_net.py b/src/models/structure_net.py new file mode 100644 index 0000000..2fee8f1 --- /dev/null +++ b/src/models/structure_net.py @@ -0,0 +1,173 @@ +import torch +from torch import nn + +from src.models.components.invariant_point_attention import InvariantPointAttention +from src.models.components.structure_transition import StructureTransition +from src.models.components.backbone_update import BackboneUpdate +from src.utils.rigid_utils import Rigids +import collections + +# Define the output structure to avoid clutter +Structure = collections.namedtuple('Structure', ['single_rep', 'pair_rep', 'transforms', 'mask']) + + +class StructureLayer(nn.Module): + + def __init__(self, + c_s, + c_z, + c_hidden_ipa, + n_head, + n_qk_point, + n_v_point, + ipa_dropout, + n_structure_transition_layer, + structure_transition_dropout + ): + """Initialize a Structure Layer. + :param c_s: + Single representation channel dimension + :param c_z: + Pair representation channel dimension + :param c_hidden_ipa: + Hidden IPA channel dimension + :param n_head: + Number of attention heads + :param n_qk_point: + Number of query/key points to generate + :param n_v_point: + Number of value points to generate + :param ipa_dropout: + IPA dropout rate + :param n_structure_transition_layer: + Number of structure transition layers + :param structure_transition_dropout: + structure transition dropout rate + """ + super(StructureLayer, self).__init__() + + self.ipa = InvariantPointAttention( + c_s, + c_z, + c_hidden_ipa, + n_head, + n_qk_point, + n_v_point + ) + self.ipa_dropout = nn.Dropout(ipa_dropout) + self.ipa_layer_norm = nn.LayerNorm(c_s) + + # Built-in dropout and layer norm + self.transition = StructureTransition( + c_s, + n_structure_transition_layer, + structure_transition_dropout + ) + + # backbone update TODO: it might be useful to zero the gradients on rotations. + self.bb_update = BackboneUpdate(c_s) + + def forward(self, inputs: Structure) -> Structure: + """Updates a structure by explicitly attending the 3D frames.""" + s, z, t, mask = inputs.single_rep, inputs.pair_rep, \ + inputs.transforms, inputs.mask + s = s + self.ipa(s, z, t, mask) + s = self.ipa_dropout(s) + s = self.ipa_layer_norm(s) + s = self.transition(s) + t = t.compose(self.bb_update(s)) + updated_structure = Structure(s, z, t, mask) + return updated_structure + + +class StructureNet(nn.Module): + + def __init__(self, + c_s: int, + c_z: int, + n_structure_layer: int = 4, + n_structure_block: int = 1, + c_hidden_ipa: int = 16, + n_head_ipa: int = 12, + n_qk_point: int = 4, + n_v_point: int = 8, + ipa_dropout: float = 0.1, + n_structure_transition_layer: int = 1, + structure_transition_dropout: float = 0.1, + ): + """Initializes a structure network. + :param c_s: + Single representation channel dimension + :param c_z: + Pair representation channel dimension + :param n_structure_layer: + Number of structure layers + :param c_hidden_ipa: + Hidden IPA channel dimension (multiplied by the number of heads) + :param n_head_ipa: + Number of attention heads in the IPA + :param n_qk_point: + Number of query/key points to generate + :param n_v_point: + Number of value points to generate + :param ipa_dropout: + IPA dropout rate + :param n_structure_transition_layer: + Number of structure transition layers + :param structure_transition_dropout: + structure transition dropout rate + """ + super(StructureNet, self).__init__() + + self.n_structure_block = n_structure_block + + # Initial projection and layer norms + self.pair_rep_layer_norm = nn.LayerNorm(c_z) + self.single_rep_layer_norm = nn.LayerNorm(c_s) + self.single_rep_linear = nn.Linear(c_s, c_s) + + layers = [ + StructureLayer( + c_s, c_z, + c_hidden_ipa, n_head_ipa, n_qk_point, n_v_point, ipa_dropout, + n_structure_transition_layer, structure_transition_dropout + ) + for _ in range(n_structure_layer) + ] + self.net = nn.Sequential(*layers) + + def forward( + self, + single_rep: torch.Tensor, + pair_rep: torch.Tensor, + transforms: Rigids, + mask: torch.Tensor = None + ) -> Rigids: + """Applies the structure module on the current transforms given single and pair representations. + :param single_rep: + [*, N_res, C_s] single representation + :param pair_rep: + [*, N_res, N_res, C_z] pair representation + :param transforms: + [*, N_res] transformation object + :param mask: + [*, N_res] mask + + :returns + [*, N_res] updated transforms + """ + + # Initial projection and layer norms + single_rep = self.single_rep_layer_norm(single_rep) + single_rep = self.single_rep_linear(single_rep) + pair_rep = self.pair_rep_layer_norm(pair_rep) + + # Initial structure + structure = Structure(single_rep, pair_rep, transforms, mask) + + # Updates with shared weights + for _ in range(self.n_structure_block): + structure = self.net(structure) + + # Return updated transforms + return structure.transforms diff --git a/src/utils/kernel/attention_core.py b/src/utils/kernel/attention_core.py deleted file mode 100644 index 5655a37..0000000 --- a/src/utils/kernel/attention_core.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import importlib -from functools import reduce -from operator import mul - -import torch - -attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda") - -SUPPORTED_DTYPES = [torch.float32, torch.bfloat16] - - -class AttentionCoreFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, v, bias_1=None, bias_2=None): - if bias_1 is None and bias_2 is not None: - raise ValueError("bias_1 must be specified before bias_2") - if q.dtype not in SUPPORTED_DTYPES: - raise ValueError("Unsupported datatype") - - q = q.contiguous() - k = k.contiguous() - - # [*, H, Q, K] - attention_logits = torch.matmul( - q, k.transpose(-1, -2), - ) - - if bias_1 is not None: - attention_logits += bias_1 - if bias_2 is not None: - attention_logits += bias_2 - - attn_core_inplace_cuda.forward_( - attention_logits, - reduce(mul, attention_logits.shape[:-1]), - attention_logits.shape[-1], - ) - - o = torch.matmul(attention_logits, v) - - ctx.bias_1_shape = bias_1.shape if bias_1 is not None else None - ctx.bias_2_shape = bias_2.shape if bias_2 is not None else None - ctx.save_for_backward(q, k, v, attention_logits) - - return o - - @staticmethod - def backward(ctx, grad_output): - q, k, v, attention_logits = ctx.saved_tensors - grad_q = grad_k = grad_v = grad_bias_1 = grad_bias_2 = None - - grad_v = torch.matmul( - attention_logits.transpose(-1, -2), - grad_output - ) - - attn_core_inplace_cuda.backward_( - attention_logits, - grad_output.contiguous(), - v.contiguous(), # v is implicitly transposed in the kernel - reduce(mul, attention_logits.shape[:-1]), - attention_logits.shape[-1], - grad_output.shape[-1], - ) - - if ctx.bias_1_shape is not None: - grad_bias_1 = torch.sum( - attention_logits, - dim=tuple(i for i, d in enumerate(ctx.bias_1_shape) if d == 1), - keepdim=True, - ) - - if ctx.bias_2_shape is not None: - grad_bias_2 = torch.sum( - attention_logits, - dim=tuple(i for i, d in enumerate(ctx.bias_2_shape) if d == 1), - keepdim=True, - ) - - grad_q = torch.matmul( - attention_logits, k - ) - grad_k = torch.matmul( - q.transpose(-1, -2), attention_logits, - ).transpose(-1, -2) - - return grad_q, grad_k, grad_v, grad_bias_1, grad_bias_2 - - -attention_core = AttentionCoreFunction.apply diff --git a/src/utils/kernel/csrc/compat.h b/src/utils/kernel/csrc/compat.h deleted file mode 100644 index bfab6aa..0000000 --- a/src/utils/kernel/csrc/compat.h +++ /dev/null @@ -1,11 +0,0 @@ -// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h - -#ifndef TORCH_CHECK -#define TORCH_CHECK AT_CHECK -#endif - -#ifdef VERSION_GE_1_3 -#define DATA_PTR data_ptr -#else -#define DATA_PTR data -#endif diff --git a/src/utils/kernel/csrc/softmax_cuda.cpp b/src/utils/kernel/csrc/softmax_cuda.cpp deleted file mode 100644 index f31eeec..0000000 --- a/src/utils/kernel/csrc/softmax_cuda.cpp +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2021 AlQuraishi Laboratory -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp - -#include - -void attn_softmax_inplace_forward_( - at::Tensor input, - long long rows, int cols -); -void attn_softmax_inplace_backward_( - at::Tensor output, - at::Tensor d_ov, - at::Tensor values, - long long rows, - int cols_output, - int cols_values -); - - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "forward_", - &attn_softmax_inplace_forward_, - "Softmax forward (CUDA)" - ); - m.def( - "backward_", - &attn_softmax_inplace_backward_, - "Softmax backward (CUDA)" - ); -} diff --git a/src/utils/kernel/csrc/softmax_cuda_kernel.cu b/src/utils/kernel/csrc/softmax_cuda_kernel.cu deleted file mode 100644 index 9850936..0000000 --- a/src/utils/kernel/csrc/softmax_cuda_kernel.cu +++ /dev/null @@ -1,241 +0,0 @@ -// Copyright 2021 AlQuraishi Laboratory -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu - -#include -#include -#include - -#include - -#include "ATen/ATen.h" -#include "ATen/cuda/CUDAContext.h" -#include "compat.h" - -#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -__inline__ __device__ float WarpAllReduceMax(float val) { - for (int mask = 1; mask < 32; mask *= 2) { - val = max(val, __shfl_xor_sync(0xffffffff, val, mask)); - } - return val; -} - -__inline__ __device__ float WarpAllReduceSum(float val) { - for (int mask = 1; mask < 32; mask *= 2) { - val += __shfl_xor_sync(0xffffffff, val, mask); - } - return val; -} - - -template -__global__ void attn_softmax_inplace_( - T *input, - long long rows, int cols -) { - int threadidx_x = threadIdx.x / 32; - int threadidx_y = threadIdx.x % 32; - long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x); - int cols_per_thread = (cols + 31) / 32; - int cols_this_thread = cols_per_thread; - - int last_y = (cols / cols_per_thread); - - if (threadidx_y == last_y) { - cols_this_thread = cols - cols_per_thread * last_y; - } - else if (threadidx_y > last_y) { - cols_this_thread = 0; - } - - float buf[32]; - - int lane_id = threadidx_y; - - if (row_offset < rows) { - T *row_input = input + row_offset * cols; - T *row_output = row_input; - - #pragma unroll - for (int i = 0; i < cols_this_thread; i++) { - int idx = lane_id * cols_per_thread + i; - buf[i] = static_cast(row_input[idx]); - } - - float thread_max = -1 * CUDART_INF_F; - #pragma unroll - for (int i = 0; i < cols_this_thread; i++) { - thread_max = max(thread_max, buf[i]); - } - - float warp_max = WarpAllReduceMax(thread_max); - - float thread_sum = 0.f; - #pragma unroll - for (int i = 0; i < cols_this_thread; i++) { - buf[i] = __expf(buf[i] - warp_max); - thread_sum += buf[i]; - } - - float warp_sum = WarpAllReduceSum(thread_sum); - #pragma unroll - for (int i = 0; i < cols_this_thread; i++) { - row_output[lane_id * cols_per_thread + i] = - static_cast(__fdividef(buf[i], warp_sum)); - } - } -} - - -void attn_softmax_inplace_forward_( - at::Tensor input, - long long rows, int cols -) { - CHECK_INPUT(input); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - - int grid = (rows + 3) / 4; - dim3 block(128); - - if (input.dtype() == torch::kFloat32) { - attn_softmax_inplace_<<>>( - (float *)input.data_ptr(), - rows, cols - ); - } - else { - attn_softmax_inplace_<<>>( - (at::BFloat16 *)input.data_ptr(), - rows, cols - ); - } -} - - -template -__global__ void attn_softmax_inplace_grad_( - T *output, - T *d_ov, - T *values, - long long rows, - int cols_output, - int cols_values -) { - int threadidx_x = threadIdx.x / 32; - int threadidx_y = threadIdx.x % 32; - long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x); - int cols_per_thread = (cols_output + 31) / 32; - int cols_this_thread = cols_per_thread; - int rows_values = cols_output; - // values are set to the beginning of the current - // rows_values x cols_values leaf matrix - long long value_row_offset = row_offset - row_offset % rows_values; - int last_y = (cols_output / cols_per_thread); - - if (threadidx_y == last_y) { - cols_this_thread = cols_output - cols_per_thread * last_y; - } - else if (threadidx_y > last_y) { - cols_this_thread = 0; - } - - float y_buf[32]; - float dy_buf[32]; - - int lane_id = threadidx_y; - - if (row_offset < rows) { - T *row_output = output + row_offset * cols_output; - T *row_d_ov = d_ov + row_offset * cols_values; - T *row_values = values + value_row_offset * cols_values; - - float thread_max = -1 * CUDART_INF_F; - - // Compute a chunk of the output gradient on the fly - int value_row_idx = 0; - int value_idx = 0; - #pragma unroll - for (int i = 0; i < cols_this_thread; i++) { - T sum = 0.; - #pragma unroll - for (int j = 0; j < cols_values; j++) { - value_row_idx = ((lane_id * cols_per_thread) + i); - value_idx = value_row_idx * cols_values + j; - sum += row_d_ov[j] * row_values[value_idx]; - } - dy_buf[i] = static_cast(sum); - } - - #pragma unroll - for (int i = 0; i < cols_this_thread; i++) { - y_buf[i] = static_cast(row_output[lane_id * cols_per_thread + i]); - } - - float thread_sum = 0.; - - #pragma unroll - for (int i = 0; i < cols_this_thread; i++) { - thread_sum += y_buf[i] * dy_buf[i]; - } - - float warp_sum = WarpAllReduceSum(thread_sum); - - #pragma unroll - for (int i = 0; i < cols_this_thread; i++) { - row_output[lane_id * cols_per_thread + i] = static_cast( - (dy_buf[i] - warp_sum) * y_buf[i] - ); - } - } -} - - -void attn_softmax_inplace_backward_( - at::Tensor output, - at::Tensor d_ov, - at::Tensor values, - long long rows, - int cols_output, - int cols_values -) { - CHECK_INPUT(output); - CHECK_INPUT(d_ov); - CHECK_INPUT(values); - const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); - - int grid = (rows + 3) / 4; - dim3 block(128); - - if (output.dtype() == torch::kFloat32) { - attn_softmax_inplace_grad_<<>>( - (float *)output.data_ptr(), - (float *)d_ov.data_ptr(), - (float *)values.data_ptr(), - rows, cols_output, cols_values - ); - } else { - attn_softmax_inplace_grad_<<>>( - (at::BFloat16 *)output.data_ptr(), - (at::BFloat16 *)d_ov.data_ptr(), - (at::BFloat16 *)values.data_ptr(), - rows, cols_output, cols_values - ); - } -} diff --git a/src/utils/kernel/csrc/softmax_cuda_stub.cpp b/src/utils/kernel/csrc/softmax_cuda_stub.cpp deleted file mode 100644 index 4539c19..0000000 --- a/src/utils/kernel/csrc/softmax_cuda_stub.cpp +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2021 AlQuraishi Laboratory -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp - -#include - -void attn_softmax_inplace_forward_( - at::Tensor input, - long long rows, int cols -) -{ - throw std::runtime_error("attn_softmax_inplace_forward_ not implemented on CPU"); -}; -void attn_softmax_inplace_backward_( - at::Tensor output, - at::Tensor d_ov, - at::Tensor values, - long long rows, - int cols_output, - int cols_values -) -{ - throw std::runtime_error("attn_softmax_inplace_backward_ not implemented on CPU"); -}; \ No newline at end of file diff --git a/src/utils/rigid_utils.py b/src/utils/rigid_utils.py index 0a760e6..c98778b 100644 --- a/src/utils/rigid_utils.py +++ b/src/utils/rigid_utils.py @@ -282,14 +282,14 @@ def invert_quat(quat: torch.Tensor): return inv -class Rotation: +class Rotations: """ A 3D rotation. Depending on how the object is initialized, the rotation is represented by either a rotation matrix or a quaternion, though both formats are made available by helper functions. To simplify gradient computation, the underlying format of the - rotation cannot be changed in-place. Like Rigid, the class is designed - to mimic the behavior of a torch Tensor, almost as if each Rotation + rotation cannot be changed in-place. Like Rigids, the class is designed + to mimic the behavior of a torch Tensor, almost as if each Rotations object were a tensor of rotations, in one format or another. """ @@ -338,13 +338,13 @@ def identity( device: Optional[torch.device] = None, requires_grad: bool = True, fmt: str = "quat", - ) -> Rotation: + ) -> Rotations: """ - Returns an identity Rotation. + Returns an identity Rotations. Args: shape: - The "shape" of the resulting Rotation object. See documentation + The "shape" of the resulting Rotations object. See documentation for the shape property dtype: The torch dtype for the rotation @@ -363,16 +363,16 @@ def identity( rot_mats = identity_rot_mats( shape, dtype, device, requires_grad, ) - return Rotation(rot_mats=rot_mats, quats=None) + return Rotations(rot_mats=rot_mats, quats=None) elif fmt == "quat": quats = identity_quats(shape, dtype, device, requires_grad) - return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + return Rotations(rot_mats=None, quats=quats, normalize_quats=False) else: raise ValueError(f"Invalid format: f{fmt}") # Magic methods - def __getitem__(self, index: Any) -> Rotation: + def __getitem__(self, index: Any) -> Rotations: """ Allows torch-style indexing over the virtual shape of the rotation object. See documentation for the shape property. @@ -388,19 +388,19 @@ def __getitem__(self, index: Any) -> Rotation: if self._rot_mats is not None: rot_mats = self._rot_mats[index + (slice(None), slice(None))] - return Rotation(rot_mats=rot_mats) + return Rotations(rot_mats=rot_mats) elif self._quats is not None: quats = self._quats[index + (slice(None),)] - return Rotation(quats=quats, normalize_quats=False) + return Rotations(quats=quats, normalize_quats=False) else: raise ValueError("Both rotations are None") def __mul__(self, right: torch.Tensor, - ) -> Rotation: + ) -> Rotations: """ Pointwise left multiplication of the rotation with a tensor. Can be - used to e.g. mask the Rotation. + used to e.g. mask the Rotations. Args: right: @@ -413,16 +413,16 @@ def __mul__(self, if self._rot_mats is not None: rot_mats = self._rot_mats * right[..., None, None] - return Rotation(rot_mats=rot_mats, quats=None) + return Rotations(rot_mats=rot_mats, quats=None) elif self._quats is not None: quats = self._quats * right[..., None] - return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + return Rotations(rot_mats=None, quats=quats, normalize_quats=False) else: raise ValueError("Both rotations are None") def __rmul__(self, left: torch.Tensor, - ) -> Rotation: + ) -> Rotations: """ Reverse pointwise multiplication of the rotation with a tensor. @@ -441,7 +441,7 @@ def shape(self) -> torch.Size: """ Returns the virtual shape of the rotation object. This shape is defined as the batch dimensions of the underlying rotation matrix - or quaternion. If the Rotation was initialized with a [10, 3, 3] + or quaternion. If the Rotations was initialized with a [10, 3, 3] rotation matrix tensor, for example, the resulting shape would be [10]. @@ -521,7 +521,7 @@ def get_quats(self) -> torch.Tensor: """ Returns the underlying rotation as a quaternion tensor. - Depending on whether the Rotation was initialized with a + Depending on whether the Rotations was initialized with a quaternion, this function may call torch.linalg.eigh. Returns: @@ -550,14 +550,14 @@ def get_cur_rot(self) -> torch.Tensor: else: raise ValueError("Both rotations are None") - # Rotation functions + # Rotations functions def compose_q_update_vec(self, q_update_vec: torch.Tensor, normalize_quats: bool = True - ) -> Rotation: + ) -> Rotations: """ - Returns a new quaternion Rotation after updating the current + Returns a new quaternion Rotations after updating the current object's underlying rotation with a quaternion update, formatted as a [*, 3] tensor whose final three columns represent x, y, z such that (1, x, y, z) is the desired (not necessarily unit) quaternion @@ -569,19 +569,19 @@ def compose_q_update_vec(self, normalize_quats: Whether to normalize the output quaternion Returns: - An updated Rotation + An updated Rotations """ quats = self.get_quats() new_quats = quats + quat_multiply_by_vec(quats, q_update_vec) - return Rotation( + return Rotations( rot_mats=None, quats=new_quats, normalize_quats=normalize_quats, ) - def compose_r(self, r: Rotation) -> Rotation: + def compose_r(self, r: Rotations) -> Rotations: """ - Compose the rotation matrices of the current Rotation object with + Compose the rotation matrices of the current Rotations object with those of another. Args: @@ -593,14 +593,14 @@ def compose_r(self, r: Rotation) -> Rotation: r1 = self.get_rot_mats() r2 = r.get_rot_mats() new_rot_mats = rot_matmul(r1, r2) - return Rotation(rot_mats=new_rot_mats, quats=None) + return Rotations(rot_mats=new_rot_mats, quats=None) - def compose_q(self, r: Rotation, normalize_quats: bool = True) -> Rotation: + def compose_q(self, r: Rotations, normalize_quats: bool = True) -> Rotations: """ - Compose the quaternions of the current Rotation object with those + Compose the quaternions of the current Rotations object with those of another. - Depending on whether either Rotation was initialized with + Depending on whether either Rotations was initialized with quaternions, this function may call torch.linalg.eigh. Args: @@ -612,13 +612,13 @@ def compose_q(self, r: Rotation, normalize_quats: bool = True) -> Rotation: q1 = self.get_quats() q2 = r.get_quats() new_quats = quat_multiply(q1, q2) - return Rotation( + return Rotations( rot_mats=None, quats=new_quats, normalize_quats=normalize_quats ) def apply(self, pts: torch.Tensor) -> torch.Tensor: """ - Apply the current Rotation as a rotation matrix to a set of 3D + Apply the current Rotations as a rotation matrix to a set of 3D coordinates. Args: @@ -644,20 +644,20 @@ def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: inv_rot_mats = invert_rot_mat(rot_mats) return rot_vec_mul(inv_rot_mats, pts) - def invert(self) -> Rotation: + def invert(self) -> Rotations: """ - Returns the inverse of the current Rotation. + Returns the inverse of the current Rotations. Returns: - The inverse of the current Rotation + The inverse of the current Rotations """ if self._rot_mats is not None: - return Rotation( + return Rotations( rot_mats=invert_rot_mat(self._rot_mats), quats=None ) elif self._quats is not None: - return Rotation( + return Rotations( rot_mats=None, quats=invert_quat(self._quats), normalize_quats=False, @@ -669,33 +669,33 @@ def invert(self) -> Rotation: def unsqueeze(self, dim: int, - ) -> Rigid: + ) -> Rigids: """ Analogous to torch.unsqueeze. The dimension is relative to the - shape of the Rotation object. + shape of the Rotations object. Args: dim: A positive or negative dimension index. Returns: - The unsqueezed Rotation. + The unsqueezed Rotations. """ if dim >= len(self.shape): raise ValueError("Invalid dimension") if self._rot_mats is not None: rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2) - return Rotation(rot_mats=rot_mats, quats=None) + return Rotations(rot_mats=rot_mats, quats=None) elif self._quats is not None: quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1) - return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + return Rotations(rot_mats=None, quats=quats, normalize_quats=False) else: raise ValueError("Both rotations are None") @staticmethod def cat( - rs: Sequence[Rotation], + rs: Sequence[Rotations], dim: int, - ) -> Rigid: + ) -> Rigids: """ Concatenates rotations along one of the batch dimensions. Analogous to torch.cat(). @@ -710,16 +710,16 @@ def cat( The dimension along which the rotations should be concatenated Returns: - A concatenated Rotation object in rotation matrix format + A concatenated Rotations object in rotation matrix format """ rot_mats = [r.get_rot_mats() for r in rs] rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2) - return Rotation(rot_mats=rot_mats, quats=None) + return Rotations(rot_mats=rot_mats, quats=None) def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor] - ) -> Rotation: + ) -> Rotations: """ Apply a Tensor -> Tensor function to underlying rotation tensors, mapping over the rotation dimension(s). Can be used e.g. to sum out @@ -727,9 +727,9 @@ def map_tensor_fn(self, Args: fn: - A Tensor -> Tensor function to be mapped over the Rotation + A Tensor -> Tensor function to be mapped over the Rotations Returns: - The transformed Rotation object + The transformed Rotations object """ if self._rot_mats is not None: rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,)) @@ -737,26 +737,26 @@ def map_tensor_fn(self, list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1 ) rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3)) - return Rotation(rot_mats=rot_mats, quats=None) + return Rotations(rot_mats=rot_mats, quats=None) elif self._quats is not None: quats = torch.stack( list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1 ) - return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + return Rotations(rot_mats=None, quats=quats, normalize_quats=False) else: raise ValueError("Both rotations are None") - def cuda(self) -> Rotation: + def cuda(self) -> Rotations: """ Analogous to the cuda() method of torch Tensors Returns: - A copy of the Rotation in CUDA memory + A copy of the Rotations in CUDA memory """ if self._rot_mats is not None: - return Rotation(rot_mats=self._rot_mats.cuda(), quats=None) + return Rotations(rot_mats=self._rot_mats.cuda(), quats=None) elif self._quats is not None: - return Rotation( + return Rotations( rot_mats=None, quats=self._quats.cuda(), normalize_quats=False @@ -767,7 +767,7 @@ def cuda(self) -> Rotation: def to(self, device: Optional[torch.device], dtype: Optional[torch.dtype] - ) -> Rotation: + ) -> Rotations: """ Analogous to the to() method of torch Tensors @@ -777,15 +777,15 @@ def to(self, dtype: A torch dtype Returns: - A copy of the Rotation using the new device and dtype + A copy of the Rotations using the new device and dtype """ if self._rot_mats is not None: - return Rotation( + return Rotations( rot_mats=self._rot_mats.to(device=device, dtype=dtype), quats=None, ) elif self._quats is not None: - return Rotation( + return Rotations( rot_mats=None, quats=self._quats.to(device=device, dtype=dtype), normalize_quats=False, @@ -793,19 +793,19 @@ def to(self, else: raise ValueError("Both rotations are None") - def detach(self) -> Rotation: + def detach(self) -> Rotations: """ - Returns a copy of the Rotation whose underlying Tensor has been + Returns a copy of the Rotations whose underlying Tensor has been detached from its torch graph. Returns: - A copy of the Rotation whose underlying Tensor has been detached + A copy of the Rotations whose underlying Tensor has been detached from its torch graph """ if self._rot_mats is not None: - return Rotation(rot_mats=self._rot_mats.detach(), quats=None) + return Rotations(rot_mats=self._rot_mats.detach(), quats=None) elif self._quats is not None: - return Rotation( + return Rotations( rot_mats=None, quats=self._quats.detach(), normalize_quats=False, @@ -814,16 +814,16 @@ def detach(self) -> Rotation: raise ValueError("Both rotations are None") -class Rigid: +class Rigids: """ - A class representing a rigid transformation. Little more than a wrapper - around two objects: a Rotation object and a [*, 3] translation + A class representing rigid transformations. Little more than a wrapper + around two objects: a Rotations object and a [*, 3] translation Designed to behave approximately like a single torch tensor with the shape of the shared batch dimensions of its component parts. """ def __init__(self, - rots: Optional[Rotation], + rots: Optional[Rotations], trans: Optional[torch.Tensor], ): """ @@ -848,7 +848,7 @@ def __init__(self, raise ValueError("At least one input argument must be specified") if rots is None: - rots = Rotation.identity( + rots = Rotations.identity( batch_dims, dtype, device, requires_grad, ) elif trans is None: @@ -873,7 +873,7 @@ def identity( device: Optional[torch.device] = None, requires_grad: bool = True, fmt: str = "quat", - ) -> Rigid: + ) -> Rigids: """ Constructs an identity transformation. @@ -889,14 +889,14 @@ def identity( Returns: The identity transformation """ - return Rigid( - Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt), + return Rigids( + Rotations.identity(shape, dtype, device, requires_grad, fmt=fmt), identity_trans(shape, dtype, device, requires_grad), ) def __getitem__(self, index: Any, - ) -> Rigid: + ) -> Rigids: """ Indexes the affine transformation with PyTorch-style indices. The index is applied to the shared dimensions of both the rotation @@ -904,7 +904,7 @@ def __getitem__(self, E.g.:: - r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None) + r = Rotations(rot_mats=torch.rand(10, 10, 3, 3), quats=None) t = Rigid(r, torch.rand(10, 10, 3)) indexed = t[3, 4:6] assert(indexed.shape == (2,)) @@ -920,14 +920,14 @@ def __getitem__(self, if type(index) != tuple: index = (index,) - return Rigid( + return Rigids( self._rots[index], self._trans[index + (slice(None),)], ) def __mul__(self, right: torch.Tensor, - ) -> Rigid: + ) -> Rigids: """ Pointwise left multiplication of the transformation with a tensor. Can be used to e.g. mask the Rigid. @@ -944,11 +944,11 @@ def __mul__(self, new_rots = self._rots * right new_trans = self._trans * right[..., None] - return Rigid(new_rots, new_trans) + return Rigids(new_rots, new_trans) def __rmul__(self, left: torch.Tensor, - ) -> Rigid: + ) -> Rigids: """ Reverse pointwise multiplication of the transformation with a tensor. @@ -983,7 +983,7 @@ def device(self) -> torch.device: """ return self._trans.device - def get_rots(self) -> Rotation: + def get_rots(self) -> Rotations: """ Getter for the rotation. @@ -1003,7 +1003,7 @@ def get_trans(self) -> torch.Tensor: def compose_q_update_vec(self, q_update_vec: torch.Tensor, - ) -> Rigid: + ) -> Rigids: """ Composes the transformation with a quaternion update vector of shape [*, 6], where the final 6 columns represent the x, y, and @@ -1021,11 +1021,11 @@ def compose_q_update_vec(self, trans_update = self._rots.apply(t_vec) new_translation = self._trans + trans_update - return Rigid(new_rots, new_translation) + return Rigids(new_rots, new_translation) def compose(self, - r: Rigid, - ) -> Rigid: + r: Rigids, + ) -> Rigids: """ Composes the current rigid object with another. @@ -1037,7 +1037,7 @@ def compose(self, """ new_rot = self._rots.compose_r(r._rots) new_trans = self._rots.apply(r._trans) + self._trans - return Rigid(new_rot, new_trans) + return Rigids(new_rot, new_trans) def apply(self, pts: torch.Tensor, @@ -1067,7 +1067,7 @@ def invert_apply(self, pts = pts - self._trans return self._rots.invert_apply(pts) - def invert(self) -> Rigid: + def invert(self) -> Rigids: """ Inverts the transformation. @@ -1077,11 +1077,11 @@ def invert(self) -> Rigid: rot_inv = self._rots.invert() trn_inv = rot_inv.apply(self._trans) - return Rigid(rot_inv, -1 * trn_inv) + return Rigids(rot_inv, -1 * trn_inv) def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor] - ) -> Rigid: + ) -> Rigids: """ Apply a Tensor -> Tensor function to underlying translation and rotation tensors, mapping over the translation/rotation dimensions @@ -1099,7 +1099,7 @@ def map_tensor_fn(self, dim=-1 ) - return Rigid(new_rots, new_trans) + return Rigids(new_rots, new_trans) def to_tensor_4x4(self) -> torch.Tensor: """ @@ -1117,7 +1117,7 @@ def to_tensor_4x4(self) -> torch.Tensor: @staticmethod def from_tensor_4x4( t: torch.Tensor - ) -> Rigid: + ) -> Rigids: """ Constructs a transformation from a homogenous transformation tensor. @@ -1130,10 +1130,10 @@ def from_tensor_4x4( if t.shape[-2:] != (4, 4): raise ValueError("Incorrectly shaped input tensor") - rots = Rotation(rot_mats=t[..., :3, :3], quats=None) + rots = Rotations(rot_mats=t[..., :3, :3], quats=None) trans = t[..., :3, 3] - return Rigid(rots, trans) + return Rigids(rots, trans) def to_tensor_7(self) -> torch.Tensor: """ @@ -1153,19 +1153,19 @@ def to_tensor_7(self) -> torch.Tensor: def from_tensor_7( t: torch.Tensor, normalize_quats: bool = False, - ) -> Rigid: + ) -> Rigids: if t.shape[-1] != 7: raise ValueError("Incorrectly shaped input tensor") quats, trans = t[..., :4], t[..., 4:] - rots = Rotation( + rots = Rotations( rot_mats=None, quats=quats, normalize_quats=normalize_quats ) - return Rigid(rots, trans) + return Rigids(rots, trans) @staticmethod def from_3_points( @@ -1173,7 +1173,7 @@ def from_3_points( origin: torch.Tensor, p_xy_plane: torch.Tensor, eps: float = 1e-8 - ) -> Rigid: + ) -> Rigids: """ Implements algorithm 21. Constructs transformations from sets of 3 points using the Gram-Schmidt algorithm. @@ -1208,13 +1208,13 @@ def from_3_points( rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1) rots = rots.reshape(rots.shape[:-1] + (3, 3)) - rot_obj = Rotation(rot_mats=rots, quats=None) + rot_obj = Rotations(rot_mats=rots, quats=None) - return Rigid(rot_obj, torch.stack(origin, dim=-1)) + return Rigids(rot_obj, torch.stack(origin, dim=-1)) def unsqueeze(self, dim: int, - ) -> Rigid: + ) -> Rigids: """ Analogous to torch.unsqueeze. The dimension is relative to the shared dimensions of the rotation/translation. @@ -1229,13 +1229,13 @@ def unsqueeze(self, rots = self._rots.unsqueeze(dim) trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1) - return Rigid(rots, trans) + return Rigids(rots, trans) @staticmethod def cat( - ts: Sequence[Rigid], + ts: Sequence[Rigids], dim: int, - ) -> Rigid: + ) -> Rigids: """ Concatenates transformations along a new dimension. @@ -1248,26 +1248,26 @@ def cat( Returns: A concatenated transformation object """ - rots = Rotation.cat([t._rots for t in ts], dim) + rots = Rotations.cat([t._rots for t in ts], dim) trans = torch.cat( [t._trans for t in ts], dim=dim if dim >= 0 else dim - 1 ) - return Rigid(rots, trans) + return Rigids(rots, trans) - def apply_rot_fn(self, fn: Callable[Rotation, Rotation]) -> Rigid: + def apply_rot_fn(self, fn: Callable[Rotations, Rotations]) -> Rigids: """ - Applies a Rotation -> Rotation function to the stored rotation + Applies a Rotations -> Rotations function to the stored rotation object. Args: - fn: A function of type Rotation -> Rotation + fn: A function of type Rotations -> Rotations Returns: A transformation object with a transformed rotation. """ - return Rigid(fn(self._rots), self._trans) + return Rigids(fn(self._rots), self._trans) - def apply_trans_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid: + def apply_trans_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigids: """ Applies a Tensor -> Tensor function to the stored translation. @@ -1278,9 +1278,9 @@ def apply_trans_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid: Returns: A transformation object with a transformed translation. """ - return Rigid(self._rots, fn(self._trans)) + return Rigids(self._rots, fn(self._trans)) - def scale_translation(self, trans_scale_factor: float) -> Rigid: + def scale_translation(self, trans_scale_factor: float) -> Rigids: """ Scales the translation by a constant factor. @@ -1293,7 +1293,7 @@ def scale_translation(self, trans_scale_factor: float) -> Rigid: fn = lambda t: t * trans_scale_factor return self.apply_trans_fn(fn) - def stop_rot_gradient(self) -> Rigid: + def stop_rot_gradient(self) -> Rigids: """ Detaches the underlying rotation object @@ -1372,15 +1372,15 @@ def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20): rots = rots.transpose(-1, -2) translation = -1 * translation - rot_obj = Rotation(rot_mats=rots, quats=None) + rot_obj = Rotations(rot_mats=rots, quats=None) - return Rigid(rot_obj, translation) + return Rigids(rot_obj, translation) - def cuda(self) -> Rigid: + def cuda(self) -> Rigids: """ Moves the transformation object to GPU memory Returns: A version of the transformation on GPU """ - return Rigid(self._rots.cuda(), self._trans.cuda()) + return Rigids(self._rots.cuda(), self._trans.cuda()) diff --git a/tests/config.py b/tests/config.py new file mode 100644 index 0000000..d6bef73 --- /dev/null +++ b/tests/config.py @@ -0,0 +1,34 @@ +import ml_collections as mlc + +consts = mlc.ConfigDict( + { + "batch_size": 2, + "n_res": 11, + "n_seq": 13, + "n_templ": 3, + "n_extra": 17, + "n_heads_extra_msa": 8, + "eps": 5e-4, + # For compatibility with DeepMind's pretrained weights, it's easiest for + # everyone if these take their real values. + "c_m": 256, + "c_z": 128, + "c_s": 384, + "c_t": 64, + "c_e": 64, + } +) + +config = mlc.ConfigDict( + { + "data": { + "common": { + "masked_msa": { + "profile_prob": 0.1, + "same_prob": 0.1, + "uniform_prob": 0.1, + }, + } + } + } +) diff --git a/tests/test_evoformer_pair_stack.py b/tests/test_evoformer_pair_stack.py index 261c6fc..f42f63b 100644 --- a/tests/test_evoformer_pair_stack.py +++ b/tests/test_evoformer_pair_stack.py @@ -1,10 +1,9 @@ import unittest import torch -import numpy as np from prettytable import PrettyTable from tests.config import consts -from src.models.components.evoformer_pair_stack import ( +from src.models.evoformer_pair_stack import ( EvoformerPairStackBlock, EvoformerPairStack ) diff --git a/tests/test_invariant_point_attention.py b/tests/test_invariant_point_attention.py index 603f46d..84f5c48 100644 --- a/tests/test_invariant_point_attention.py +++ b/tests/test_invariant_point_attention.py @@ -2,7 +2,7 @@ import numpy as np import unittest from src.models.components.invariant_point_attention import InvariantPointAttention -from src.utils.rigid_utils import Rotation, Rigid +from src.utils.rigid_utils import Rotations, Rigids class TestInvariantPointAttention(unittest.TestCase): @@ -22,10 +22,10 @@ def test_shape(self): mask = torch.ones((batch_size, n_res)) rot_mats = torch.rand((batch_size, n_res, 3, 3)) - rots = Rotation(rot_mats=rot_mats, quats=None) + rots = Rotations(rot_mats=rot_mats, quats=None) trans = torch.rand((batch_size, n_res, 3)) - r = Rigid(rots, trans) + r = Rigids(rots, trans) ipa = InvariantPointAttention( c_m, c_z, c_hidden, no_heads, no_qp, no_vp diff --git a/tests/test_structure_net.py b/tests/test_structure_net.py new file mode 100644 index 0000000..d244536 --- /dev/null +++ b/tests/test_structure_net.py @@ -0,0 +1,52 @@ +import unittest +import torch +from prettytable import PrettyTable + +from tests.config import consts +from src.models.structure_net import StructureNet +from src.utils.rigid_utils import Rigids + + +class TestStructureNet(unittest.TestCase): + def test_shape(self): + c_s = 128 + n_heads = 4 + c_hidden = 128 + dropout_rate = 0.25 + n_blocks = 4 + net = StructureNet(c_s=128, c_z=128) + batch_size = consts.batch_size + n_res = consts.n_res + s = torch.rand((batch_size, n_res, c_s)) + z = torch.rand((batch_size, n_res, n_res, c_s)) + transforms = Rigids.identity((2, 11)) + + shape_before = transforms.shape + mask = torch.ones((batch_size, n_res)) + output = net(s, z, transforms, mask) + shape_after = output.shape + + self.assertTrue(shape_before == shape_after) + + def test_params(self): + """A method to check the number of parameters in the stack.""" + def count_parameters(model): + table = PrettyTable(["Modules", "Parameters"]) + total_params = 0 + for name, parameter in model.named_parameters(): + if not parameter.requires_grad: + continue + params = parameter.numel() + table.add_row([name, params]) + total_params += params + print(table) + print(f"Total Trainable Params: {total_params}") + return total_params + + net = StructureNet(c_s=128, c_z=128) + count_parameters(net) + self.assertTrue(True) + + +if __name__ == '__main__': + unittest.main()