Skip to content

Commit

Permalink
Merge pull request #277 from KosinskiLab/fix-9
Browse files Browse the repository at this point in the history
Fix 9
  • Loading branch information
dingquanyu authored Mar 1, 2024
2 parents cee4f42 + 655fc3c commit 2e88c42
Showing 1 changed file with 129 additions and 44 deletions.
173 changes: 129 additions & 44 deletions alphapulldown/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import tempfile
import os
import subprocess
import contextlib
import numpy as np
from alphafold.data import parsers
Expand All @@ -13,7 +14,9 @@
from alphafold.data import msa_pairing
from alphafold.data import feature_processing
from pathlib import Path as plPath
from colabfold.batch import unserialize_msa, get_msa_and_templates, msa_to_str, build_monomer_feature
from colabfold.batch import (unserialize_msa,
get_msa_and_templates,
msa_to_str, build_monomer_feature)


@contextlib.contextmanager
Expand Down Expand Up @@ -50,13 +53,49 @@ def uniprot_runner(self):
def uniprot_runner(self, uniprot_runner):
self._uniprot_runner = uniprot_runner

@staticmethod
def zip_msa_files(msa_output_path: str):
"""
A static method that zip individual msa files within the given msa_output_path folder
"""
def zip_individual_file(msa_file: plPath):
assert os.path.exists(msa_file)
cmd = f"gzip {msa_file}"
_ = subprocess.run(cmd, shell=True, capture_output=True, text=True)

msa_file_endings = ['.a3m', '.fasta', '.sto', '.hmm']
msa_files = [i for i in plPath(
msa_output_path).iterdir() if i.suffix in msa_file_endings]
if len(msa_files) > 0:
for msa_file in msa_files:
zip_individual_file(msa_file)

@staticmethod
def unzip_msa_files(msa_output_path: str):
"""
A static method that unzip msa files in a folder if they exist
"""
def unzip_individual_file(msa_file: plPath):
assert os.path.exists(msa_file)
cmd = f"gunzip {msa_file}"
_ = subprocess.run(cmd, shell=True, capture_output=True, text=True)

zipped_files = [i for i in plPath(
msa_output_path).iterdir() if i.suffix == '.gz']
if len(zipped_files) > 0:
for zipped_file in zipped_files:
unzip_individual_file(zipped_file)
return True # means it has used zipped msa files
else:
return False

def all_seq_msa_features(
self,
input_fasta_path,
uniprot_msa_runner,
save_msa,
output_dir=None,
use_precomuted_msa=False,
use_precomuted_msa=False
):
"""Get MSA features for unclustered uniprot, for pairing later on."""
if not use_precomuted_msa:
Expand Down Expand Up @@ -102,9 +141,15 @@ def all_seq_msa_features(
return feats

def make_features(
self, pipeline, output_dir=None, use_precomputed_msa=False, save_msa=True
self, pipeline, output_dir=None,
use_precomputed_msa=False,
save_msa=True, compress_msa_files=False
):
"""a method that make msa and template features"""
# firstly check if there are zipped msa files. unzip it if there is zipped msa files
using_zipped_msa_files = MonomericObject.unzip_msa_files(
os.path.join(output_dir, self.description))

if not use_precomputed_msa:
if not save_msa:
"""this means no msa files are going to be saved"""
Expand All @@ -125,14 +170,19 @@ def make_features(
"""this means no precomputed msa available and will save output msa files"""
msa_output_dir = os.path.join(output_dir, self.description)
sequence_str = f">{self.description}\n{self.sequence}"
logging.info("will save msa files in :{}".format(msa_output_dir))
logging.info(
"will save msa files in :{}".format(msa_output_dir))
plPath(msa_output_dir).mkdir(parents=True, exist_ok=True)
with temp_fasta_file(sequence_str) as fasta_file:
self.feature_dict = pipeline.process(fasta_file, msa_output_dir)
self.feature_dict = pipeline.process(
fasta_file, msa_output_dir)
pairing_results = self.all_seq_msa_features(
fasta_file, self._uniprot_runner, save_msa, msa_output_dir
)
self.feature_dict.update(pairing_results)

if compress_msa_files:
MonomericObject.zip_msa_files(msa_output_dir)
else:
"""This means precomputed msa files are available"""
msa_output_dir = os.path.join(output_dir, self.description)
Expand All @@ -144,7 +194,8 @@ def make_features(
)
sequence_str = f">{self.description}\n{self.sequence}"
with temp_fasta_file(sequence_str) as fasta_file:
self.feature_dict = pipeline.process(fasta_file, msa_output_dir)
self.feature_dict = pipeline.process(
fasta_file, msa_output_dir)
pairing_results = self.all_seq_msa_features(
fasta_file,
self._uniprot_runner,
Expand All @@ -154,7 +205,11 @@ def make_features(
)
self.feature_dict.update(pairing_results)

def mk_template(self, a3m_lines,
if using_zipped_msa_files:
MonomericObject.zip_msa_files(
os.path.join(output_dir, self.description))

def mk_template(self, a3m_lines,
pipeline, query_sequence):
"""
Overwrite ColabFold's original mk_template to incorporate max_template data argument
Expand All @@ -169,21 +224,26 @@ def mk_template(self, a3m_lines,
hmm_build_runner = pipeline.template_searcher.hmmbuild_runner
hmm_profile = hmm_build_runner.build_profile_from_a3m(a3m_lines)
query_result = pipeline.template_searcher.query_with_hmm(hmm_profile)
template_hits = pipeline.template_searcher.get_template_hits(query_result,query_sequence)
template_hits = pipeline.template_searcher.get_template_hits(
query_result, query_sequence)
templates_result = template_featuriser.get_templates(
query_sequence=query_sequence, hits=template_hits
)
return dict(templates_result.features)

def make_mmseq_features(
self, DEFAULT_API_SERVER,
self, DEFAULT_API_SERVER,
pipeline=None, output_dir=None,
compress_msa_files=False
):
"""
A method to use mmseq_remote to calculate msa
Modified from ColabFold: https://github.com/sokrypton/ColabFold
"""

# first check if there are zipped a3m files
using_zipped_msa_files = MonomericObject.unzip_msa_files(
os.path.join(output_dir, self.description))

logging.info("You chose to calculate MSA with mmseq2.\nPlease also cite: Mirdita M, Schütze K, Moriwaki Y, Heo L, Ovchinnikov S and Steinegger M. ColabFold: Making protein folding accessible to all. Nature Methods (2022) doi: 10.1038/s41592-022-01488-1")
msa_mode = "MMseqs2 (UniRef+Environmental)"
keep_existing_results = True
Expand All @@ -193,12 +253,15 @@ def make_mmseq_features(
if keep_existing_results and plPath(result_zip).is_file():
logging.info(f"Skipping {self.description} (result.zip)")

logging.info(f"looking for possible precomputed a3m at {os.path.join(result_dir, self.description + '.a3m')}")
logging.info(
f"looking for possible precomputed a3m at {os.path.join(result_dir, self.description + '.a3m')}")
try:
logging.info(f"input is {os.path.join(result_dir, self.description + '.a3m')}")
logging.info(
f"input is {os.path.join(result_dir, self.description + '.a3m')}")
input_path = os.path.join(result_dir, self.description + '.a3m')
a3m_lines = [plPath(input_path).read_text()]
logging.info(f"Finished parsing the precalculated a3m_file\nNow will search for template")
logging.info(
f"Finished parsing the precalculated a3m_file\nNow will search for template")
except:
a3m_lines = None

Expand All @@ -219,9 +282,9 @@ def make_mmseq_features(
query_seqs_cardinality,
template_features,
) = get_msa_and_templates(
jobname = self.description,
jobname=self.description,
query_sequences=self.sequence,
a3m_lines = None,
a3m_lines=None,
result_dir=plPath(result_dir),
msa_mode=msa_mode,
use_templates=use_templates,
Expand All @@ -233,15 +296,22 @@ def make_mmseq_features(
msa = msa_to_str(
unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality
)
plPath(os.path.join(result_dir, self.description + ".a3m")).write_text(msa)
a3m_lines = [plPath(os.path.join(result_dir, self.description + ".a3m")).read_text()]
plPath(os.path.join(result_dir, self.description + ".a3m")
).write_text(msa)
a3m_lines = [
plPath(os.path.join(result_dir, self.description + ".a3m")).read_text()]

if compress_msa_files:
MonomericObject.zip_msa_files(
os.path.join(result_dir, self.description))
# unserialize_msa was from colabfold.batch and originally will only create mock template features
# below will search against pdb70 database using hhsearch and create real template features
logging.info("will search for templates in local template database")
a3m_lines[0] = "\n".join([line for line in a3m_lines[0].splitlines() if not line.startswith("#")])
a3m_lines[0] = "\n".join(
[line for line in a3m_lines[0].splitlines() if not line.startswith("#")])
template_features = self.mk_template(a3m_lines[0],
pipeline, query_sequence=self.sequence)
self.feature_dict = build_monomer_feature(self.sequence, unpaired_msa[0],
pipeline, query_sequence=self.sequence)
self.feature_dict = build_monomer_feature(self.sequence, unpaired_msa[0],
template_features)

# update feature_dict with
Expand All @@ -254,6 +324,9 @@ def make_mmseq_features(
}
self.feature_dict.update(feats)

if using_zipped_msa_files:
MonomericObject.zip_msa_files(
os.path.join(output_dir, self.description))

class ChoppedObject(MonomericObject):
"""chopped monomeric objects"""
Expand Down Expand Up @@ -284,15 +357,17 @@ def prepare_new_msa_feature(self, msa_feature, start_point, end_point):
new_seq_length = np.array([length] * length)
new_aa_type = msa_feature["aatype"][start_point:end_point, :]
new_between_segment_residue = msa_feature["between_segment_residues"][
start_point:end_point
]
start_point:end_point
]
new_domain_name = msa_feature["domain_name"]
new_residue_index = msa_feature["residue_index"][start_point:end_point]
new_sequence = np.array([msa_feature["sequence"][0][start_point:end_point]])
new_deletion_mtx = msa_feature["deletion_matrix_int"][:, start_point:end_point]
new_sequence = np.array(
[msa_feature["sequence"][0][start_point:end_point]])
new_deletion_mtx = msa_feature["deletion_matrix_int"][:,
start_point:end_point]
new_deletion_mtx_all_seq = msa_feature["deletion_matrix_int_all_seq"][
:, start_point:end_point
]
:, start_point:end_point
]
new_msa = msa_feature["msa"][:, start_point:end_point]
new_msa_all_seq = msa_feature["msa_all_seq"][:, start_point:end_point]
new_num_alignments = np.array([msa_feature["msa"].shape[0]] * length)
Expand Down Expand Up @@ -328,14 +403,14 @@ def prepare_new_template_feature_dict(
"""
start_point = start_point - 1
new_template_aatype = template_feature["template_aatype"][
:, start_point:end_point, :
]
:, start_point:end_point, :
]
new_template_all_atom_masks = template_feature["template_all_atom_masks"][
:, start_point:end_point, :
]
:, start_point:end_point, :
]
new_template_all_atom_positions = template_feature[
"template_all_atom_positions"
][:, start_point:end_point, :, :]
"template_all_atom_positions"
][:, start_point:end_point, :, :]
new_template_domain_names = template_feature["template_domain_names"]
new_template_sequence = template_feature["template_sequence"]
new_template_sum_probs = template_feature["template_sum_probs"]
Expand Down Expand Up @@ -470,10 +545,11 @@ def create_chain_id_map(self):
multimer_sequence_str = ""
for interactor in self.interactors:
multimer_sequence_str = (
multimer_sequence_str
+ f">{interactor.description}\n{interactor.sequence}\n"
multimer_sequence_str
+ f">{interactor.description}\n{interactor.sequence}\n"
)
self.input_seqs, input_descs = parsers.parse_fasta(multimer_sequence_str)
self.input_seqs, input_descs = parsers.parse_fasta(
multimer_sequence_str)
self.chain_id_map = pipeline_multimer._make_chain_id_map(
sequences=self.input_seqs, descriptions=input_descs
)
Expand All @@ -500,7 +576,8 @@ def save_binary_matrix(self, matrix, file_path):
text_width, text_height = draw.textsize(text, font=font)
x = (col + 0.5) * image.width / width - text_width / 2
y = image.height - text_height
draw.text((x, y), text, font=font, fill=(0, 0, 0)) # Set text fill color to black
# Set text fill color to black
draw.text((x, y), text, font=font, fill=(0, 0, 0))

image.save(file_path)

Expand All @@ -510,18 +587,22 @@ def create_multichain_mask(self):
no_gap_map = []
for interactor in self.interactors:
temp_length = len(interactor.sequence)
pdb_map.extend([interactor.feature_dict['template_domain_names'][0]] * temp_length)
pdb_map.extend(
[interactor.feature_dict['template_domain_names'][0]] * temp_length)
has_no_gaps = [True] * temp_length
# for each template in the interactor, check for gaps in sequence
for template_sequence in interactor.feature_dict['template_sequence']:
is_not_gap = [s != '-' for s in template_sequence.decode("utf-8").strip()]
is_not_gap = [
s != '-' for s in template_sequence.decode("utf-8").strip()]
# False if any of the templates has a gap in this position
has_no_gaps = [a and b for a, b in zip(has_no_gaps, is_not_gap)]
has_no_gaps = [a and b for a,
b in zip(has_no_gaps, is_not_gap)]
no_gap_map.extend(has_no_gaps)
multichain_mask = np.zeros((len(pdb_map), len(pdb_map)), dtype=int)
for index1, id1 in enumerate(pdb_map):
for index2, id2 in enumerate(pdb_map):
if (id1[:4] == id2[:4]): # and (no_gap_map[index1] and no_gap_map[index2]):
# and (no_gap_map[index1] and no_gap_map[index2]):
if (id1[:4] == id2[:4]):
multichain_mask[index1, index2] = 1
# DEBUG
self.save_binary_matrix(multichain_mask, "multichain_mask.png")
Expand All @@ -533,10 +614,13 @@ def pair_and_merge(self, all_chain_features):
MSA_CROP_SIZE = 2048
feature_processing.process_unmerged_features(all_chain_features)
np_chains_list = list(all_chain_features.values())
pair_msa_sequences = self.pair_msa and not feature_processing._is_homomer_or_monomer(np_chains_list)
pair_msa_sequences = self.pair_msa and not feature_processing._is_homomer_or_monomer(
np_chains_list)
if pair_msa_sequences:
np_chains_list = msa_pairing.create_paired_features(chains=np_chains_list)
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(np_chains_list)
np_chains_list = msa_pairing.create_paired_features(
chains=np_chains_list)
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(
np_chains_list)
np_chains_list = feature_processing.crop_chains(
np_chains_list,
msa_crop_size=MSA_CROP_SIZE,
Expand Down Expand Up @@ -587,5 +671,6 @@ def create_all_chain_features(self):
self.feature_dict['multichain_mask'] = self.multichain_mask
# save used templates
for i in self.interactors:
logging.info("Used multimeric templates for protein {}".format(i.description))
logging.info(
"Used multimeric templates for protein {}".format(i.description))
logging.info(i.feature_dict['template_domain_names'])

0 comments on commit 2e88c42

Please sign in to comment.