diff --git a/alphapulldown/objects.py b/alphapulldown/objects.py index 30008168..c3fdc277 100644 --- a/alphapulldown/objects.py +++ b/alphapulldown/objects.py @@ -17,7 +17,7 @@ from alphafold.data import feature_processing from pathlib import Path as plPath from typing import List, Dict -from colabfold.batch import get_msa_and_templates, msa_to_str, build_monomer_feature +from colabfold.batch import get_msa_and_templates, msa_to_str, build_monomer_feature, unserialize_msa from alphapulldown.utils.multimeric_template_utils import (extract_multimeric_template_features_for_single_chain, prepare_multimeric_template_meta_info) from alphapulldown.utils.file_handling import temp_fasta_file @@ -171,22 +171,21 @@ def make_features( if using_zipped_msa_files: MonomericObject.zip_msa_files( os.path.join(output_dir, self.description)) - + + def make_mmseq_features( self, DEFAULT_API_SERVER, output_dir=None, - compress_msa_files=False + compress_msa_files=False, + use_precomputed_msa=False, ): """ - A method to use mmseq_remote to calculate msa - Modified from ColabFold: https://github.com/sokrypton/ColabFold + A method to use mmseq_remote to calculate MSA. + Modified from ColabFold to allow reusing precomputed MSAs if available. """ - # first check if there are zipped a3m files os.makedirs(output_dir, exist_ok=True) - using_zipped_msa_files = MonomericObject.unzip_msa_files( - output_dir) - logging.info("You chose to calculate MSA with mmseq2.\nPlease also cite: Mirdita M, Schütze K, Moriwaki Y, Heo L, Ovchinnikov S and Steinegger M. ColabFold: Making protein folding accessible to all. Nature Methods (2022) doi: 10.1038/s41592-022-01488-1") - + using_zipped_msa_files = MonomericObject.unzip_msa_files(output_dir) + msa_mode = "mmseqs2_uniref_env" keep_existing_results = True result_dir = output_dir @@ -194,65 +193,54 @@ def make_mmseq_features( result_zip = os.path.join(result_dir, self.description, ".result.zip") if keep_existing_results and plPath(result_zip).is_file(): logging.info(f"Skipping {self.description} (result.zip)") - - ( - unpaired_msa, - paired_msa, - query_seqs_unique, - query_seqs_cardinality, - template_features, - ) = get_msa_and_templates( - jobname=self.description, - query_sequences=self.sequence, - a3m_lines=None, - result_dir=plPath(result_dir), - msa_mode=msa_mode, - use_templates=use_templates, - custom_template_path=None, - pair_mode="none", - host_url=DEFAULT_API_SERVER, - user_agent='alphapulldown' - ) - msa = msa_to_str( - unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality - ) - plPath(os.path.join(result_dir, self.description + ".a3m") - ).write_text(msa) - a3m_lines = [ - plPath(os.path.join(result_dir, self.description + ".a3m")).read_text()] - - if compress_msa_files: - MonomericObject.zip_msa_files( - os.path.join(result_dir, self.description)) - # unserialize_msa was from colabfold.batch and originally will only create mock template features - a3m_lines[0] = "\n".join( - [line for line in a3m_lines[0].splitlines() if not line.startswith("#")]) - self.feature_dict = build_monomer_feature(self.sequence, unpaired_msa[0], - template_features[0]) - - # update feature_dict with - valid_feats = msa_pairing.MSA_FEATURES + ( - "msa_species_identifiers", - "msa_uniprot_accession_identifiers", - ) - feats = { - f"{k}_all_seq": v for k, v in self.feature_dict.items() if k in valid_feats - } - # add template_confidence_scores if it does not exist - template_confidence_scores = self.feature_dict.get('template_confidence_scores', None) - template_release_date = self.feature_dict.get('template_release_date', None) - if template_confidence_scores is None: - self.feature_dict.update( - {'template_confidence_scores': np.array([[1] * len(self.sequence)])} + a3m_path = os.path.join(result_dir, self.description + ".a3m") + if use_precomputed_msa and os.path.isfile(a3m_path): + logging.info(f"Using precomputed MSA from {a3m_path}") + a3m_lines = [plPath(a3m_path).read_text()] + (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, + template_features) = unserialize_msa(a3m_lines, self.sequence) + else: + logging.info("You chose to calculate MSA with mmseqs2.\nPlease also cite: " + "Mirdita M, Schütze K, Moriwaki Y, Heo L, Ovchinnikov S and Steinegger M. " + "ColabFold: Making protein folding accessible to all. " + "Nature Methods (2022) doi: 10.1038/s41592-022-01488-1") + (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, + template_features) = get_msa_and_templates( + jobname=self.description, + query_sequences=self.sequence, + a3m_lines=None, + result_dir=plPath(result_dir), + msa_mode=msa_mode, + use_templates=use_templates, + custom_template_path=None, + pair_mode="none", + host_url=DEFAULT_API_SERVER, + user_agent='alphapulldown' ) - if template_release_date is None: - self.feature_dict.update({"template_release_date" : ['none']}) + msa = msa_to_str(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality) + plPath(a3m_path).write_text(msa) + a3m_lines = [plPath(a3m_path).read_text()] + if compress_msa_files: + MonomericObject.zip_msa_files(os.path.join(result_dir, self.description)) + + # Remove header lines starting with '#' if present. + a3m_lines[0] = "\n".join([line for line in a3m_lines[0].splitlines() if not line.startswith("#")]) + self.feature_dict = build_monomer_feature(self.sequence, unpaired_msa[0], template_features[0]) + + # Fix: Change tuple to list so that we can concatenate with msa_pairing.MSA_FEATURES. + valid_feats = msa_pairing.MSA_FEATURES + ["msa_species_identifiers", "msa_uniprot_accession_identifiers"] + feats = {f"{k}_all_seq": v for k, v in self.feature_dict.items() if k in valid_feats} + + # Add default template confidence and release date if missing. + if self.feature_dict.get('template_confidence_scores', None) is None: + self.feature_dict.update({'template_confidence_scores': np.array([[1] * len(self.sequence)])}) + if self.feature_dict.get('template_release_date', None) is None: + self.feature_dict.update({"template_release_date": ['none']}) self.feature_dict.update(feats) if using_zipped_msa_files: - MonomericObject.zip_msa_files( - output_dir) + MonomericObject.zip_msa_files(output_dir) class ChoppedObject(MonomericObject): diff --git a/alphapulldown/scripts/create_individual_features.py b/alphapulldown/scripts/create_individual_features.py index a8159d70..5343f022 100644 --- a/alphapulldown/scripts/create_individual_features.py +++ b/alphapulldown/scripts/create_individual_features.py @@ -305,22 +305,11 @@ def create_and_save_monomer_objects(monomer, pipeline): # Create features if FLAGS.use_mmseqs2: - if FLAGS.use_precomputed_msas: - logging.info("Using precomputed MSAs for mmseqs2: loading from disk.") - # Ensure a pipeline is available (create one if necessary) - if pipeline is None: - pipeline = create_pipeline() - monomer.make_features( - pipeline=pipeline, - output_dir=FLAGS.output_dir, - use_precomputed_msa=True, - save_msa=FLAGS.save_msa_files, - ) - else: logging.info("Running MMseqs2 for feature generation...") monomer.make_mmseq_features( DEFAULT_API_SERVER=DEFAULT_API_SERVER, - output_dir=FLAGS.output_dir + output_dir=FLAGS.output_dir, + use_precomputed_msa=FLAGS.use_precomputed_msas, ) else: monomer.make_features( diff --git a/test/test_mmseqs.py b/test/test_mmseqs.py new file mode 100644 index 00000000..9b9a253c --- /dev/null +++ b/test/test_mmseqs.py @@ -0,0 +1,127 @@ +import os +import tempfile +import numpy as np + +from absl import logging +from absl.testing import absltest + +# Import the class to test. Adjust the module path as needed. +from alphapulldown.objects import MonomericObject + +# Dummy implementations for the functions used by make_mmseq_features. +def fake_build_monomer_feature(sequence, msa, template_features): + # Return a dummy dictionary that will be later updated. + return {"dummy_feature": 42, + "template_confidence_scores": None, + "template_release_date": None} + +# Dummy msa_pairing.MSA_FEATURES list. +class FakeMSAPairing: + MSA_FEATURES = ["dummy_feature"] + +# Save originals to restore later. +_original_build_monomer_feature = None +_original_get_msa_and_templates = None +_original_unserialize_msa = None + +def fake_get_msa_and_templates(jobname, query_sequences, a3m_lines, result_dir, + msa_mode, use_templates, custom_template_path, + pair_mode, host_url, user_agent): + # Return fake tuple values. + fake_unpaired = ["FAKE_UNPAIRED"] + fake_paired = ["FAKE_PAIRED"] + fake_unique = ["FAKE_UNIQUE"] + fake_card = ["FAKE_CARDINALITY"] + fake_template = ["FAKE_TEMPLATE"] + return (fake_unpaired, fake_paired, fake_unique, fake_card, fake_template) + +def fake_unserialize_msa(a3m_lines, sequence): + # Return fake tuple values based solely on the precomputed file. + fake_unpaired = ["PRECOMPUTED_UNPAIRED"] + fake_paired = ["PRECOMPUTED_PAIRED"] + fake_unique = ["PRECOMPUTED_UNIQUE"] + fake_card = ["PRECOMPUTED_CARDINALITY"] + fake_template = ["PRECOMPUTED_TEMPLATE"] + return (fake_unpaired, fake_paired, fake_unique, fake_card, fake_template) + +class MmseqFeaturesTest(absltest.TestCase): + + def setUp(self): + super(MmseqFeaturesTest, self).setUp() + # Create a dummy MonomericObject with a known description and sequence. + self.monomer = MonomericObject("dummy", "ACDE") + # Create a temporary output directory. + self.temp_dir = tempfile.TemporaryDirectory() + self.output_dir = self.temp_dir.name + + # Monkey-patch the functions used inside make_mmseq_features. + import alphapulldown.objects as objects_mod + self._original_build_monomer_feature = objects_mod.build_monomer_feature + self._original_get_msa_and_templates = objects_mod.get_msa_and_templates + self._original_unserialize_msa = objects_mod.unserialize_msa + + objects_mod.build_monomer_feature = fake_build_monomer_feature + objects_mod.get_msa_and_templates = fake_get_msa_and_templates + objects_mod.unserialize_msa = fake_unserialize_msa + + # Override msa_pairing.MSA_FEATURES in the module where make_mmseq_features uses it. + objects_mod.msa_pairing.MSA_FEATURES = FakeMSAPairing.MSA_FEATURES + + def tearDown(self): + # Restore originals. + import alphapulldown.objects as objects_mod + objects_mod.build_monomer_feature = self._original_build_monomer_feature + objects_mod.get_msa_and_templates = self._original_get_msa_and_templates + objects_mod.unserialize_msa = self._original_unserialize_msa + self.temp_dir.cleanup() + super(MmseqFeaturesTest, self).tearDown() + + def test_use_precomputed_msa(self): + """Test that if a precomputed MSA exists and use_precomputed_msa is True, + the branch using unserialize_msa is taken.""" + # Create a dummy precomputed a3m file. + a3m_path = os.path.join(self.output_dir, self.monomer.description + ".a3m") + precomputed_content = ">dummy\nPRECOMPUTED_CONTENT\n" + with open(a3m_path, "w") as f: + f.write(precomputed_content) + + # Call the method with use_precomputed_msa=True. + self.monomer.make_mmseq_features( + DEFAULT_API_SERVER="http://fake.api", + output_dir=self.output_dir, + use_precomputed_msa=True + ) + # Our fake_unserialize_msa returns fake values that we check: + self.assertEqual(self.monomer.feature_dict["dummy_feature"], 42) + # Check that template_confidence_scores and template_release_date got set: + self.assertTrue(isinstance(self.monomer.feature_dict["template_confidence_scores"], np.ndarray)) + self.assertEqual(self.monomer.feature_dict["template_release_date"], ['none']) + + def test_api_generation(self): + """Test that if no precomputed MSA exists (or use_precomputed_msa is False), + the API branch is taken and a new a3m file is created.""" + a3m_path = os.path.join(self.output_dir, self.monomer.description + ".a3m") + # Ensure the file does not exist. + if os.path.exists(a3m_path): + os.remove(a3m_path) + # Call the method with use_precomputed_msa=False. + self.monomer.make_mmseq_features( + DEFAULT_API_SERVER="http://fake.api", + output_dir=self.output_dir, + use_precomputed_msa=False + ) + # The fake_get_msa_and_templates returns known fake values. + self.assertEqual(self.monomer.feature_dict["dummy_feature"], 42) + # The a3m file should now exist. + self.assertTrue(os.path.isfile(a3m_path)) + # Check that the file contains our dummy content from fake_get_msa_and_templates branch. + with open(a3m_path) as f: + msa_content = f.read() + self.assertIn("FAKE_UNPAIRED", msa_content) + # Check that default template values were added. + self.assertTrue(isinstance(self.monomer.feature_dict["template_confidence_scores"], np.ndarray)) + self.assertEqual(self.monomer.feature_dict["template_release_date"], ['none']) + + +if __name__ == '__main__': + absltest.main()