Skip to content

Commit

Permalink
Merge pull request #268 from KosinskiLab/update-truemultimer
Browse files Browse the repository at this point in the history
Update truemultimer by making it run while predicting models
  • Loading branch information
dingquanyu authored Mar 7, 2024
2 parents 593e30b + 02b79fe commit 473fbfa
Show file tree
Hide file tree
Showing 7 changed files with 302 additions and 8 deletions.
1 change: 1 addition & 0 deletions .github/workflows/github_actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ jobs:
pip install -e .
pip install -e alphapulldown/ColabFold --no-deps
pip install -e alphafold --no-deps
python test/test_python_imports.py
- name: Install dependencies in AlphaLink2 setup.py
run: |
Expand Down
140 changes: 140 additions & 0 deletions alphapulldown/multimeric_template_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@

import os, logging, csv,sys
from pathlib import Path
from alphafold.data.templates import (
_extract_template_features,
_build_query_to_hit_index_mapping)
from alphafold.data.templates import SingleHitResult
from alphafold.data.mmcif_parsing import ParsingResult
from alphafold.data.parsers import TemplateHit
from alphapulldown.remove_clashes_low_plddt import MmcifChainFiltered
from typing import Optional
import shutil
import numpy as np

def prepare_multimeric_template_meta_info(csv_path:str, mmt_dir:str) -> dict:
"""
Adapted from https://github.com/KosinskiLab/AlphaPulldown/blob/231863af7faa61fa04d45829c90a3bab9d9e2ff2/alphapulldown/create_individual_features_with_templates.py#L107C1-L159C38
by @DimaMolod
Args:
csv_path: Path to the text file with descriptions
features.csv: A coma-separated file with three columns: PROTEIN name, PDB/CIF template, chain ID.
mmt_dir: Path to directory with multimeric template mmCIF files
Returns:
a list of dictionaries with the following structure:
[{"protein": protein_name, "sequence" :sequence", templates": [pdb_files], "chains": [chain_id]}, ...]}]
"""
# Parse csv file
parsed_dict = {}
with open(csv_path, newline="") as csvfile:
csvreader = csv.reader(csvfile)
for row in csvreader:
# skip empty lines
if not row:
continue
if len(row) == 3:
protein, template, chain = [item.strip() for item in row]
assert os.path.exists(os.path.join(mmt_dir,template)), f"Provided {template} cannot be found in {mmt_dir}. Abort"
if protein not in parsed_dict:
parsed_dict[protein] = {
template:chain
}
else:
logging.error(f"Invalid line found in the file {csv_path}: {row}")
sys.exit()

return parsed_dict

def obtain_kalign_binary_path() -> Optional[str]:
assert shutil.which('kalign') is not None, "Could not find kalign in your environment"
return shutil.which('kalign')


def parse_mmcif_file(file_id:str,mmcif_file:str) -> ParsingResult:
"""
Args:
file_id: A string identifier for this file. Should be unique within the
collection of files being processed.
mmcif_file: path to the target mmcif file
Returns:
A ParsingResult object
"""
try:
mmcif_filtered_obj = MmcifChainFiltered(Path(mmcif_file),file_id)
parsing_result = mmcif_filtered_obj.parsing_result
except FileNotFoundError as e:
parsing_result = None
print(f"{mmcif_file} could not be found")

return parsing_result

def create_template_hit(index:int, name:str,query:str) -> TemplateHit:
"""
Create the new template hits and mapping. Currently only supports the cases
where the query sequence and the template sequence are identical
Args:
index: index of the hit e.g. numberXX of the customised templates
name: name of the hit e.g. pdbid_CHAIN
query: query sequence
Returns:
A TemplateHit object in which hit and query sequences are identical
"""
aligned_cols = len(query)
sum_probs = None
hit_sequence = query
indices_hit, indices_query = list(range(aligned_cols)),list(range(aligned_cols))
return TemplateHit(index=index, name=name,aligned_cols = aligned_cols,
sum_probs = sum_probs,query = query, hit_sequence = hit_sequence,
indices_query = indices_query, indices_hit = indices_hit)

def extract_multimeric_template_features_for_single_chain(
query_seq:str,
pdb_id:str,
chain_id:str,
mmcif_file:str,
index:int =1,

) -> SingleHitResult:
"""
Args:
index: index of the hit e.g. numberXX of the customised templates
query_seq: the sequence to be modelled, single chain
pdb_id: the id of the PDB file or the name of the pdb file where the multimeric template structure is written
chain_id: which chain of the multimeric template that this query sequence will be aligned to
mmcif_file: path to the .cif file that is going to be parsed.
Returns:
A SingleHitResult object
"""
hit = create_template_hit(index, name=f"{pdb_id}_{chain_id}", query=query_seq)
mapping = _build_query_to_hit_index_mapping(hit.query, hit.hit_sequence, hit.indices_hit, hit.indices_query,query_seq)
mmcif_parse_result = parse_mmcif_file(pdb_id, mmcif_file)
if (mmcif_parse_result is not None) and (mmcif_parse_result.mmcif_object is not None):
try:
features, realign_warning = _extract_template_features(
mmcif_object = mmcif_parse_result.mmcif_object,
pdb_id = pdb_id,
mapping = mapping,
template_sequence = query_seq,
query_sequence = query_seq,
template_chain_id = chain_id,
kalign_binary_path = obtain_kalign_binary_path()
)
features['template_sum_probs'] = [0]*4
# add 1 dimension to template_all_atom_positions and replicate 4 times
features['template_all_atom_positions'] = np.tile(features['template_all_atom_positions'],(4,1,1,1))
features['template_all_atom_position'] = features['template_all_atom_positions']
# replicate all_atom_mask
features['template_all_atom_mask'] = features['template_all_atom_masks'][np.newaxis,:]
for k in ['template_sequence','template_domain_names',
'template_aatype']:
features[k] = [features[k]]*4
return SingleHitResult(features=features, error=None, warning=realign_warning)
except Exception as e:
print(f"Failed to extract template features")
return SingleHitResult(features=None, error=None, warning=None)
43 changes: 39 additions & 4 deletions alphapulldown/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from alphafold.data import msa_pairing
from alphafold.data import feature_processing
from pathlib import Path as plPath
from colabfold.batch import (unserialize_msa,
get_msa_and_templates,
msa_to_str, build_monomer_feature)
from colabfold.batch import unserialize_msa, get_msa_and_templates, msa_to_str, build_monomer_feature
from alphapulldown.multimeric_template_utils import (extract_multimeric_template_features_for_single_chain,
prepare_multimeric_template_meta_info)


@contextlib.contextmanager
Expand Down Expand Up @@ -512,18 +512,38 @@ class MultimericObject:
Args
index: assign a unique index ranging from 0 just to identify different multimer jobs
interactors: individual interactors that are to be concatenated
pair_msa: boolean, tells the programme whether to pair MSA or not
multimeric_mode: boolean, tells the programme whether use multimeric templates or not
multimeric_template_meta_data: a csv with the format {"monomer_A":{"xxx.cif":"chainID"},"monomer_B":{"yyy.cif":"chainID"}}
multimeric_template_dir: a directory where all the multimeric templates mmcifs files are stored
"""

def __init__(self, interactors: list, pair_msa: bool = True, multimeric_mode: bool = False) -> None:
def __init__(self, interactors: list, pair_msa: bool = True,
multimeric_mode: bool = False,
multimeric_template_meta_data: str = None,
multimeric_template_dir:str = None) -> None:
self.description = ""
self.interactors = interactors
self.build_description_monomer_mapping()
self.pair_msa = pair_msa
self.multimeric_mode = multimeric_mode
self.chain_id_map = dict()
self.input_seqs = []
self.multimeric_template_dir = multimeric_template_dir
self.create_output_name()

if multimeric_template_meta_data is not None:
self.multimeric_template_meta_data = prepare_multimeric_template_meta_info(multimeric_template_meta_data,
self.multimeric_template_dir)

if self.multimeric_mode:
self.create_multimeric_template_features()
self.create_all_chain_features()
pass

def build_description_monomer_mapping(self):
"""This method constructs a dictionary {description: monomer}"""
self.monomers_mapping = {m.description: m for m in self.interactors}

def get_all_residue_index(self):
"""get all residue indexes from subunits"""
Expand Down Expand Up @@ -607,6 +627,21 @@ def create_multichain_mask(self):
# DEBUG
self.save_binary_matrix(multichain_mask, "multichain_mask.png")
return multichain_mask

def create_multimeric_template_features(self):
"""A method of creating multimeric template features"""
assert self.multimeric_template_meta_data is not None, "You chose to use multimeric template mode but multimric template information is missing. Abort"
for monomer_name in self.multimeric_template_meta_data:
for k,v in self.multimeric_template_meta_data[monomer_name].items():
curr_monomer = self.monomers_mapping[monomer_name]
assert k.endswith(".cif"), "The multimeric template file you provided does not seem to be a mmcif file. Please check your format and make sure it ends with .cif"
assert os.path.exists(os.path.join(self.multimeric_template_dir,k)), f"Your provided {k} cannot be found in: {self.multimeric_template_dir}. Abort"
pdb_id = k.split('.cif')[0]
multimeric_template_features = extract_multimeric_template_features_for_single_chain(query_seq=curr_monomer.sequence,
pdb_id=pdb_id,chain_id=v,
mmcif_file=os.path.join(self.multimeric_template_dir,k))
curr_monomer.feature_dict.update(multimeric_template_features.features)


def pair_and_merge(self, all_chain_features):
"""merge all chain features"""
Expand Down
8 changes: 4 additions & 4 deletions alphapulldown/remove_clashes_low_plddt.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ def __init__(self, input_file_path, code, chain_id=None):
self.input_file_path = input_file_path.with_suffix(".cif")
with open(self.input_file_path) as f:
mmcif = f.read()
parsing_result = parse(file_id=code, mmcif_string=mmcif)
if parsing_result.errors:
raise Exception(f"Can't parse mmcif file {self.input_file_path}: {parsing_result.errors}")
mmcif_object = parsing_result.mmcif_object
self.parsing_result = parse(file_id=code, mmcif_string=mmcif)
if self.parsing_result.errors:
raise Exception(f"Can't parse mmcif file {self.input_file_path}: {self.parsing_result.errors}")
mmcif_object = self.parsing_result.mmcif_object
self.seqres_to_structure = mmcif_object.seqres_to_structure[chain_id]
structure, sequence_atom = self.extract_chain(mmcif_object.structure, chain_id)
self.sequence_atom = sequence_atom
Expand Down
6 changes: 6 additions & 0 deletions alphapulldown/run_multimer_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from alphapulldown.utils import (create_interactors, read_all_proteins, read_custom, make_dir_monomer_dictionary,
load_monomer_objects, check_output_dir, create_model_runners_and_random_seed,
create_and_save_pae_plots, post_prediction_process)
from alphapulldown.multimeric_template_utils import prepare_multimeric_template_meta_info
from itertools import combinations
from alphapulldown.objects import MultimericObject
import os
Expand Down Expand Up @@ -91,6 +92,9 @@
"remove_result_pickles", False,
"Whether the result pickles that do not belong to the best model are going to be removed. Default is False"
)
flags.DEFINE_string("description_file", None,
"Path to the text file with multimeric template instructions")
flags.DEFINE_string("path_to_mmt", None, "Path to directory with multimeric template mmCIF files")
flags.DEFINE_enum("unifold_model_name", "multimer_af2",
["multimer_af2", "multimer_ft", "multimer", "multimer_af2_v3", "multimer_af2_model45_v3"],
"choose unifold model structure")
Expand Down Expand Up @@ -231,6 +235,8 @@ def create_multimer_objects(data, monomer_objects_dir, pair_msa=True):
interactors = create_interactors(data, monomer_objects_dir, job_idx)
if len(interactors) > 1:
multimer = MultimericObject(interactors=interactors, pair_msa=pair_msa,
multimeric_template_meta_data=FLAGS.description_file,
multimeric_template_dir=FLAGS.path_to_mmt,
multimeric_mode=FLAGS.multimeric_mode)
logging.info(f"done creating multimer {multimer.description}")
multimers.append(multimer)
Expand Down
46 changes: 46 additions & 0 deletions test/test_create_multimeric_template_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import unittest
import numpy as np
import gzip,pickle,shutil
from alphapulldown.objects import MultimericObject
from alphafold.data.templates import _build_query_to_hit_index_mapping
from alphapulldown import multimeric_template_utils
class TestMultimericTemplateFeatures(unittest.TestCase):
def setUp(self):
self.mmcif_file = "./test/test_data/true_multimer/3L4Q.cif"
self.monomer1 = pickle.load(open("./test/test_data/true_multimer/features/3L4Q_A.pkl",'rb'))
self.monomer2 = pickle.load(open("./test/test_data/true_multimer/features/3L4Q_C.pkl",'rb'))
self.kalign_binary_path = shutil.which('kalign')
self.mmt_dir = './test/test_data/true_multimer/'
self.instruction_file = "./test/test_data/true_multimer/description_file.csv"
self.data_dir = '/scratch/AlphaFold_DBs/2.3.2'

def test_1_create_template_hit(self):
template_hit = multimeric_template_utils.create_template_hit(index=1, name='3l4q_A',query=self.monomer1.sequence)
self.assertEqual(self.monomer1.sequence,template_hit.hit_sequence)

def test_2_build_mapping(self):
template_hit = multimeric_template_utils.create_template_hit(index=1, name='3l4q_A',query=self.monomer1.sequence)
expected_mapping = {i:i for i in range(len(self.monomer1.sequence))}
mapping = _build_query_to_hit_index_mapping(template_hit.query,
template_hit.hit_sequence,
template_hit.indices_hit,
template_hit.indices_query,
self.monomer1.sequence)
self.assertEqual(expected_mapping, mapping)

def test_3_extract_multimeric_template_features(self):
single_hit_result = multimeric_template_utils.extract_multimeric_template_features_for_single_chain(self.monomer1.sequence,
pdb_id='3L4Q',
chain_id='A',
mmcif_file=self.mmcif_file)
self.assertIsNotNone(single_hit_result.features)

def test_4_parse_instraction_file(self):
"""Test if the instruction csv table is parsed properly"""
multimeric_template_meta = multimeric_template_utils.prepare_multimeric_template_meta_info(self.instruction_file,self.mmt_dir)
self.assertIsInstance(multimeric_template_meta, dict)
expected_dict = {"3L4Q_A":{"3L4Q.cif":"A"}, "3L4Q_C":{"3L4Q.cif":"C"}}
self.assertEqual(multimeric_template_meta,expected_dict)

if __name__ == "__main__":
unittest.main()
66 changes: 66 additions & 0 deletions test/test_python_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os, logging, csv,sys
from alphafold.data.templates import (_read_file,
_extract_template_features,
_build_query_to_hit_index_mapping)
from alphafold.data.templates import SingleHitResult
from alphafold.data import mmcif_parsing
from alphafold.data.mmcif_parsing import ParsingResult
from alphafold.data.parsers import TemplateHit
from typing import Optional
import shutil
import logging
import tempfile
import os
import contextlib
import numpy as np
from alphafold.data import parsers
from alphafold.data import pipeline_multimer
from alphafold.data import pipeline
from alphafold.data import msa_pairing
from alphafold.data import feature_processing
from pathlib import Path as plPath
from colabfold.batch import unserialize_msa, get_msa_and_templates, msa_to_str, build_monomer_feature
from alphapulldown.multimeric_template_utils import (extract_multimeric_template_features_for_single_chain,
prepare_multimeric_template_meta_info)

import itertools
from absl import app, logging
from alphapulldown.utils import (create_interactors, read_all_proteins, read_custom, make_dir_monomer_dictionary,
load_monomer_objects, check_output_dir, create_model_runners_and_random_seed,
create_and_save_pae_plots, post_prediction_process)
from alphapulldown.multimeric_template_utils import prepare_multimeric_template_meta_info
from itertools import combinations
from alphapulldown.objects import MultimericObject
import os
from pathlib import Path
from alphapulldown.predict_structure import predict, ModelsToRelax
from alphapulldown.utils import get_run_alphafold
from alphapulldown import __version__ as ap_version
from alphafold.version import __version__ as af_version
import json
from datetime import datetime
from alphafold.data.tools import jackhmmer
from alphapulldown.objects import ChoppedObject
from alphapulldown import __version__ as AP_VERSION
from alphafold.version import __version__ as AF_VERSION
import json
import os
import pickle
import logging
from alphapulldown.plot_pae import plot_pae
import alphafold
from alphafold.model import config
from alphafold.model import model
from alphafold.model import data
from alphafold.data import templates
import random
import subprocess
from alphafold.data import parsers
from pathlib import Path
import numpy as np
import sys
import datetime
import re
import hashlib
import glob
import importlib.util

0 comments on commit 473fbfa

Please sign in to comment.