Skip to content

Commit

Permalink
Code improvements, implemented Structure net
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Nov 23, 2023
1 parent 67151bb commit e73447f
Show file tree
Hide file tree
Showing 18 changed files with 395 additions and 576 deletions.
1 change: 1 addition & 0 deletions configs/logger/wandb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ wandb:
id: null # pass correct id to resume experiment!
anonymous: null # enable anonymous logging
project: "lightning-hydra-template"
entity: "ligo-technologies"
log_model: False # upload lightning ckpts
prefix: "" # a string to put at the beginning of metric keys
# entity: "" # set to name of your wandb team
Expand Down
7 changes: 3 additions & 4 deletions src/data/protein_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ class Reorder(torch.nn.Module):
"""A transformation that reorders the 3D coordinates of backbone atoms
from N, C, Ca, O -> N, Ca, C, O."""
def forward(self, protein_dict):
if 'reordered' not in protein_dict.keys():
# If not already reordered, switch to N, Ca, C, ordering.
reordered_X = protein_dict['X'].index_select(1, torch.tensor([0, 2, 1, 3]))
protein_dict['X'] = reordered_X
# Switch to N, Ca, C, ordering.
reordered_X = protein_dict['X'].index_select(1, torch.tensor([0, 2, 1, 3]))
protein_dict['X'] = reordered_X
return protein_dict


Expand Down
6 changes: 3 additions & 3 deletions src/models/components/backbone_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch import nn

from src.models.components.primitives import Linear
from src.utils.rigid_utils import quat_to_rot, Rigid
from src.utils.rigid_utils import Rigids, Rotations


class BackboneUpdate(nn.Module):
Expand Down Expand Up @@ -62,6 +62,6 @@ def forward(self, s):
quats = quats / norm_denominator.unsqueeze(-1)

# [*, 3, 3]
rots = quat_to_rot(quats)
rots = Rotations(quats=quats)

return Rigid(rots, trans)
return Rigids(rots, trans)
4 changes: 2 additions & 2 deletions src/models/components/invariant_point_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from typing import Optional, Tuple, Sequence

from src.utils.precision_utils import is_fp16_enabled
from src.utils.rigid_utils import Rotation, Rigid
from src.utils.rigid_utils import Rotations, Rigids

from src.models.components.primitives import Linear, ipa_point_weights_init_
from src.utils.tensor_utils import (
Expand Down Expand Up @@ -108,7 +108,7 @@ def forward(
self,
s: torch.Tensor,
z: Optional[torch.Tensor],
r: Rigid,
r: Rigids,
mask: torch.Tensor,
inplace_safe: bool = False,
_offload_inference: bool = False,
Expand Down
6 changes: 4 additions & 2 deletions src/models/components/structure_transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: I suspect that this module can easily be deleted. It is a simple feedforward nonlinear transition.

import torch.nn as nn
from src.models.components.primitives import Linear
Expand Down Expand Up @@ -45,7 +44,10 @@ def forward(self, s):


class StructureTransition(nn.Module):
def __init__(self, c, num_layers, dropout_rate):
def __init__(self,
c,
num_layers: int = 1,
dropout_rate: float = 0.1):
super(StructureTransition, self).__init__()

self.c = c
Expand Down
7 changes: 5 additions & 2 deletions src/models/components/triangular_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,11 @@ def forward(self,
return x


# Implements Algorithm 13
TriangleAttentionStartingNode = TriangleAttention
class TriangleAttentionStartingNode(TriangleAttention):
"""
Implements Algorithm 13.
"""
__init__ = partialmethod(TriangleAttention.__init__, starting=True)


class TriangleAttentionEndingNode(TriangleAttention):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,6 @@ def __init__(
dropout_rate: float = 0.25
):
super().__init__()
"""
self.blocks = nn.Sequential(
OrderedDict(
[(f'evoformer_pair_stack_block_{i}', EvoformerPairStackBlock(c_s=c_s, n_heads=n_heads,
c_hidden=c_hidden,
dropout_rate=dropout_rate))
for i in range(n_blocks)]
)
)"""
self.blocks = nn.ModuleList([EvoformerPairStackBlock(c_s=c_s,
n_heads=n_heads,
c_hidden=c_hidden,
Expand Down
173 changes: 173 additions & 0 deletions src/models/structure_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import torch
from torch import nn

from src.models.components.invariant_point_attention import InvariantPointAttention
from src.models.components.structure_transition import StructureTransition
from src.models.components.backbone_update import BackboneUpdate
from src.utils.rigid_utils import Rigids
import collections

# Define the output structure to avoid clutter
Structure = collections.namedtuple('Structure', ['single_rep', 'pair_rep', 'transforms', 'mask'])


class StructureLayer(nn.Module):

def __init__(self,
c_s,
c_z,
c_hidden_ipa,
n_head,
n_qk_point,
n_v_point,
ipa_dropout,
n_structure_transition_layer,
structure_transition_dropout
):
"""Initialize a Structure Layer.
:param c_s:
Single representation channel dimension
:param c_z:
Pair representation channel dimension
:param c_hidden_ipa:
Hidden IPA channel dimension
:param n_head:
Number of attention heads
:param n_qk_point:
Number of query/key points to generate
:param n_v_point:
Number of value points to generate
:param ipa_dropout:
IPA dropout rate
:param n_structure_transition_layer:
Number of structure transition layers
:param structure_transition_dropout:
structure transition dropout rate
"""
super(StructureLayer, self).__init__()

self.ipa = InvariantPointAttention(
c_s,
c_z,
c_hidden_ipa,
n_head,
n_qk_point,
n_v_point
)
self.ipa_dropout = nn.Dropout(ipa_dropout)
self.ipa_layer_norm = nn.LayerNorm(c_s)

# Built-in dropout and layer norm
self.transition = StructureTransition(
c_s,
n_structure_transition_layer,
structure_transition_dropout
)

# backbone update TODO: it might be useful to zero the gradients on rotations.
self.bb_update = BackboneUpdate(c_s)

def forward(self, inputs: Structure) -> Structure:
"""Updates a structure by explicitly attending the 3D frames."""
s, z, t, mask = inputs.single_rep, inputs.pair_rep, \
inputs.transforms, inputs.mask
s = s + self.ipa(s, z, t, mask)
s = self.ipa_dropout(s)
s = self.ipa_layer_norm(s)
s = self.transition(s)
t = t.compose(self.bb_update(s))
updated_structure = Structure(s, z, t, mask)
return updated_structure


class StructureNet(nn.Module):

def __init__(self,
c_s: int,
c_z: int,
n_structure_layer: int = 4,
n_structure_block: int = 1,
c_hidden_ipa: int = 16,
n_head_ipa: int = 12,
n_qk_point: int = 4,
n_v_point: int = 8,
ipa_dropout: float = 0.1,
n_structure_transition_layer: int = 1,
structure_transition_dropout: float = 0.1,
):
"""Initializes a structure network.
:param c_s:
Single representation channel dimension
:param c_z:
Pair representation channel dimension
:param n_structure_layer:
Number of structure layers
:param c_hidden_ipa:
Hidden IPA channel dimension (multiplied by the number of heads)
:param n_head_ipa:
Number of attention heads in the IPA
:param n_qk_point:
Number of query/key points to generate
:param n_v_point:
Number of value points to generate
:param ipa_dropout:
IPA dropout rate
:param n_structure_transition_layer:
Number of structure transition layers
:param structure_transition_dropout:
structure transition dropout rate
"""
super(StructureNet, self).__init__()

self.n_structure_block = n_structure_block

# Initial projection and layer norms
self.pair_rep_layer_norm = nn.LayerNorm(c_z)
self.single_rep_layer_norm = nn.LayerNorm(c_s)
self.single_rep_linear = nn.Linear(c_s, c_s)

layers = [
StructureLayer(
c_s, c_z,
c_hidden_ipa, n_head_ipa, n_qk_point, n_v_point, ipa_dropout,
n_structure_transition_layer, structure_transition_dropout
)
for _ in range(n_structure_layer)
]
self.net = nn.Sequential(*layers)

def forward(
self,
single_rep: torch.Tensor,
pair_rep: torch.Tensor,
transforms: Rigids,
mask: torch.Tensor = None
) -> Rigids:
"""Applies the structure module on the current transforms given single and pair representations.
:param single_rep:
[*, N_res, C_s] single representation
:param pair_rep:
[*, N_res, N_res, C_z] pair representation
:param transforms:
[*, N_res] transformation object
:param mask:
[*, N_res] mask
:returns
[*, N_res] updated transforms
"""

# Initial projection and layer norms
single_rep = self.single_rep_layer_norm(single_rep)
single_rep = self.single_rep_linear(single_rep)
pair_rep = self.pair_rep_layer_norm(pair_rep)

# Initial structure
structure = Structure(single_rep, pair_rep, transforms, mask)

# Updates with shared weights
for _ in range(self.n_structure_block):
structure = self.net(structure)

# Return updated transforms
return structure.transforms
103 changes: 0 additions & 103 deletions src/utils/kernel/attention_core.py

This file was deleted.

11 changes: 0 additions & 11 deletions src/utils/kernel/csrc/compat.h

This file was deleted.

Loading

0 comments on commit e73447f

Please sign in to comment.