Skip to content

Commit

Permalink
add options to zip msa files
Browse files Browse the repository at this point in the history
  • Loading branch information
dingquanyu committed Mar 1, 2024
1 parent c32d1e2 commit 0615e2e
Showing 1 changed file with 95 additions and 43 deletions.
138 changes: 95 additions & 43 deletions alphapulldown/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import tempfile
import os
from zipfile import ZipFile
import contextlib
import numpy as np
from alphafold.data import parsers
Expand All @@ -13,7 +14,9 @@
from alphafold.data import msa_pairing
from alphafold.data import feature_processing
from pathlib import Path as plPath
from colabfold.batch import unserialize_msa, get_msa_and_templates, msa_to_str, build_monomer_feature
from colabfold.batch import (unserialize_msa,
get_msa_and_templates,
msa_to_str, build_monomer_feature)


@contextlib.contextmanager
Expand Down Expand Up @@ -50,13 +53,32 @@ def uniprot_runner(self):
def uniprot_runner(self, uniprot_runner):
self._uniprot_runner = uniprot_runner

@staticmethod
def zip_msa_files(msa_output_path: str):
"""
A static method that zip individual msa files within the given msa_output_path folder
"""
def zip_individual_file(msa_file: plPath):
assert os.path.exists(msa_file)
with ZipFile(os.path.join(msa_file.parent, msa_files.name+".zip"), "w") as myzip:
myzip.write(msa_file)
myzip.close()
os.remove(msa_file)

msa_file_endings = ['.a3m', '.fasta', '.sto', '.hmm']
msa_files = [i for i in plPath(
msa_output_path).iterdir() if i.suffix in msa_file_endings]
if len(msa_files) > 0:
for msa_file in msa_files:
zip_individual_file(msa_file)

def all_seq_msa_features(
self,
input_fasta_path,
uniprot_msa_runner,
save_msa,
output_dir=None,
use_precomuted_msa=False,
use_precomuted_msa=False
):
"""Get MSA features for unclustered uniprot, for pairing later on."""
if not use_precomuted_msa:
Expand Down Expand Up @@ -102,7 +124,8 @@ def all_seq_msa_features(
return feats

def make_features(
self, pipeline, output_dir=None, use_precomputed_msa=False, save_msa=True
self, pipeline, output_dir=None,
use_precomputed_msa=False, save_msa=True, compress_msa_files=False
):
"""a method that make msa and template features"""
if not use_precomputed_msa:
Expand All @@ -125,14 +148,19 @@ def make_features(
"""this means no precomputed msa available and will save output msa files"""
msa_output_dir = os.path.join(output_dir, self.description)
sequence_str = f">{self.description}\n{self.sequence}"
logging.info("will save msa files in :{}".format(msa_output_dir))
logging.info(
"will save msa files in :{}".format(msa_output_dir))
plPath(msa_output_dir).mkdir(parents=True, exist_ok=True)
with temp_fasta_file(sequence_str) as fasta_file:
self.feature_dict = pipeline.process(fasta_file, msa_output_dir)
self.feature_dict = pipeline.process(
fasta_file, msa_output_dir)
pairing_results = self.all_seq_msa_features(
fasta_file, self._uniprot_runner, save_msa, msa_output_dir
)
self.feature_dict.update(pairing_results)

if compress_msa_files:
MonomericObject.zip_msa_files(msa_output_dir)
else:
"""This means precomputed msa files are available"""
msa_output_dir = os.path.join(output_dir, self.description)
Expand All @@ -144,7 +172,8 @@ def make_features(
)
sequence_str = f">{self.description}\n{self.sequence}"
with temp_fasta_file(sequence_str) as fasta_file:
self.feature_dict = pipeline.process(fasta_file, msa_output_dir)
self.feature_dict = pipeline.process(
fasta_file, msa_output_dir)
pairing_results = self.all_seq_msa_features(
fasta_file,
self._uniprot_runner,
Expand All @@ -154,7 +183,7 @@ def make_features(
)
self.feature_dict.update(pairing_results)

def mk_template(self, a3m_lines,
def mk_template(self, a3m_lines,
pipeline, query_sequence):
"""
Overwrite ColabFold's original mk_template to incorporate max_template data argument
Expand All @@ -169,15 +198,17 @@ def mk_template(self, a3m_lines,
hmm_build_runner = pipeline.template_searcher.hmmbuild_runner
hmm_profile = hmm_build_runner.build_profile_from_a3m(a3m_lines)
query_result = pipeline.template_searcher.query_with_hmm(hmm_profile)
template_hits = pipeline.template_searcher.get_template_hits(query_result,query_sequence)
template_hits = pipeline.template_searcher.get_template_hits(
query_result, query_sequence)
templates_result = template_featuriser.get_templates(
query_sequence=query_sequence, hits=template_hits
)
return dict(templates_result.features)

def make_mmseq_features(
self, DEFAULT_API_SERVER,
self, DEFAULT_API_SERVER,
pipeline=None, output_dir=None,
compress_msa_files=False
):
"""
A method to use mmseq_remote to calculate msa
Expand All @@ -193,12 +224,15 @@ def make_mmseq_features(
if keep_existing_results and plPath(result_zip).is_file():
logging.info(f"Skipping {self.description} (result.zip)")

logging.info(f"looking for possible precomputed a3m at {os.path.join(result_dir, self.description + '.a3m')}")
logging.info(
f"looking for possible precomputed a3m at {os.path.join(result_dir, self.description + '.a3m')}")
try:
logging.info(f"input is {os.path.join(result_dir, self.description + '.a3m')}")
logging.info(
f"input is {os.path.join(result_dir, self.description + '.a3m')}")
input_path = os.path.join(result_dir, self.description + '.a3m')
a3m_lines = [plPath(input_path).read_text()]
logging.info(f"Finished parsing the precalculated a3m_file\nNow will search for template")
logging.info(
f"Finished parsing the precalculated a3m_file\nNow will search for template")
except:
a3m_lines = None

Expand All @@ -219,9 +253,9 @@ def make_mmseq_features(
query_seqs_cardinality,
template_features,
) = get_msa_and_templates(
jobname = self.description,
jobname=self.description,
query_sequences=self.sequence,
a3m_lines = None,
a3m_lines=None,
result_dir=plPath(result_dir),
msa_mode=msa_mode,
use_templates=use_templates,
Expand All @@ -233,15 +267,21 @@ def make_mmseq_features(
msa = msa_to_str(
unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality
)
plPath(os.path.join(result_dir, self.description + ".a3m")).write_text(msa)
a3m_lines = [plPath(os.path.join(result_dir, self.description + ".a3m")).read_text()]
plPath(os.path.join(result_dir, self.description + ".a3m")
).write_text(msa)
a3m_lines = [
plPath(os.path.join(result_dir, self.description + ".a3m")).read_text()]

if compress_msa_files:
MonomericObject.zip_msa_files(os.path.join(result_dir, self.description))
# unserialize_msa was from colabfold.batch and originally will only create mock template features
# below will search against pdb70 database using hhsearch and create real template features
logging.info("will search for templates in local template database")
a3m_lines[0] = "\n".join([line for line in a3m_lines[0].splitlines() if not line.startswith("#")])
a3m_lines[0] = "\n".join(
[line for line in a3m_lines[0].splitlines() if not line.startswith("#")])
template_features = self.mk_template(a3m_lines[0],
pipeline, query_sequence=self.sequence)
self.feature_dict = build_monomer_feature(self.sequence, unpaired_msa[0],
pipeline, query_sequence=self.sequence)
self.feature_dict = build_monomer_feature(self.sequence, unpaired_msa[0],
template_features)

# update feature_dict with
Expand Down Expand Up @@ -284,15 +324,17 @@ def prepare_new_msa_feature(self, msa_feature, start_point, end_point):
new_seq_length = np.array([length] * length)
new_aa_type = msa_feature["aatype"][start_point:end_point, :]
new_between_segment_residue = msa_feature["between_segment_residues"][
start_point:end_point
]
start_point:end_point
]
new_domain_name = msa_feature["domain_name"]
new_residue_index = msa_feature["residue_index"][start_point:end_point]
new_sequence = np.array([msa_feature["sequence"][0][start_point:end_point]])
new_deletion_mtx = msa_feature["deletion_matrix_int"][:, start_point:end_point]
new_sequence = np.array(
[msa_feature["sequence"][0][start_point:end_point]])
new_deletion_mtx = msa_feature["deletion_matrix_int"][:,
start_point:end_point]
new_deletion_mtx_all_seq = msa_feature["deletion_matrix_int_all_seq"][
:, start_point:end_point
]
:, start_point:end_point
]
new_msa = msa_feature["msa"][:, start_point:end_point]
new_msa_all_seq = msa_feature["msa_all_seq"][:, start_point:end_point]
new_num_alignments = np.array([msa_feature["msa"].shape[0]] * length)
Expand Down Expand Up @@ -328,14 +370,14 @@ def prepare_new_template_feature_dict(
"""
start_point = start_point - 1
new_template_aatype = template_feature["template_aatype"][
:, start_point:end_point, :
]
:, start_point:end_point, :
]
new_template_all_atom_masks = template_feature["template_all_atom_masks"][
:, start_point:end_point, :
]
:, start_point:end_point, :
]
new_template_all_atom_positions = template_feature[
"template_all_atom_positions"
][:, start_point:end_point, :, :]
"template_all_atom_positions"
][:, start_point:end_point, :, :]
new_template_domain_names = template_feature["template_domain_names"]
new_template_sequence = template_feature["template_sequence"]
new_template_sum_probs = template_feature["template_sum_probs"]
Expand Down Expand Up @@ -470,10 +512,11 @@ def create_chain_id_map(self):
multimer_sequence_str = ""
for interactor in self.interactors:
multimer_sequence_str = (
multimer_sequence_str
+ f">{interactor.description}\n{interactor.sequence}\n"
multimer_sequence_str
+ f">{interactor.description}\n{interactor.sequence}\n"
)
self.input_seqs, input_descs = parsers.parse_fasta(multimer_sequence_str)
self.input_seqs, input_descs = parsers.parse_fasta(
multimer_sequence_str)
self.chain_id_map = pipeline_multimer._make_chain_id_map(
sequences=self.input_seqs, descriptions=input_descs
)
Expand All @@ -500,7 +543,8 @@ def save_binary_matrix(self, matrix, file_path):
text_width, text_height = draw.textsize(text, font=font)
x = (col + 0.5) * image.width / width - text_width / 2
y = image.height - text_height
draw.text((x, y), text, font=font, fill=(0, 0, 0)) # Set text fill color to black
# Set text fill color to black
draw.text((x, y), text, font=font, fill=(0, 0, 0))

image.save(file_path)

Expand All @@ -510,18 +554,22 @@ def create_multichain_mask(self):
no_gap_map = []
for interactor in self.interactors:
temp_length = len(interactor.sequence)
pdb_map.extend([interactor.feature_dict['template_domain_names'][0]] * temp_length)
pdb_map.extend(
[interactor.feature_dict['template_domain_names'][0]] * temp_length)
has_no_gaps = [True] * temp_length
# for each template in the interactor, check for gaps in sequence
for template_sequence in interactor.feature_dict['template_sequence']:
is_not_gap = [s != '-' for s in template_sequence.decode("utf-8").strip()]
is_not_gap = [
s != '-' for s in template_sequence.decode("utf-8").strip()]
# False if any of the templates has a gap in this position
has_no_gaps = [a and b for a, b in zip(has_no_gaps, is_not_gap)]
has_no_gaps = [a and b for a,
b in zip(has_no_gaps, is_not_gap)]
no_gap_map.extend(has_no_gaps)
multichain_mask = np.zeros((len(pdb_map), len(pdb_map)), dtype=int)
for index1, id1 in enumerate(pdb_map):
for index2, id2 in enumerate(pdb_map):
if (id1[:4] == id2[:4]): # and (no_gap_map[index1] and no_gap_map[index2]):
# and (no_gap_map[index1] and no_gap_map[index2]):
if (id1[:4] == id2[:4]):
multichain_mask[index1, index2] = 1
# DEBUG
self.save_binary_matrix(multichain_mask, "multichain_mask.png")
Expand All @@ -533,10 +581,13 @@ def pair_and_merge(self, all_chain_features):
MSA_CROP_SIZE = 2048
feature_processing.process_unmerged_features(all_chain_features)
np_chains_list = list(all_chain_features.values())
pair_msa_sequences = self.pair_msa and not feature_processing._is_homomer_or_monomer(np_chains_list)
pair_msa_sequences = self.pair_msa and not feature_processing._is_homomer_or_monomer(
np_chains_list)
if pair_msa_sequences:
np_chains_list = msa_pairing.create_paired_features(chains=np_chains_list)
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(np_chains_list)
np_chains_list = msa_pairing.create_paired_features(
chains=np_chains_list)
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(
np_chains_list)
np_chains_list = feature_processing.crop_chains(
np_chains_list,
msa_crop_size=MSA_CROP_SIZE,
Expand Down Expand Up @@ -587,5 +638,6 @@ def create_all_chain_features(self):
self.feature_dict['multichain_mask'] = self.multichain_mask
# save used templates
for i in self.interactors:
logging.info("Used multimeric templates for protein {}".format(i.description))
logging.info(
"Used multimeric templates for protein {}".format(i.description))
logging.info(i.feature_dict['template_domain_names'])

0 comments on commit 0615e2e

Please sign in to comment.