diff --git a/alphapulldown/folding_backend/__init__.py b/alphapulldown/folding_backend/__init__.py new file mode 100644 index 00000000..230841d9 --- /dev/null +++ b/alphapulldown/folding_backend/__init__.py @@ -0,0 +1,87 @@ +""" Implements class to represent electron density maps. + + Copyright (c) 2023 European Molecular Biology Laboratory + + Author: Valentin Maurer +""" + +from typing import Dict, List + +from .alphafold_backend import AlphaFold + +class FoldingBackendManager: + """ + Manager for structure prediction backends. + + Attributes + ---------- + _BACKEND_REGISTRY : dict + A dictionary mapping backend names to their respective classes or instances. + _backend : instance of MatchingBackend + An instance of the currently active backend. Defaults to AlphaFold. + _backend_name : str + Name of the current backend. + _backend_args : Dict + Arguments passed to create current backend. + + """ + + def __init__(self): + self._BACKEND_REGISTRY = { + "alphafold": AlphaFold, + } + self._backend = AlphaFold() + self._backend_name = "alphafold" + self._backend_args = {} + + def __repr__(self): + return f"" + + def __getattr__(self, name): + return getattr(self._backend, name) + + def __dir__(self) -> List: + """ + Return a list of attributes available in this object, + including those from the backend. + + Returns + ------- + list + Sorted list of attributes. + """ + base_attributes = [] + base_attributes.extend(dir(self.__class__)) + base_attributes.extend(self.__dict__.keys()) + base_attributes.extend(dir(self._backend)) + return sorted(base_attributes) + + def change_backend(self, backend_name: str, **backend_kwargs: Dict) -> None: + """ + Change the backend. + + Parameters + ---------- + backend_name : str + Name of the new backend that should be used. + **backend_kwargs : Dict, optional + Parameters passed to __init__ method of backend. + + Raises + ------ + NotImplementedError + If no backend is found with the provided name. + """ + if backend_name not in self._BACKEND_REGISTRY: + available_backends = ", ".join( + [str(x) for x in self._BACKEND_REGISTRY.keys()] + ) + raise NotImplementedError( + f"Available backends are {available_backends} - not {backend_name}." + ) + self._backend = self._BACKEND_REGISTRY[backend_name](**backend_kwargs) + self._backend_name = backend_name + self._backend_args = backend_kwargs + + +backend = FoldingBackendManager() diff --git a/alphapulldown/folding_backend/alphafold_backend.py b/alphapulldown/folding_backend/alphafold_backend.py new file mode 100644 index 00000000..fda47ef7 --- /dev/null +++ b/alphapulldown/folding_backend/alphafold_backend.py @@ -0,0 +1,296 @@ +""" Implements structure prediction backend using AlphaFold. + + Copyright (c) 2024 European Molecular Biology Laboratory + + Author: Valentin Maurer +""" + +import time +import json +import pickle + +import numpy as np +from os.path import join, exists +from typing import List, Dict + +import jax.numpy as jnp +from alphapulldown.predict_structure import get_existing_model_info +from alphapulldown.objects import MultimericObject +from alphapulldown.utils import ( + create_and_save_pae_plots, + post_prediction_process, +) +# Avoid module not found error by importing after AP +import run_alphafold +from run_alphafold import ModelsToRelax +from alphafold.relax import relax +from alphafold.common import protein, residue_constants + +from .folding_backend import FoldingBackend + + +def _jnp_to_np(output): + """Recursively changes jax arrays to numpy arrays.""" + for k, v in output.items(): + if isinstance(v, dict): + output[k] = _jnp_to_np(v) + elif isinstance(v, jnp.ndarray): + output[k] = np.array(v) + return output + + +class AlphaFold(FoldingBackend): + """ + A backend to perform structure prediction using AlphaFold. + """ + @staticmethod + def predict( + model_runners : Dict, + output_dir : Dict, + feature_dict : Dict, + random_seed : int, + fasta_name: str, + models_to_relax: object = ModelsToRelax, + allow_resume: bool = True, + seqs: List = [], + use_gpu_relax: bool = True, + multimeric_mode: bool = False, + **kwargs): + """ + Predicts the structure of proteins using a specified set of models and features. + + Parameters + ---------- + model_runners : dict + A dictionary of model names to their respective runners obtained from + :py:meth:`alphapulldown.utils.create_model_runners_and_random_seed. + output_dir : str + The directory where prediction results, including PDB files and metrics, + will be saved. + feature_dict : dict + A dictionary containing the features required by the models for prediction. + random_seed : int + A seed for random number generation to ensure reproducibility obtained + from :py:meth:`alphapulldown.utils.create_model_runners_and_random_seed. + fasta_name : str + The name of the fasta file, used for naming the output files. + models_to_relax : object, optional + An enum indicating which models' predictions to relax. Defaults to + ModelsToRelax which should be an enum type. + allow_resume : bool, optional + If True, attempts to resume prediction from partially completed runs. + Default is True. + seqs : List, optional + A list of sequences for which predictions are being made. + Default is an empty list. + use_gpu_relax : bool, optional + If True, uses GPU acceleration for the relaxation step. Default is True. + multimeric_mode : bool, optional + If True, enables multimeric prediction mode. Default is False. + **kwargs + Additional keyword arguments passed to model prediction and processing. + + Raises + ------ + ValueError + If multimeric mode is enabled but no valid templates are found in + the feature dictionary. + + Notes + ----- + This function is a cleaned up version of alphapulldown.predict_structure.predict + """ + + timings = {} + unrelaxed_pdbs = {} + relaxed_pdbs = {} + relax_metrics = {} + ranking_confidences = {} + unrelaxed_proteins = {} + prediction_result = {} + START = 0 + + ranking_output_path = join(output_dir, "ranking_debug.json") + + if allow_resume: + ( + ranking_confidences, + unrelaxed_proteins, + unrelaxed_pdbs, + START, + ) = get_existing_model_info(output_dir, model_runners) + + if exists(ranking_output_path) and len(unrelaxed_pdbs) == len( + model_runners + ): + START = len(model_runners) + + num_models = len(model_runners) + for model_index, (model_name, model_runner) in enumerate(model_runners.items()): + if model_index < START: + continue + t_0 = time.time() + + model_random_seed = model_index + random_seed * num_models + processed_feature_dict = model_runner.process_features( + feature_dict, random_seed=model_random_seed + ) + timings[f"process_features_{model_name}"] = time.time() - t_0 + # Die if --multimeric_mode=True but no non-zero templates are in the feature dict + if multimeric_mode: + if "template_all_atom_positions" in processed_feature_dict: + if not np.any( + processed_feature_dict["template_all_atom_positions"] + ): + raise ValueError( + "No valid templates found: all positions are zero." + ) + else: + raise ValueError( + "No template_all_atom_positions key found in processed_feature_dict." + ) + + t_0 = time.time() + prediction_result = model_runner.predict( + processed_feature_dict, random_seed=model_random_seed + ) + + # update prediction_result with input seqs + prediction_result.update({"seqs": seqs}) + + t_diff = time.time() - t_0 + timings[f"predict_and_compile_{model_name}"] = t_diff + + plddt = prediction_result["plddt"] + ranking_confidences[model_name] = prediction_result["ranking_confidence"] + + # Remove jax dependency from results. + np_prediction_result = _jnp_to_np(dict(prediction_result)) + + result_output_path = join(output_dir, f"result_{model_name}.pkl") + with open(result_output_path, "wb") as f: + pickle.dump(np_prediction_result, f, protocol=4) + + plddt_b_factors = np.repeat( + plddt[:, None], residue_constants.atom_type_num, axis=-1 + ) + + unrelaxed_protein = protein.from_prediction( + features=processed_feature_dict, + result=prediction_result, + b_factors=plddt_b_factors, + remove_leading_feature_dimension=not model_runner.multimer_mode, + ) + + unrelaxed_proteins[model_name] = unrelaxed_protein + unrelaxed_pdbs[model_name] = protein.to_pdb(unrelaxed_protein) + unrelaxed_pdb_path = join(output_dir, f"unrelaxed_{model_name}.pdb") + with open(unrelaxed_pdb_path, "w") as f: + f.write(unrelaxed_pdbs[model_name]) + + + # Rank by model confidence. + ranked_order = [ + model_name + for model_name, confidence in sorted( + ranking_confidences.items(), key=lambda x: x[1], reverse=True + ) + ] + + # Relax predictions. + amber_relaxer = relax.AmberRelaxation( + max_iterations=run_alphafold.RELAX_MAX_ITERATIONS, + tolerance=run_alphafold.RELAX_ENERGY_TOLERANCE, + stiffness=run_alphafold.RELAX_STIFFNESS, + exclude_residues=run_alphafold.RELAX_EXCLUDE_RESIDUES, + max_outer_iterations=run_alphafold.RELAX_MAX_OUTER_ITERATIONS, + use_gpu=use_gpu_relax, + ) + + to_relax = [] + if models_to_relax == ModelsToRelax.BEST: + to_relax = [ranked_order[0]] + elif models_to_relax == ModelsToRelax.ALL: + to_relax = ranked_order + + for model_name in to_relax: + t_0 = time.time() + relaxed_pdb_str, _, violations = amber_relaxer.process( + prot=unrelaxed_proteins[model_name] + ) + relax_metrics[model_name] = { + "remaining_violations": violations, + "remaining_violations_count": sum(violations), + } + timings[f"relax_{model_name}"] = time.time() - t_0 + + relaxed_pdbs[model_name] = relaxed_pdb_str + + # Save the relaxed PDB. + relaxed_output_path = join(output_dir, f"relaxed_{model_name}.pdb") + with open(relaxed_output_path, "w") as f: + f.write(relaxed_pdb_str) + + # Write out relaxed PDBs in rank order. + for idx, model_name in enumerate(ranked_order): + ranked_output_path = join(output_dir, f"ranked_{idx}.pdb") + with open(ranked_output_path, "w") as f: + if model_name in relaxed_pdbs: + model = relaxed_pdbs[model_name] + else: + model = unrelaxed_pdbs[model_name] + f.write(model) + + if not exists(ranking_output_path): # already exists if restored. + with open(ranking_output_path, "w") as f: + label = "iptm+ptm" if "iptm" in prediction_result else "plddts" + f.write( + json.dumps( + {label: ranking_confidences, "order": ranked_order}, indent=4 + ) + ) + + timings_output_path = join(output_dir, "timings.json") + with open(timings_output_path, "w") as f: + f.write(json.dumps(timings, indent=4)) + if models_to_relax != ModelsToRelax.NONE: + relax_metrics_path = join(output_dir, "relax_metrics.json") + with open(relax_metrics_path, "w") as f: + f.write(json.dumps(relax_metrics, indent=4)) + + + @staticmethod + def postprocess( + multimer: MultimericObject, + output_path: str, + zip_pickles: bool = False, + remove_pickles: bool = False, + **kwargs: Dict, + ) -> None: + """ + Performs post-processing operations on predicted protein structures and + writes results and plots to output_path. + + Parameters + ---------- + multimer : MultimericObject + The multimeric object containing the predicted structures and + associated data. + output_path : str + The directory where post-processed files and plots will be saved. + zip_pickles : bool, optional + If True, zips the pickle files containing prediction results. + Default is False. + remove_pickles : bool, optional + If True, removes the pickle files after post-processing is complete. + Default is False. + **kwargs : dict + Additional keyword arguments for future extensions or custom + post-processing steps. + """ + create_and_save_pae_plots(multimer, output_path) + post_prediction_process( + output_path, + zip_pickles=zip_pickles, + remove_pickles=remove_pickles, + ) diff --git a/alphapulldown/folding_backend/folding_backend.py b/alphapulldown/folding_backend/folding_backend.py new file mode 100644 index 00000000..f74dd99d --- /dev/null +++ b/alphapulldown/folding_backend/folding_backend.py @@ -0,0 +1,48 @@ +""" Implements structure prediction strategy class. + + Copyright (c) 2024 European Molecular Biology Laboratory + + Author: Valentin Maurer +""" + +from abc import ABC, abstractmethod + +class FoldingBackend(ABC): + """ + A strategy class for structure prediction using various folding backends. + """ + @abstractmethod + def predict(self, **kwargs) -> None: + """ + Abstract method for predicting protein structures. + + This method should be implemented by subclasses to perform protein structure + prediction given a set of input features and parameters specific to the + implementation. Implementations may vary in terms of accepted parameters and + the method of prediction. + + Parameters + ---------- + **kwargs : dict + A flexible set of keyword arguments that can include input features, + model configuration, and other prediction-related parameters. + """ + + + @abstractmethod + def postprocess(self, **kwargs): + """ + Abstract method for post-processing predicted protein structures. + + This method should be implemented by subclasses to perform any necessary + post-processing on the predicted protein structures, such as generating plots, + modifying the structure data, or cleaning up temporary files. The specifics + of the post-processing steps can vary between implementations. + + Parameters + ---------- + **kwargs : dict + A flexible set of keyword arguments that can include paths to prediction results, \ + options for file handling, and other post-processing related parameters. + """ + diff --git a/alphapulldown/run_structure_prediction.py b/alphapulldown/run_structure_prediction.py new file mode 100644 index 00000000..135e577b --- /dev/null +++ b/alphapulldown/run_structure_prediction.py @@ -0,0 +1,261 @@ +#!python3 +""" CLI inferface for performing structure prediction. + + Copyright (c) 2024 European Molecular Biology Laboratory + + Author: Valentin Maurer +""" +import argparse +from os import makedirs +from os.path import exists, join + +from alphapulldown.run_multimer_jobs import create_custom_info +from alphapulldown.utils import create_model_runners_and_random_seed, create_interactors +from alphapulldown.objects import MultimericObject +from alphapulldown.folding_backend import backend + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run protein folding.") + parser.add_argument( + "-i", + "--input", + dest="input", + type=str, + required=True, + help="Folds in format [fasta_path:number:start-stop],[...],.", + ), + parser.add_argument( + "-o", + "--output_directory", + dest="output_directory", + type=str, + required=True, + help="Path to output directory. Will be created if not exists.", + ), + parser.add_argument( + "--num_cycle", + dest="num_cycle", + type=int, + required=False, + default=3, + help="Number of recycles, defaults to 3.", + ), + parser.add_argument( + "--num_predictions_per_model", + dest="num_predictions_per_model", + type=int, + required=False, + default=1, + help="Number of predictions per model, defaults to 1.", + ), + parser.add_argument( + "--data_directory", + dest="data_directory", + type=str, + required=True, + help="Path to data directory.", + ), + parser.add_argument( + "--features_directory", + dest="features_directory", + type=str, + nargs="+", + required=True, + help="Path to computed monomer features.", + ), + parser.add_argument( + "--no_pair_msa", + dest="no_pair_msa", + action="store_true", + default=False, + help="Do not pair the MSAs when constructing multimer objects.", + ), + parser.add_argument( + "--gradient_msa_depth", + dest="gradient_msa_depth", + action="store_true", + default=None, + help="Run predictions for each model with logarithmically distributed MSA depth.", + ), + parser.add_argument( + "--multimeric_template", + dest="multimeric_template", + action="store_true", + default=None, + help="Whether to use multimeric templates.", + ), + parser.add_argument( + "--model_names", + dest="model_names", + type=str, + default=None, + help="Names of models to use, e.g. model_2_multimer_v3 (default: all models).", + ), + parser.add_argument( + "--msa_depth", + dest="msa_depth", + type=int, + default=None, + help="Number of sequences to use from the MSA (by default is taken from AF model config).", + ), + parser.add_argument( + "--protein_delimiter", + dest="protein_delimiter", + type=str, + default=";", + help="Delimiter for proteins of a singel fold.", + ), + args = parser.parse_args() + + makedirs(args.output_directory, exist_ok=True) + + formatted_folds, missing_features, unique_features = [], [], [] + protein_folds = [x.split(":") for x in args.input.split(args.protein_delimiter)] + for protein_fold in protein_folds: + name, number, region = None, 1, "all" + + match len(protein_fold): + case 1: + name = protein_fold[0] + case 2: + name, number = protein_fold[0], protein_fold[1] + if ("-") in protein_fold[1]: + number = 1 + region = protein_fold[1].split("-") + case 3: + name, number, region = protein_fold + + number = int(number) + if len(region) != 2 and region != "all": + raise ValueError(f"Region {region} is malformatted expected start-stop.") + + if len(region) == 2: + region = [tuple(int(x) for x in region)] + + unique_features.append(name) + for monomer_dir in args.features_directory: + if exists(join(monomer_dir, f"{name}.pkl")): + continue + missing_features.append(name) + + formatted_folds.extend([{name: region} for _ in range(number)]) + + missing_features = set(missing_features) + if len(missing_features): + raise FileNotFoundError( + f"{missing_features} not found in {args.features_directory}" + ) + + args.parsed_input = formatted_folds + + return args + + +def predict_multimer( + multimer: MultimericObject, + num_recycles: int, + data_directory: str, + num_predictions_per_model: int, + output_directory: str, + gradient_msa_depth: bool = False, + model_names: str = None, + msa_depth: int = None, + random_seed: int = 42, + fold_backend: str = "alphafold", +) -> None: + """ + Predict structural features of multimers using specified models and configurations. + + Parameters + ---------- + multimer : MultimericObject + An instance of `MultimericObject` representing the multimeric structure(s) for which + predictions are to be made. These objects should be created using functions like + `create_multimer_objects()`, `create_custom_jobs()`, or `create_homooligomers()`. + num_recycles : int + The number of recycles to be used during the prediction process. + data_directory : str + The directory path where input data for the prediction process is stored. + num_predictions_per_model : int + The number of predictions to generate per model. + output_directory : str + The directory path where the prediction results will be saved. + gradient_msa_depth : bool, optional + A flag indicating whether to adjust the MSA depth based on gradients. Default is False. + model_names : str, optional + The names of the models to be used for prediction. If not provided, a default set of + models is used. Default is None. + msa_depth : int, optional + Specifies the depth of the MSA (Multiple Sequence Alignment) to be used. If not + provided, a default value based on the model configuration is used. Default is None. + random_seed : int, optional + The random seed for initializing the prediction process to ensure reproducibility. + Default is 42. + fold_backend : str, optional + Backend used for folding, defaults to alphafold. + """ + + flags_dict = { + "model_preset": "monomer_ptm", + "random_seed": random_seed, + "num_cycle": num_recycles, + "data_dir": data_directory, + "num_multimer_predictions_per_model": num_predictions_per_model, + } + + if isinstance(multimer, MultimericObject): + flags_dict["model_preset"] = "multimer" + flags_dict["gradient_msa_depth"] = gradient_msa_depth + flags_dict["model_names_custom"] = model_names + flags_dict["msa_depth"] = msa_depth + else: + multimer.input_seqs = [multimer.sequence] + + model_runners, random_seed = create_model_runners_and_random_seed(**flags_dict) + + backend.change_backend(backend_name=fold_backend) + + backend.predict( + model_runners=model_runners, + output_dir=output_directory, + feature_dict=multimer.feature_dict, + random_seed=random_seed, + fasta_name=multimer.description, + seqs=multimer.input_seqs, + ) + backend.postprocess( + multimer=multimer, + output_path=output_directory, + zip_pickles=False, + remove_pickles=False, + ) + + +def main(): + args = parse_args() + + data = create_custom_info(args.parsed_input) + interactors = create_interactors(data, args.features_directory, 0) + multimer = interactors[0] + if len(interactors) > 1: + multimer = MultimericObject( + interactors=interactors, + pair_msa=not args.no_pair_msa, + multimeric_mode=args.multimeric_template, + ) + + predict_multimer( + multimer=multimer, + num_recycles=args.num_cycle, + data_directory=args.data_directory, + num_predictions_per_model=args.num_predictions_per_model, + output_directory=args.output_directory, + gradient_msa_depth=args.gradient_msa_depth, + model_names=args.model_names, + msa_depth=args.msa_depth, + ) + + +if __name__ == "__main__": + main() diff --git a/setup.cfg b/setup.cfg index 7cbf4852..9ff3309c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,7 @@ classifiers = [options] packages= alphapulldown + alphapulldown.folding_backend alphafold alphafold.data alphafold.data.tools @@ -65,7 +66,7 @@ install_requires = appdirs>=1.4.4 jupyterlab ipywidgets -scripts = ./alphafold/run_alphafold.py, ./alphapulldown/create_individual_features.py, ./alphapulldown/run_multimer_jobs.py, ./alphapulldown/analysis_pipeline/create_notebook.py, ./alphapulldown/rename_colab_search_a3m.py, ./alphapulldown/prepare_seq_names.py, ./alphapulldown/generate_crosslink_pickle.py, ./alphapulldown/convert_to_modelcif.py +scripts = ./alphafold/run_alphafold.py, ./alphapulldown/create_individual_features.py, ./alphapulldown/run_multimer_jobs.py, ./alphapulldown/analysis_pipeline/create_notebook.py, ./alphapulldown/rename_colab_search_a3m.py, ./alphapulldown/prepare_seq_names.py, ./alphapulldown/generate_crosslink_pickle.py, ./alphapulldown/convert_to_modelcif.py, ./alphapulldown/run_structure_prediction.py [options.data_files] lib/python3.10/site-packages/alphafold/common/ = stereo_chemical_props.txt