diff --git a/src/common/biomolecule.py b/src/common/biomolecule.py index a251082..36be0a0 100644 --- a/src/common/biomolecule.py +++ b/src/common/biomolecule.py @@ -10,8 +10,7 @@ import io from typing import Any, Dict, List, Mapping, Optional, Tuple from src.common import residue_constants -from Bio.PDB import MMCIFParser -from Bio.PDB import PDBParser +from Bio.PDB import MMCIFParser, PDBParser, Residue from Bio.PDB.mmcifio import MMCIFIO from Bio.PDB.Structure import Structure import numpy as np @@ -24,6 +23,12 @@ PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62. +# Convenience class for a tokenized residue. +TokenizedResidue = collections.namedtuple( + 'TokenizedResidue', ['pos', 'mask', 'b_factors', 'restype', 'res_idx', 'chain_idx'] +) + + @dataclasses.dataclass(frozen=True) class Biomolecule: # Cartesian coordinates of atoms in angstroms. The atom ordering @@ -31,9 +36,9 @@ class Biomolecule: # (coming soon: support for RNA and DNA) atom_positions: np.ndarray # [num_token, num_max_token_atoms, 3] - # Amino-acid type for each residue represented as an integer between 0 and - # 20, where 20 is 'X'. Ligands are encoded the same as 'X' for unknown residue. - aatype: np.ndarray # [num_token] + # Residue type represented as an integer between 0 and 20, where 20 is 'X'. + # Ligands are encoded the same as 'X' for unknown residue. + restype: np.ndarray # [num_token] # Binary float mask to indicate presence of a particular atom. 1.0 if an atom # is present and 0.0 if not. This should be used for loss masking. @@ -47,11 +52,11 @@ class Biomolecule: chain_index: np.ndarray # [num_token] # Unique integer for each distinct sequence. - entity_id: np.ndarray # [num_token] + # entity_id: np.ndarray # [num_token] # Unique integer within chains of this sequence. e.g. if chains A, B, C share a sequence but # D does not, their sym_ids would be [0, 1, 2, 0]. - sym_id: np.ndarray # [num_token] + # sym_id: np.ndarray # [num_token] # B-factors, or temperature factors, of each residue (in sq. angstroms units), # representing the displacement of the residue from its ground truth mean @@ -65,7 +70,7 @@ class Biomolecule: # Chemical ID of each amino-acid, nucleotide, or ligand residue represented # as a string. This is primarily used to record a ligand residue's name # (e.g., when exporting an mmCIF file from a Biomolecule object). - chemid: np.ndarray # [num_res] + # chemid: np.ndarray # [num_res] def __post_init__(self): if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS: @@ -93,8 +98,116 @@ def _from_bio_structure( Raises: ValueError: If the number of models included in the structure is not 1. ValueError: If insertion code is detected at a residue. - """ - pass + """ + models = list(structure.get_models()) + if len(models) != 1: + raise ValueError( + 'Only single model PDBs/mmCIFs are supported. Found' + f' {len(models)} models.' + ) + model = models[0] + + atom_positions = [] + restype = [] + atom_mask = [] + residue_index = [] + chain_ids = [] + b_factors = [] + + for chain in model: + if chain_id is not None and chain.id != chain_id: + continue + for res in chain: + if res.id[2] != ' ': + raise ValueError( + f'PDB/mmCIF contains an insertion code at chain {chain.id} and' + f' residue index {res.id[1]}. These are not supported.' + ) + res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') + tokenized_res = tokenize_residue(res, chain_idx=chain.id) + if np.sum(tokenized_res.mask) < 0.5: + # If no known atom positions are reported for the residue then skip it. + continue + restype.append(tokenized_res.res_idx) + atom_positions.append(tokenized_res.pos) + atom_mask.append(tokenized_res.mask) + residue_index.append(res.id[1]) + chain_ids.append(chain.id) + b_factors.append(tokenized_res.b_factors) + + # Chain IDs are usually characters so map these to ints. + unique_chain_ids = np.unique(chain_ids) + chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)} + chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids]) + # TODO: complete this + return Biomolecule( + atom_positions=np.stack(atom_positions, dim=0), + restype=np.stack(restype, dim=0), + atom_mask=np.stack(atom_mask, dim=0), + residue_index=np.stack(residue_index, dim=0), + chain_index=np.stack(), + b_factors=np.array(b_factors), + ) + + +def tokenize_residue( + res: Residue, + chain_idx: int, +) -> TokenizedResidue: + """ + Tokenizes a residue into atom positions, atom masks, and B-factors. Based on the AlphaFold3 + tokenization scheme. + • A standard amino acid residue is represented as a single token. + • A standard nucleotide residue is represented as a single token. + • A modified amino acid or nucleotide residue is tokenized per-atom (i.e. N tokens for an N-atom residue) + • All ligands are tokenized per-atom. + + Coming soon: support for RNA/DNA. + """ + res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') + restype_idx = residue_constants.restype_order.get( + res_shortname, residue_constants.restype_num) + atoms = [atom for atom in res.get_atoms()] + # Check if is standard protein residue and has modifications. + is_protein_res = res_shortname is not 'X' + has_modification = any(atom.name not in residue_constants.atom_types for atom in atoms if is_protein_res) + if is_protein_res and not has_modification: + # Standard amino acid residue. + pos = np.zeros((1, residue_constants.atom_type_num, 3)) # atom37 representation. + mask = np.zeros((1, residue_constants.atom_type_num,)) + b_factors = np.zeros((1, residue_constants.atom_type_num,)) + for atom in atoms: + pos[0, residue_constants.atom_order[atom.name]] = atom.coord + mask[0, residue_constants.atom_order[atom.name]] = 1.0 + b_factors[0, residue_constants.atom_order[atom.name]] = atom.bfactor + elif is_protein_res and has_modification: + # Modified amino acid residue. + pos = np.zeros((len(atoms), residue_constants.atom_type_num, 3)) + mask = np.zeros((len(atoms), residue_constants.atom_type_num,)) + b_factors = np.zeros((len(atoms), residue_constants.atom_type_num,)) + for i, atom in enumerate(atoms): + pos[i, residue_constants.atom_order[atom.name]] = atom.coord + mask[i, residue_constants.atom_order[atom.name]] = 1.0 + b_factors[i, residue_constants.atom_order[atom.name]] = atom.bfactor + else: + # Ligand residue. + pos = np.zeros((len(atoms), residue_constants.atom_type_num, 3)) + mask = np.zeros((len(atoms), residue_constants.atom_type_num,)) + b_factors = np.zeros((len(atoms), residue_constants.atom_type_num,)) + for i, atom in enumerate(atoms): + # If ligand, add the atom coordinate as the first atom. + pos[i, 0] = atom.coord + mask[i, 0] = 1.0 + b_factors[i, 0] = atom.bfactor + restype = np.full((len(atoms),), fill_value=restype_idx) + res_idx = np.full((len(atoms),), fill_value=res.id[1]) + chain_idx = np.full_like(mask, fill_value=chain_idx) + return TokenizedResidue(pos=pos, + mask=mask, + restype=restype, + b_factors=b_factors, + res_idx=res_idx, + chain_idx=chain_idx) def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Biomolecule: diff --git a/src/data/components/mmcif_parsing.py b/src/data/components/mmcif_parsing.py index 0de3dc2..f0209cb 100644 --- a/src/data/components/mmcif_parsing.py +++ b/src/data/components/mmcif_parsing.py @@ -28,8 +28,8 @@ import numpy as np from src.data.errors import MultipleChainsError -import src.common.residue_constants as residue_constants - +from src.common import residue_constants +from src.common import ligand_constants # Type aliases: ChainId = str @@ -45,18 +45,31 @@ class Monomer: num: int +@dataclasses.dataclass(frozen=True) +class Ligand: + ligand_id: str + ligand_num: int + + +@dataclasses.dataclass(frozen=True) +class Bond: + atom1: str + atom2: str + bond_type: str + + # Note - mmCIF format provides no guarantees on the type of author-assigned # sequence numbers. They need not be integers. @dataclasses.dataclass(frozen=True) class AtomSite: - residue_name: str - author_chain_id: str - mmcif_chain_id: str - author_seq_num: str - mmcif_seq_num: int - insertion_code: str - hetatm_atom: str - model_num: int + residue_name: str # label_comp_id + author_chain_id: str # auth_asym_id + mmcif_chain_id: str # label_asym_id + author_seq_num: str # auth_seq_id + mmcif_seq_num: int # label_seq_id + insertion_code: str # pdbx_PDB_ins_code + hetatm_atom: str # hetero-atom, considered not to be a part of the primary molecule (e.g. ligands) + model_num: int # pdbx_PDB_model_num # Used to map SEQRES index to a residue in the structure. @@ -99,6 +112,7 @@ class MmcifObject: chain_to_seqres: Mapping[ChainId, SeqRes] seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]] raw_string: Any + # bonds: @dataclasses.dataclass(frozen=True) @@ -120,7 +134,7 @@ class ParseError(Exception): def mmcif_loop_to_list( - prefix: str, parsed_info: MmCIFDict + prefix: str, parsed_info: MmCIFDict ) -> Sequence[Mapping[str, str]]: """Extracts loop associated with a prefix from mmCIF data as a list. @@ -145,16 +159,16 @@ def mmcif_loop_to_list( data.append(value) assert all([len(xs) == len(data[0]) for xs in data]), ( - "mmCIF error: Not all loops are the same length: %s" % cols + "mmCIF error: Not all loops are the same length: %s" % cols ) return [dict(zip(cols, xs)) for xs in zip(*data)] def mmcif_loop_to_dict( - prefix: str, - index: str, - parsed_info: MmCIFDict, + prefix: str, + index: str, + parsed_info: MmCIFDict, ) -> Mapping[str, Mapping[str, str]]: """Extracts loop associated with a prefix from mmCIF data as a dictionary. @@ -176,9 +190,9 @@ def mmcif_loop_to_dict( @functools.lru_cache(16, typed=False) def parse( - *, file_id: str, mmcif_string: str, catch_all_errors: bool = True + *, file_id: str, mmcif_string: str, catch_all_errors: bool = True ) -> ParsingResult: - """Entry point, parses an mmcif_string. + """Entry point, parses a mmcif_string. Args: file_id: A string identifier for this file. Should be unique within the @@ -210,6 +224,7 @@ def parse( # Determine the protein chains, and their start numbers according to the # internal mmCIF numbering scheme (likely but not guaranteed to be 1). + # TODO: change to proteins & ligands valid_chains = _get_protein_chains(parsed_info=parsed_info) if not valid_chains: return ParsingResult( @@ -252,7 +267,7 @@ def parse( insertion_code=insertion_code, ) seq_idx = ( - int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id] + int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id] ) current = seq_to_structure_mappings.get( atom.author_chain_id, {} @@ -339,9 +354,9 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader: header["resolution"] = 0.00 for res_key in ( - "_refine.ls_d_res_high", - "_em_3d_reconstruction.resolution", - "_reflns.d_resolution_high", + "_refine.ls_d_res_high", + "_em_3d_reconstruction.resolution", + "_reflns.d_resolution_high", ): if res_key in parsed_info: try: @@ -374,7 +389,7 @@ def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]: def _get_protein_chains( - *, parsed_info: Mapping[str, Any] + *, parsed_info: Mapping[str, Any] ) -> Mapping[ChainId, Sequence[Monomer]]: """Extracts polymer information for protein chains only. @@ -409,6 +424,7 @@ def _get_protein_chains( chain_id = struct_asym["_struct_asym.id"] entity_id = struct_asym["_struct_asym.entity_id"] entity_to_mmcif_chains[entity_id].append(chain_id) + # ligands are actually different entities and have different chain ids. # Identify and return the valid protein chains. valid_chains = {} @@ -416,26 +432,92 @@ def _get_protein_chains( chain_ids = entity_to_mmcif_chains[entity_id] # Reject polymers without any peptide-like components, such as DNA/RNA. - if any( - [ - "peptide" in chem_comps[monomer.id]["_chem_comp.type"] - for monomer in seq_info - ] - ): + has_protein = any( + ["peptide" in chem_comps[monomer.id]["_chem_comp.type"] + for monomer in seq_info]) + if has_protein: for chain_id in chain_ids: valid_chains[chain_id] = seq_info + return valid_chains +def _get_ligand_chains( + *, parsed_info: Mapping[str, Any] +) -> Mapping[ChainId, Sequence[Ligand]]: + """Extracts ligand information for ligand 'chains'. + + Args: + parsed_info: _mmcif_dict produced by the Biopython parser. + + Returns: + A dict mapping mmcif chain id to a list of Ligands. + """ + # Get non-polymer ligand information for each entity in the structure. + entity_nonpolymers = mmcif_loop_to_list("_pdbx_nonpoly_scheme.", parsed_info) + + nonpolymers = collections.defaultdict(list) + for entity_nonpoly in entity_nonpolymers: + nonpolymers[entity_nonpoly["_pdbx_nonpoly_scheme.entity_id"]].append( + Ligand( + ligand_id=entity_nonpoly["_pdbx_nonpoly_scheme.mon_id"], + ligand_num=int(entity_nonpoly["_pdbx_nonpoly_scheme.num"]), # TODO: this is wrong! .num doesn't exist + ) + ) + + # Get chemical compositions. Will allow us to identify which of these polymers + # are proteins. + chem_comps = mmcif_loop_to_dict("_chem_comp.", "_chem_comp.id", parsed_info) + + # Get chains information for each entity. Necessary so that we can return a + # dict keyed on chain id rather than entity. + struct_asyms = mmcif_loop_to_list("_struct_asym.", parsed_info) + + entity_to_mmcif_chains = collections.defaultdict(list) + for struct_asym in struct_asyms: + chain_id = struct_asym["_struct_asym.id"] + entity_id = struct_asym["_struct_asym.entity_id"] + entity_to_mmcif_chains[entity_id].append(chain_id) + # ligands actually have different chain ids. + # In the case of hemoglobin that has 4 heme groups, + # each heme group has a different chain id but the same entity id. + + # Identify and return the valid ligand chains. + valid_ligands = {} + for entity_id, ligand_info in nonpolymers.items(): + chain_ids = entity_to_mmcif_chains[entity_id] + + # Reject ligands that are crystallization aids or in the exclusion list. + for monomer in ligand_info: + is_crystal_aid = monomer.ligand_id in ligand_constants.CRYSTALLIZATION_AIDS + is_valid_ligand = monomer.ligand_id not in ligand_constants.LIGAND_EXCLUSION_LIST and not is_crystal_aid + if is_valid_ligand: + valid_ligands[chain_id] = ligand_info + # has_valid_ligand = any() + + pass + """ + ligand_chains = {} + for chem_comp in chem_comps.keys(): + # Check if there is a valid ligand in the structure + is_crystal_aid = chem_comp in ligand_constants.CRYSTALLIZATION_AIDS + is_valid_ligand = chem_comp not in ligand_constants.LIGAND_EXCLUSION_LIST and not is_crystal_aid + if is_valid_ligand: + pass + """ + # TODO: Add the ligand chains separately here. They are not polymers. + return None + + def _is_set(data: str) -> bool: """Returns False if data is a special mmCIF character indicating 'unset'.""" return data not in (".", "?") def get_atom_coords( - mmcif_object: MmcifObject, - chain_id: str, - _zero_center_positions: bool = False + mmcif_object: MmcifObject, + chain_id: str, + _zero_center_positions: bool = False ) -> Tuple[np.ndarray, np.ndarray]: # Locate the right chain chains = list(mmcif_object.structure.get_chains()) @@ -447,7 +529,8 @@ def get_atom_coords( chain = relevant_chains[0] # Extract the coordinates - num_res = len(mmcif_object.chain_to_seqres[chain_id]) + num_res = len(mmcif_object.chain_to_seqres[chain_id]) # TODO: number of tokens + # the ligands and water are included as residues in the chain all_atom_positions = np.zeros( [num_res, residue_constants.atom_type_num, 3], dtype=np.float32 ) @@ -455,10 +538,10 @@ def get_atom_coords( [num_res, residue_constants.atom_type_num], dtype=np.float32 ) for res_index in range(num_res): - pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32) + pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32) # atom37 representation mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32) res_at_position = mmcif_object.seqres_to_structure[chain_id][res_index] - if not res_at_position.is_missing: + if not res_at_position.is_missing: # TODO: switch to tokenize_residue from biomolecule.py res = chain[ ( res_at_position.hetflag, @@ -476,17 +559,19 @@ def get_atom_coords( # Put the coords of the selenium atom in the sulphur column pos[residue_constants.atom_order["SD"]] = [x, y, z] mask[residue_constants.atom_order["SD"]] = 1.0 + # TODO: put the coords of ligand atoms here as well # Fix naming errors in arginine residues where NH2 is incorrectly # assigned to be closer to CD than NH1 cd = residue_constants.atom_order['CD'] nh1 = residue_constants.atom_order['NH1'] nh2 = residue_constants.atom_order['NH2'] - if( - res.get_resname() == 'ARG' and - all(mask[atom_index] for atom_index in (cd, nh1, nh2)) and - (np.linalg.norm(pos[nh1] - pos[cd]) > - np.linalg.norm(pos[nh2] - pos[cd])) + # TODO: slight indexing change e.g. pos[:, nh1] + if ( + res.get_resname() == 'ARG' and + all(mask[atom_index] for atom_index in (cd, nh1, nh2)) and + (np.linalg.norm(pos[nh1] - pos[cd]) > + np.linalg.norm(pos[nh2] - pos[cd])) ): pos[nh1], pos[nh2] = pos[nh2].copy(), pos[nh1].copy() mask[nh1], mask[nh2] = mask[nh2].copy(), mask[nh1].copy() @@ -494,6 +579,7 @@ def get_atom_coords( all_atom_positions[res_index] = pos all_atom_mask[res_index] = mask + # TODO: stack all the tokenized residue elements into a single array if _zero_center_positions: binary_mask = all_atom_mask.astype(bool) translation_vec = all_atom_positions[binary_mask].mean(axis=0)