Skip to content

Commit

Permalink
Fix #486
Browse files Browse the repository at this point in the history
  • Loading branch information
DimaMolod committed Feb 11, 2025
1 parent 1e08bad commit b84b7da
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 77 deletions.
116 changes: 52 additions & 64 deletions alphapulldown/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from alphafold.data import feature_processing
from pathlib import Path as plPath
from typing import List, Dict
from colabfold.batch import get_msa_and_templates, msa_to_str, build_monomer_feature
from colabfold.batch import get_msa_and_templates, msa_to_str, build_monomer_feature, unserialize_msa
from alphapulldown.utils.multimeric_template_utils import (extract_multimeric_template_features_for_single_chain,
prepare_multimeric_template_meta_info)
from alphapulldown.utils.file_handling import temp_fasta_file
Expand Down Expand Up @@ -171,88 +171,76 @@ def make_features(
if using_zipped_msa_files:
MonomericObject.zip_msa_files(
os.path.join(output_dir, self.description))



def make_mmseq_features(
self, DEFAULT_API_SERVER,
output_dir=None,
compress_msa_files=False
compress_msa_files=False,
use_precomputed_msa=False,
):
"""
A method to use mmseq_remote to calculate msa
Modified from ColabFold: https://github.com/sokrypton/ColabFold
A method to use mmseq_remote to calculate MSA.
Modified from ColabFold to allow reusing precomputed MSAs if available.
"""
# first check if there are zipped a3m files
os.makedirs(output_dir, exist_ok=True)
using_zipped_msa_files = MonomericObject.unzip_msa_files(
output_dir)
logging.info("You chose to calculate MSA with mmseq2.\nPlease also cite: Mirdita M, Schütze K, Moriwaki Y, Heo L, Ovchinnikov S and Steinegger M. ColabFold: Making protein folding accessible to all. Nature Methods (2022) doi: 10.1038/s41592-022-01488-1")

using_zipped_msa_files = MonomericObject.unzip_msa_files(output_dir)

msa_mode = "mmseqs2_uniref_env"
keep_existing_results = True
result_dir = output_dir
use_templates = True
result_zip = os.path.join(result_dir, self.description, ".result.zip")
if keep_existing_results and plPath(result_zip).is_file():
logging.info(f"Skipping {self.description} (result.zip)")

(
unpaired_msa,
paired_msa,
query_seqs_unique,
query_seqs_cardinality,
template_features,
) = get_msa_and_templates(
jobname=self.description,
query_sequences=self.sequence,
a3m_lines=None,
result_dir=plPath(result_dir),
msa_mode=msa_mode,
use_templates=use_templates,
custom_template_path=None,
pair_mode="none",
host_url=DEFAULT_API_SERVER,
user_agent='alphapulldown'
)
msa = msa_to_str(
unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality
)
plPath(os.path.join(result_dir, self.description + ".a3m")
).write_text(msa)
a3m_lines = [
plPath(os.path.join(result_dir, self.description + ".a3m")).read_text()]

if compress_msa_files:
MonomericObject.zip_msa_files(
os.path.join(result_dir, self.description))
# unserialize_msa was from colabfold.batch and originally will only create mock template features
a3m_lines[0] = "\n".join(
[line for line in a3m_lines[0].splitlines() if not line.startswith("#")])
self.feature_dict = build_monomer_feature(self.sequence, unpaired_msa[0],
template_features[0])

# update feature_dict with
valid_feats = msa_pairing.MSA_FEATURES + (
"msa_species_identifiers",
"msa_uniprot_accession_identifiers",
)
feats = {
f"{k}_all_seq": v for k, v in self.feature_dict.items() if k in valid_feats
}

# add template_confidence_scores if it does not exist
template_confidence_scores = self.feature_dict.get('template_confidence_scores', None)
template_release_date = self.feature_dict.get('template_release_date', None)
if template_confidence_scores is None:
self.feature_dict.update(
{'template_confidence_scores': np.array([[1] * len(self.sequence)])}
a3m_path = os.path.join(result_dir, self.description + ".a3m")
if use_precomputed_msa and os.path.isfile(a3m_path):
logging.info(f"Using precomputed MSA from {a3m_path}")
a3m_lines = [plPath(a3m_path).read_text()]
(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality,
template_features) = unserialize_msa(a3m_lines, self.sequence)
else:
logging.info("You chose to calculate MSA with mmseqs2.\nPlease also cite: "
"Mirdita M, Schütze K, Moriwaki Y, Heo L, Ovchinnikov S and Steinegger M. "
"ColabFold: Making protein folding accessible to all. "
"Nature Methods (2022) doi: 10.1038/s41592-022-01488-1")
(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality,
template_features) = get_msa_and_templates(
jobname=self.description,
query_sequences=self.sequence,
a3m_lines=None,
result_dir=plPath(result_dir),
msa_mode=msa_mode,
use_templates=use_templates,
custom_template_path=None,
pair_mode="none",
host_url=DEFAULT_API_SERVER,
user_agent='alphapulldown'
)
if template_release_date is None:
self.feature_dict.update({"template_release_date" : ['none']})
msa = msa_to_str(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality)
plPath(a3m_path).write_text(msa)
a3m_lines = [plPath(a3m_path).read_text()]
if compress_msa_files:
MonomericObject.zip_msa_files(os.path.join(result_dir, self.description))

# Remove header lines starting with '#' if present.
a3m_lines[0] = "\n".join([line for line in a3m_lines[0].splitlines() if not line.startswith("#")])
self.feature_dict = build_monomer_feature(self.sequence, unpaired_msa[0], template_features[0])

# Fix: Change tuple to list so that we can concatenate with msa_pairing.MSA_FEATURES.
valid_feats = msa_pairing.MSA_FEATURES + ["msa_species_identifiers", "msa_uniprot_accession_identifiers"]
feats = {f"{k}_all_seq": v for k, v in self.feature_dict.items() if k in valid_feats}

# Add default template confidence and release date if missing.
if self.feature_dict.get('template_confidence_scores', None) is None:
self.feature_dict.update({'template_confidence_scores': np.array([[1] * len(self.sequence)])})
if self.feature_dict.get('template_release_date', None) is None:
self.feature_dict.update({"template_release_date": ['none']})
self.feature_dict.update(feats)

if using_zipped_msa_files:
MonomericObject.zip_msa_files(
output_dir)
MonomericObject.zip_msa_files(output_dir)


class ChoppedObject(MonomericObject):
Expand Down
15 changes: 2 additions & 13 deletions alphapulldown/scripts/create_individual_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,22 +305,11 @@ def create_and_save_monomer_objects(monomer, pipeline):

# Create features
if FLAGS.use_mmseqs2:
if FLAGS.use_precomputed_msas:
logging.info("Using precomputed MSAs for mmseqs2: loading from disk.")
# Ensure a pipeline is available (create one if necessary)
if pipeline is None:
pipeline = create_pipeline()
monomer.make_features(
pipeline=pipeline,
output_dir=FLAGS.output_dir,
use_precomputed_msa=True,
save_msa=FLAGS.save_msa_files,
)
else:
logging.info("Running MMseqs2 for feature generation...")
monomer.make_mmseq_features(
DEFAULT_API_SERVER=DEFAULT_API_SERVER,
output_dir=FLAGS.output_dir
output_dir=FLAGS.output_dir,
use_precomputed_msa=FLAGS.use_precomputed_msas,
)
else:
monomer.make_features(
Expand Down
127 changes: 127 additions & 0 deletions test/test_mmseqs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import os
import tempfile
import numpy as np

from absl import logging
from absl.testing import absltest

# Import the class to test. Adjust the module path as needed.
from alphapulldown.objects import MonomericObject

# Dummy implementations for the functions used by make_mmseq_features.
def fake_build_monomer_feature(sequence, msa, template_features):
# Return a dummy dictionary that will be later updated.
return {"dummy_feature": 42,
"template_confidence_scores": None,
"template_release_date": None}

# Dummy msa_pairing.MSA_FEATURES list.
class FakeMSAPairing:
MSA_FEATURES = ["dummy_feature"]

# Save originals to restore later.
_original_build_monomer_feature = None
_original_get_msa_and_templates = None
_original_unserialize_msa = None

def fake_get_msa_and_templates(jobname, query_sequences, a3m_lines, result_dir,
msa_mode, use_templates, custom_template_path,
pair_mode, host_url, user_agent):
# Return fake tuple values.
fake_unpaired = ["FAKE_UNPAIRED"]
fake_paired = ["FAKE_PAIRED"]
fake_unique = ["FAKE_UNIQUE"]
fake_card = ["FAKE_CARDINALITY"]
fake_template = ["FAKE_TEMPLATE"]
return (fake_unpaired, fake_paired, fake_unique, fake_card, fake_template)

def fake_unserialize_msa(a3m_lines, sequence):
# Return fake tuple values based solely on the precomputed file.
fake_unpaired = ["PRECOMPUTED_UNPAIRED"]
fake_paired = ["PRECOMPUTED_PAIRED"]
fake_unique = ["PRECOMPUTED_UNIQUE"]
fake_card = ["PRECOMPUTED_CARDINALITY"]
fake_template = ["PRECOMPUTED_TEMPLATE"]
return (fake_unpaired, fake_paired, fake_unique, fake_card, fake_template)

class MmseqFeaturesTest(absltest.TestCase):

def setUp(self):
super(MmseqFeaturesTest, self).setUp()
# Create a dummy MonomericObject with a known description and sequence.
self.monomer = MonomericObject("dummy", "ACDE")
# Create a temporary output directory.
self.temp_dir = tempfile.TemporaryDirectory()
self.output_dir = self.temp_dir.name

# Monkey-patch the functions used inside make_mmseq_features.
import alphapulldown.objects as objects_mod
self._original_build_monomer_feature = objects_mod.build_monomer_feature
self._original_get_msa_and_templates = objects_mod.get_msa_and_templates
self._original_unserialize_msa = objects_mod.unserialize_msa

objects_mod.build_monomer_feature = fake_build_monomer_feature
objects_mod.get_msa_and_templates = fake_get_msa_and_templates
objects_mod.unserialize_msa = fake_unserialize_msa

# Override msa_pairing.MSA_FEATURES in the module where make_mmseq_features uses it.
objects_mod.msa_pairing.MSA_FEATURES = FakeMSAPairing.MSA_FEATURES

def tearDown(self):
# Restore originals.
import alphapulldown.objects as objects_mod
objects_mod.build_monomer_feature = self._original_build_monomer_feature
objects_mod.get_msa_and_templates = self._original_get_msa_and_templates
objects_mod.unserialize_msa = self._original_unserialize_msa
self.temp_dir.cleanup()
super(MmseqFeaturesTest, self).tearDown()

def test_use_precomputed_msa(self):
"""Test that if a precomputed MSA exists and use_precomputed_msa is True,
the branch using unserialize_msa is taken."""
# Create a dummy precomputed a3m file.
a3m_path = os.path.join(self.output_dir, self.monomer.description + ".a3m")
precomputed_content = ">dummy\nPRECOMPUTED_CONTENT\n"
with open(a3m_path, "w") as f:
f.write(precomputed_content)

# Call the method with use_precomputed_msa=True.
self.monomer.make_mmseq_features(
DEFAULT_API_SERVER="http://fake.api",
output_dir=self.output_dir,
use_precomputed_msa=True
)
# Our fake_unserialize_msa returns fake values that we check:
self.assertEqual(self.monomer.feature_dict["dummy_feature"], 42)
# Check that template_confidence_scores and template_release_date got set:
self.assertTrue(isinstance(self.monomer.feature_dict["template_confidence_scores"], np.ndarray))
self.assertEqual(self.monomer.feature_dict["template_release_date"], ['none'])

def test_api_generation(self):
"""Test that if no precomputed MSA exists (or use_precomputed_msa is False),
the API branch is taken and a new a3m file is created."""
a3m_path = os.path.join(self.output_dir, self.monomer.description + ".a3m")
# Ensure the file does not exist.
if os.path.exists(a3m_path):
os.remove(a3m_path)
# Call the method with use_precomputed_msa=False.
self.monomer.make_mmseq_features(
DEFAULT_API_SERVER="http://fake.api",
output_dir=self.output_dir,
use_precomputed_msa=False
)
# The fake_get_msa_and_templates returns known fake values.
self.assertEqual(self.monomer.feature_dict["dummy_feature"], 42)
# The a3m file should now exist.
self.assertTrue(os.path.isfile(a3m_path))
# Check that the file contains our dummy content from fake_get_msa_and_templates branch.
with open(a3m_path) as f:
msa_content = f.read()
self.assertIn("FAKE_UNPAIRED", msa_content)
# Check that default template values were added.
self.assertTrue(isinstance(self.monomer.feature_dict["template_confidence_scores"], np.ndarray))
self.assertEqual(self.monomer.feature_dict["template_release_date"], ['none'])


if __name__ == '__main__':
absltest.main()

0 comments on commit b84b7da

Please sign in to comment.