Skip to content

0.2.0 Release

Choose a tag to compare
@KiddoZhu KiddoZhu released this 19 Sep 05:23
· 29 commits to master since this release

V0.2.0 is a major release with a new family member TorchProtein, a library for machine-learning-guided protein science. Aiming at simplifying the development of protein methods, TorchProtein encapsulates many complicated yet repetitive subroutines into functional modules, including widely-used datasets, flexible data processing operations, advanced encoding models, and diverse protein tasks.

Such comprehensive encapsulation enables users to develop protein machine learning solutions with one easy-to-use library. It avoids the embarrassment of gluing multiple libraries into a pipeline.

With TorchProtein, we can rapidly prototype machine learning solutions to various protein applications within 20 lines of codes, and conduct ablation studies by substituting different parts of the solution with off-the-shelf modules. Furthermore, we can easily adapt these modules to our own needs, and make systematic analyses by comparing the new results to a benchmark provided in the library.

Additionally, TorchProtein is designed to be accessible to everyone. For inexperienced users, like beginners or biological researchers, TorchProtein provides user-friendly APIs to simplify the development of protein machine learning solutions. Meanwhile, for professional users, TorchProtein also preserves enough flexibility to satisfy their demands, supported by features like modular design of the library and on-the-fly graph construction.

Main Features

Simplify Data Processing

  • It is challenging to transform raw bioinformatic protein datasets into tensor formats for machine learning. To reduce tedious operations, TorchProtein provides us with a data structure data.Protein and its batched extension data.PackedProtein to automate the data processing step.

    • data.Protein and data.PackedProtein automatically gather protein data from various bio-sources and seamlessly switch between data formats like pdb files, RDKit objects and sequences. Please see the section data structures and operations for transforming from and to sequences and RDKit objects.

      # construct a data.Protein instance from a pdb file
      pdb_file = ...
      protein = data.Protein.from_pdb(pdb_file, atom_feature="position", bond_feature="length", residue_feature="symbol")
      # write a data.Protein instance back to a pdb file
      new_pdb_file = ...
      Protein(num_atom=445, num_bond=916, num_residue=57)
    • data.Protein and data.PackedProtein automatically pre-process all kinds of features of atoms, bonds and residues, by simply setting up several arguments.

      pdb_file = ...
      protein = data.Protein.from_pdb(pdb_file, atom_feature="position", bond_feature="length", residue_feature="symbol")
      # feature
      torch.Size([57, 21])
      torch.Size([445, 3])
      torch.Size([916, 1])
    • data.Protein and data.PackedProtein automatically keeps track of numerous attributes associated with atoms, bonds, residues and the whole protein.

      • For example, reference offers a way to register new attributes as node, edge or graph property, and in this way, the new attributes would automatically go along with the node, edge or graph themself. More in-built attributes are listed in the section data structures and operations.
      protein = ...
      with protein.node():
          protein.node_id = torch.tensor([i for i in range(0, protein.num_node)])
      with protein.edge():
          protein.edge_cost = torch.rand(protein.num_edge)
      with protein.graph():
          protein.graph_feature = torch.randn(128)
      • Even more, reference can be utilized to maintain the correspondence between two well related objects. For example, the mapping atom2residue maintains relationship between atoms and residues, and enables indexing on either of them.
      protein = ...
      # create a mask indices for atoms in a glutamine (GLN)
      is_glutamine = protein.residue_type[protein.atom2residue] == protein.residue2id["GLN"]
      mask_indices = is_glutamine.nonzero().squeeze(-1)
      # map the masked atoms back to the glutamine residue
      residue_type = protein.residue_type[protein.atom2residue[mask_indices]]
      print([protein.id2residue[r] for r in residue_type.tolist()])
      tensor([ 26,  27,  28,  29,  30,  31,  32,  33,  34, 307, 308, 309, 310, 311,
              312, 313, 314, 315, 384, 385, 386, 387, 388, 389, 390, 391, 392])
      ['GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN']
  • It is useful to augment protein data by modifying protein graphs or constructing new ones. With the protein operations and the graph construction layers provided in TorchProtein,

    • we can easily modify proteins on the fly by batching, slicing sequences, masking out side chains, etc. Please see the tutorials for more details on masking.

      pdb_file = ...
      protein = data.Protein.from_pdb(pdb_file, atom_feature="position", bond_feature="length", residue_feature="symbol")
      # batch
      proteins = data.Protein.pack([protein, protein, protein])
      # slice sequences
      # use indexing to extract consecutive residues of a particular protein
      two_residues = protein[[0,2]]

      two residues

    • we can construct protein graphs on the fly with GPU acceleration, which offers users flexible choices rather than using fixed pre-processed graphs. Below is an example to build a graph with only alpha carbon atoms, please check tutorials for more cases, such as adding spatial / KNN / sequential edges.

      protein = ...
      # transfer from CPU to GPU
      protein = protein.cuda()
      # build a graph with only alpha carbon (CA) atoms
      node_layers = [geometry.AlphaCarbonNode()]
      graph_construction_model = layers.GraphConstruction(node_layers=node_layers)
      original_protein = data.Protein.pack([protein])
      CA_protein = graph_construction_model(_protein)
      print("Graph before:", original_protein)
      print("Graph after:", CA_protein)
      Protein(num_atom=445, num_bond=916, num_residue=57, device='cuda:0')
      Graph before: PackedProtein(batch_size=1, num_atoms=[2639], num_bonds=[5368], num_residues=[350])
      Graph after: PackedProtein(batch_size=1, num_atoms=[350], num_bonds=[0], num_residues=[350])

Easy to Prototype Solutions

With TorchProtein, common protein tasks can be finished within 20 lines of codes, such as sequence-based protein property prediction task. Below is an example and more examples of different popular protein tasks and models can be found in Protein Tasks, Models and Tutorials.

import torch
from torchdrug import datasets, transforms, models, tasks, core

truncate_transform = transforms.TruncateProtein(max_length=200, random=False)
protein_view_transform = transforms.ProteinView(view="residue")
transform = transforms.Compose([truncate_transform, protein_view_transform])

dataset = datasets.BetaLactamase("~/protein-datasets/", residue_only=True, transform=transform)
train_set, valid_set, test_set = dataset.split()

model = models.ProteinCNN(input_dim=21,
                          hidden_dims=[1024, 1024],
                          kernel_size=5, padding=2, readout="max")

task = tasks.PropertyPrediction(model, task=dataset.tasks,
                                criterion="mse", metric=("mae", "rmse", "spearmanr"),
                                normalization=False, num_mlp_layer=2)

optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer, 
                     gpus=[0], batch_size=64)
mean absolute error [scaled_effect1]: 0.249482
root mean squared error [scaled_effect1]: 0.304326
spearmanr [scaled_effect1]: 0.44572

Compatible with Existing Molecular Models in TorchDrug

  • TorchProtein follows the scientific fact that proteins are macromolecules. The core data structures data.Protein and data.PackedProtein inherit from data.Molecule and data.PackedMolecule respectively. Therefore, we can apply any existing molecule model in TorchDrug to proteins

    import torch
    from torchdrug import layers, datasets, transforms, models, tasks, core
    from torchdrug.layers import geometry
    truncate_transform = transforms.TruncateProtein(max_length=200, random=False)
    protein_view_transform = transforms.ProteinView(view="residue")
    transform = transforms.Compose([truncate_transform, protein_view_transform])
    dataset = datasets.EnzymeCommission("~/protein-datasets/", transform=transform)
    train_set, valid_set, test_set = dataset.split()
    model = models.GIN(input_dim=21,
                        hidden_dims=[256, 256, 256, 256],
                        batch_norm=True, short_cut=True, concat_hidden=True)
    graph_construction_model = layers.GraphConstruction(
                                      edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                      geometry.KNNEdge(k=10, min_distance=5),
    task = tasks.MultipleBinaryClassification(model, graph_construction_model=graph_construction_model, num_mlp_layer=3,
                                              task=list(range(len(dataset.tasks))), criterion="bce",
                                              metric=("auprc@micro", "f1_max"))
    optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
    solver = core.Engine(task, train_set, valid_set, test_set, optimizer, 
                         gpus=[0], batch_size=4)
    auprc@micro: 0.187884
    f1_max: 0.231008
  • In Protein-Ligand Interaction (PLI) prediction task, we can utilize a molecular encoder module to extract the representations of molecules. Please check tutorial 2 for more details.

    train_set, valid_set, test_set = ...
    # protein encoder
    model = models.ProteinCNN(input_dim=21,
                              hidden_dims=[1024, 1024],
                              kernel_size=5, padding=2, readout="max")
    # molecule encoder
    model2 = models.GIN(input_dim=66,
                        hidden_dims=[256, 256, 256, 256],
                        batch_norm=True, short_cut=True, concat_hidden=True)
    task = tasks.InteractionPrediction(model, model2=model2, task=dataset.tasks,
                                       criterion="mse", metric=("mae", "rmse", "spearmanr"),
                                       normalization=False, num_mlp_layer=2)
    optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
    solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                         gpus=[0], batch_size=16)
    mean absolute error [scaled_effect1]: 0.249482
    root mean squared error [scaled_effect1]: 0.304326
    spearmanr [scaled_effect1]: 0.44572

Support From the Developer (@DeepGraphLearning/torchdrug-maintainers)

There is always an active supporting team to answer questions and provide helps. Feedbacks of use experience and contributions for development are welcomed.

New Modules

Data Structures and Operations


  • Representative attributes:
    • data.Protein.edge_list: list of edges and each edge is represented by a tuple (node_in, node_out, bond_type)
    • data.Protein.atom_type: atom types
    • data.Protein.bond_type: bond types
    • data.Protein.residue_type: residue types
    • data.Protein.view: default view for this protein. Can be “atom” or “residue”
    • data.Protein.atom_name: atom names in each residue
    • data.Protein.atom2residue: atom id to residue id mapping
    • data.Protein.is_hetero_atom: hetero atom indicator
    • data.Protein.occupancy: protein occupancy
    • data.Protein.b_factor: temperature factors
    • data.Protein.residue_number: residue numbers
    • data.Protein.insertion_code: insertion codes
    • data.Protein.chain_id: chain ids
  • Representative Methods:
    • data.Protein.from_molecule: create a protein from an RDKit object.
    • data.Protein.from_sequence: create a protein from a sequence.
    • data.Protein.from_sequence_fast: a faster version of creating a protein from a sequence.
    • data.Protein.from_pdb: create a protein from a PDB file.
    • data.Protein.to_molecule: return an RDKit object of this protein.
    • data.Protein.to_sequence: return a sequence of this protein.
    • data.Protein.to_pdb: write this protein to a pdb file.
    • data.Protein.split: split this protein graph into multiple disconnected protein graphs.
    • data.Protein.pack: batch a list of data.Protein into data.PackedProtein.
    • data.Protein.repeat: repeat this protein.
    • data.Protein.residue2atom: map residue id to atom ids.
    • data.Protein.residue_mask: return a masked protein based on the specified residues.
    • data.Protein.subresidue: return a subgraph based on the specified residues.
    • data.Protein.residue2graph: residue id to protein id mapping.
    • data.Protein.node_mask: return a masked protein based on the specified nodes.
    • data.Protein.edge_mask: return a masked protein based on the specified edges.
    • data.Protein.compact: remove isolated nodes and compact node ids.


  • Representative attributes:
    • data.PackedProtein.edge_list: list of edges and each edge is represented by a tuple (node_in, node_out, bond_type)
    • data.PackedProtein.atom_type: atom types
    • data.PackedProtein.bond_type: bond types
    • data.PackedProtein.residue_type: residue types
    • data.PackedProtein.view: default view for this protein. Can be “atom” or “residue”
    • data.PackedProtein.num_nodes: number of nodes in each protein graph
    • data.PackedProtein.num_edges: number of edges in each protein graph
    • data.PackedProtein.num_residues: number of residues in each protein graph
    • data.PackedProtein.offsets: node id offsets in different proteins
  • Representative methods:
    • data.PackedProtein.node_mask: return a masked packed protein based on the specified nodes.
    • data.PackedProtein.edge_mask: return a masked packed protein based on the specified edges.
    • data.PackedProtein.residue_mask: return a masked packed protein based on the specified residues.
    • data.PackedProtein.graph_mask: return a masked packed protein based on the specified protein graphs.
    • data.PackedProtein.from_molecule: create a protein from a list of RDKit objects.
    • data.PackedProtein.from_sequence: create a protein from a list of sequences.
    • data.PackedProtein.from_sequence_fast: a faster version of creating a protein from a list of sequences.
    • data.PackedProtein.from_pdb: create a protein from a list of PDB files.
    • data.PackedProtein.to_molecule: return a list of RDKit objects of this packed protein.
    • data.PackedProtein.to_sequence: return a list of sequences of this packed protein.
    • data.PackedProtein.to_pdb: write this packed protein to a list of pdb files.
    • data.PackedProtein.merge: merge multiple packed proteins into a single packed protein.
    • data.PackedProtein.repeat: repeat this packed protein.
    • data.PackedProtein.repeat_interleave: repeat this packed protein, behaving similarly to torch.repeat_interleave_.
    • data.PackedProtein.residue2graph: residue id to graph id mapping.


  • GearNet: Geometry Aware Relational Graph Neural Network.
  • ESM: Evolutionary Scale Modeling (ESM).
  • ProteinCNN: protein shallow CNN.
  • ProteinResNet: protein ResNet.
  • ProteinLSTM: protein LSTM.
  • ProteinBERT: protein BERT.
  • Statistic: the statistic feature engineering for protein sequence.
  • Physicochemical: the physicochemical feature engineering for protein sequence.

Protein Tasks

Sequence-based Protein Property Prediction:

  • tasks.PropertyPrediction predicts some property of each protein, such as Beta-lactamase activity, stability and solubility for proteins.
  • tasks.NodePropertyPrediction predicts some property of each residue in proteins, such as the secondary structure (coil, strand or helix) of each residue.
  • tasks.ContactPrediction predicts whether any pair of residues contact or not in the folded structure.
  • tasks.InteractionPrediction predicts the binding affinity of two interacting proteins or of a protein and a ligand, i.e. performing PPI affinity prediction or PLI affinity prediction.

Structure-based Protein Property Prediction:

  • tasks.MultipleBinaryClassification predicts whether a protein owns several specific functions or not with binary labels.

Pre-trained Protein Structure Representations:

  • Self-Supervised Protein Structure Pre-training: acquires informative protein representations from massive unlabeled protein structures, such as tasks.EdgePrediction, tasks.AttributeMasking, tasks.ContextPrediction, tasks.DistancePrediction, tasks.AnglePrediction, tasks.DihedralPrediction .
  • Fine-tuning on Downstream Task: fine-tunes the pre-trained protein encoder on downstream tasks, such as any property prediction task mentioned above.

Protein Datasets

Protein Property Prediction Datasets

  • BetaLactamase : protein sequences with activity labels
  • Fluorescence: protein sequences with fitness labels
  • Stability: protein sequences with stability labels
  • Solubility: protein sequences with solubility labels
  • BinaryLocalization: protein sequences with membrane-bound or soluble labels
  • SubcellularLocalization: protein sequences with natural cell location labels
  • EnzymeCommission: protein sequences and 3D structures with EC number labels for catalysis in biochemical reactions
  • GeneOntology: protein sequences and 3D structures with GO term labels, including molecular function (MF), biological process (BP) and cellular component (CC)
  • AlphaFoldDB: protein sequences and 3D structures predicted by AlphaFold

Protein Structure Prediction Datasets

  • Fold: protein sequences and 3D structures with fold labels determined by the global structural topology
  • SecondaryStructure: protein sequences and 3D structures with secondary structure labels determined by the local structures
  • ProteinNet: protein sequences and 3D structures for the contact prediction task

Protein-Protein Interaction Prediction Datasets

  • HumanPPI: protein sequences with binary interaction labels for human proteins
  • YeastPPI: protein sequences with binary interaction labels for yeast proteins
  • PPIAffinity: protein sequences with binding affinity values measured by $p_{K_d}$

Protein Ligand Interaction Prediction Datasets

  • BindingDB: protein sequences and molecule graphs with binding affinity between pairs of protein and ligand
  • PDBBind: protein sequences and molecule graphs with binding affinity between pairs of protein and ligand

Data Transform Modules

  • TruncateProtein: truncate over long protein sequences into a fixed length
  • ProteinView: convert proteins to a specific view

Graph Construction Layers

  • SubsequenceNode: take a protein subsequence of a specific length
  • SubspaceNode: extract a subgraph by only keeping neighboring nodes in a spatial ball for each centered node
  • RandomEdgeMask: mask out some edges randomly from the protein graph


To help users gain a comprehensive understanding of TorchProtein, we recommend some user-friendly tutorials for its basic usage and examples to various protein-related tasks. These tutorials may also serve as boilerplate codes for users to develop their own applications.

Bug Fixes

  • Fix an error in the decorator @utils.cached (#118)
  • Fix an index error in data.Graph.split() (#115)
  • Fix setting attribute node_feature , edge_feature and graph_feature (#116)
  • Fix incorrect node feature shape for the synthon dataset USPTO50k (#116)
  • Fix a compatible issue when adding node/edge/graph reference and changing node/edge to atom/bond (#116, #117)