Skip to content

Commit

Permalink
Added stubs
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Jul 18, 2024
1 parent bbf1ede commit 0a4f74e
Show file tree
Hide file tree
Showing 2 changed files with 248 additions and 49 deletions.
133 changes: 123 additions & 10 deletions src/common/biomolecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import io
from typing import Any, Dict, List, Mapping, Optional, Tuple
from src.common import residue_constants
from Bio.PDB import MMCIFParser
from Bio.PDB import PDBParser
from Bio.PDB import MMCIFParser, PDBParser, Residue
from Bio.PDB.mmcifio import MMCIFIO
from Bio.PDB.Structure import Structure
import numpy as np
Expand All @@ -24,16 +23,22 @@
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62.


# Convenience class for a tokenized residue.
TokenizedResidue = collections.namedtuple(
'TokenizedResidue', ['pos', 'mask', 'b_factors', 'restype', 'res_idx', 'chain_idx']
)


@dataclasses.dataclass(frozen=True)
class Biomolecule:
# Cartesian coordinates of atoms in angstroms. The atom ordering
# corresponds to the order in residue_constants.atom_types for proteins.
# (coming soon: support for RNA and DNA)
atom_positions: np.ndarray # [num_token, num_max_token_atoms, 3]

# Amino-acid type for each residue represented as an integer between 0 and
# 20, where 20 is 'X'. Ligands are encoded the same as 'X' for unknown residue.
aatype: np.ndarray # [num_token]
# Residue type represented as an integer between 0 and 20, where 20 is 'X'.
# Ligands are encoded the same as 'X' for unknown residue.
restype: np.ndarray # [num_token]

# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
# is present and 0.0 if not. This should be used for loss masking.
Expand All @@ -47,11 +52,11 @@ class Biomolecule:
chain_index: np.ndarray # [num_token]

# Unique integer for each distinct sequence.
entity_id: np.ndarray # [num_token]
# entity_id: np.ndarray # [num_token]

# Unique integer within chains of this sequence. e.g. if chains A, B, C share a sequence but
# D does not, their sym_ids would be [0, 1, 2, 0].
sym_id: np.ndarray # [num_token]
# sym_id: np.ndarray # [num_token]

# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
Expand All @@ -65,7 +70,7 @@ class Biomolecule:
# Chemical ID of each amino-acid, nucleotide, or ligand residue represented
# as a string. This is primarily used to record a ligand residue's name
# (e.g., when exporting an mmCIF file from a Biomolecule object).
chemid: np.ndarray # [num_res]
# chemid: np.ndarray # [num_res]

def __post_init__(self):
if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS:
Expand Down Expand Up @@ -93,8 +98,116 @@ def _from_bio_structure(
Raises:
ValueError: If the number of models included in the structure is not 1.
ValueError: If insertion code is detected at a residue.
"""
pass
"""
models = list(structure.get_models())
if len(models) != 1:
raise ValueError(
'Only single model PDBs/mmCIFs are supported. Found'
f' {len(models)} models.'
)
model = models[0]

atom_positions = []
restype = []
atom_mask = []
residue_index = []
chain_ids = []
b_factors = []

for chain in model:
if chain_id is not None and chain.id != chain_id:
continue
for res in chain:
if res.id[2] != ' ':
raise ValueError(
f'PDB/mmCIF contains an insertion code at chain {chain.id} and'
f' residue index {res.id[1]}. These are not supported.'
)
res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
tokenized_res = tokenize_residue(res, chain_idx=chain.id)
if np.sum(tokenized_res.mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue
restype.append(tokenized_res.res_idx)
atom_positions.append(tokenized_res.pos)
atom_mask.append(tokenized_res.mask)
residue_index.append(res.id[1])
chain_ids.append(chain.id)
b_factors.append(tokenized_res.b_factors)

# Chain IDs are usually characters so map these to ints.
unique_chain_ids = np.unique(chain_ids)
chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
# TODO: complete this
return Biomolecule(
atom_positions=np.stack(atom_positions, dim=0),
restype=np.stack(restype, dim=0),
atom_mask=np.stack(atom_mask, dim=0),
residue_index=np.stack(residue_index, dim=0),
chain_index=np.stack(),
b_factors=np.array(b_factors),
)


def tokenize_residue(
res: Residue,
chain_idx: int,
) -> TokenizedResidue:
"""
Tokenizes a residue into atom positions, atom masks, and B-factors. Based on the AlphaFold3
tokenization scheme.
• A standard amino acid residue is represented as a single token.
• A standard nucleotide residue is represented as a single token.
• A modified amino acid or nucleotide residue is tokenized per-atom (i.e. N tokens for an N-atom residue)
• All ligands are tokenized per-atom.
Coming soon: support for RNA/DNA.
"""
res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num)
atoms = [atom for atom in res.get_atoms()]
# Check if is standard protein residue and has modifications.
is_protein_res = res_shortname is not 'X'
has_modification = any(atom.name not in residue_constants.atom_types for atom in atoms if is_protein_res)
if is_protein_res and not has_modification:
# Standard amino acid residue.
pos = np.zeros((1, residue_constants.atom_type_num, 3)) # atom37 representation.
mask = np.zeros((1, residue_constants.atom_type_num,))
b_factors = np.zeros((1, residue_constants.atom_type_num,))
for atom in atoms:
pos[0, residue_constants.atom_order[atom.name]] = atom.coord
mask[0, residue_constants.atom_order[atom.name]] = 1.0
b_factors[0, residue_constants.atom_order[atom.name]] = atom.bfactor
elif is_protein_res and has_modification:
# Modified amino acid residue.
pos = np.zeros((len(atoms), residue_constants.atom_type_num, 3))
mask = np.zeros((len(atoms), residue_constants.atom_type_num,))
b_factors = np.zeros((len(atoms), residue_constants.atom_type_num,))
for i, atom in enumerate(atoms):
pos[i, residue_constants.atom_order[atom.name]] = atom.coord
mask[i, residue_constants.atom_order[atom.name]] = 1.0
b_factors[i, residue_constants.atom_order[atom.name]] = atom.bfactor
else:
# Ligand residue.
pos = np.zeros((len(atoms), residue_constants.atom_type_num, 3))
mask = np.zeros((len(atoms), residue_constants.atom_type_num,))
b_factors = np.zeros((len(atoms), residue_constants.atom_type_num,))
for i, atom in enumerate(atoms):
# If ligand, add the atom coordinate as the first atom.
pos[i, 0] = atom.coord
mask[i, 0] = 1.0
b_factors[i, 0] = atom.bfactor
restype = np.full((len(atoms),), fill_value=restype_idx)
res_idx = np.full((len(atoms),), fill_value=res.id[1])
chain_idx = np.full_like(mask, fill_value=chain_idx)
return TokenizedResidue(pos=pos,
mask=mask,
restype=restype,
b_factors=b_factors,
res_idx=res_idx,
chain_idx=chain_idx)


def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Biomolecule:
Expand Down
Loading

0 comments on commit 0a4f74e

Please sign in to comment.