Skip to content

Commit

Permalink
delete proteus
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Sep 3, 2024
1 parent c2e6cbf commit b78a889
Show file tree
Hide file tree
Showing 4 changed files with 0 additions and 578 deletions.
81 changes: 0 additions & 81 deletions src/models/embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,84 +386,3 @@ def forward(
u = torch.div(u, n_templ) # average
u = self.output_proj(u)
return u


class ProteusFeatureEmbedder(nn.Module):
"""Convenience class for the Proteus experiment."""

def __init__(
self,
c_token: int = 384,
c_atom: int = 128,
c_atompair: int = 16,
c_trunk_pair: int = 16,
no_blocks: int = 3,
no_heads: int = 4,
dropout: float = 0.0,
n_queries: int = 32,
n_keys: int = 128,
):
super().__init__()
self.no_blocks = no_blocks
self.c_token = c_token
self.c_atom = c_atom
self.c_atompair = c_atompair
self.c_trunk_pair = c_trunk_pair
self.no_heads = no_heads
self.dropout = dropout
self.n_queries = n_queries
self.n_keys = n_keys

self.input_feature_embedder = InputFeatureEmbedder(
no_blocks=no_blocks,
c_token=c_token,
c_atom=c_atom,
c_atompair=c_atompair,
c_trunk_pair=c_trunk_pair,
no_heads=no_heads,
dropout=dropout,
n_queries=n_queries,
n_keys=n_keys
)
self.linear_s_init = LinearNoBias(c_token, c_token)
self.linear_z_col = LinearNoBias(c_token, c_trunk_pair)
self.linear_z_row = LinearNoBias(c_token, c_trunk_pair)
self.relative_pos_encoder = RelativePositionEncoding(c_trunk_pair)

def forward(
self,
features: Dict[str, torch.Tensor],
atom_mask: torch.Tensor = None,
token_mask: torch.Tensor = None,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Forward pass of the Proteus feature embedder.
Args:
features:
Dictionary containing the input features
atom_mask:
[*, N_atoms] mask indicating which atoms are valid (non-padding).
token_mask:
[*, N_tokens] mask indicating which tokens are valid (non-padding).
Returns:
[*, N_tokens, c_token] Embedding of the input features.
"""
# Grab data about the input
*_, n_tokens = features["token_index"].shape

# Encode the input features
per_token_features = self.input_feature_embedder(
features=features,
n_tokens=n_tokens,
mask=atom_mask,
)
# f_restype, f_profile, and f_deletion_mean do not exist for design

# Compute s_trunk
s_trunk = self.linear_s_init(per_token_features)

# Compute z_trunk
z_trunk = self.linear_z_col(per_token_features[:, :, None, :]) + \
self.linear_z_row(per_token_features[:, None, :, :])
z_trunk = z_trunk + self.relative_pos_encoder(features, token_mask)

return per_token_features, s_trunk, z_trunk
313 changes: 0 additions & 313 deletions src/models/proteus_module.py

This file was deleted.

Loading

0 comments on commit b78a889

Please sign in to comment.