Skip to content

Commit

Permalink
added names of older AlphaFold models
Browse files Browse the repository at this point in the history
  • Loading branch information
dingquanyu committed Jun 20, 2024
1 parent c5c4c7e commit 5c77cf3
Showing 1 changed file with 36 additions and 37 deletions.
73 changes: 36 additions & 37 deletions alphapulldown/folding_backend/alphafold_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,22 @@ def setup(
num_ensemble = 1
model_runners = {}
model_names = config.MODEL_PRESETS[model_name]

# add model names of older versionsto be compatible with older version of AlphaFold Multimer
old_model_names = (
'model_1_multimer',
'model_2_multimer',
'model_3_multimer',
'model_4_multimer',
'model_5_multimer',
'model_1_multimer_v2',
'model_2_multimer_v2',
'model_3_multimer_v2',
'model_4_multimer_v2',
'model_5_multimer_v2',
)
if model_names_custom:
model_names_custom = tuple(model_names_custom.split(","))
if all(x in model_names for x in model_names_custom):
if all(x in model_names for x in model_names_custom + old_model_names):
model_names = model_names_custom
else:
raise Exception(
Expand Down Expand Up @@ -390,12 +402,13 @@ def predict_individual_job(
b_factors=plddt_b_factors,
remove_leading_feature_dimension=not model_runner.multimer_mode,
)

# Remove jax dependency from results
np_prediction_result = _jnp_to_np(dict(prediction_result))
# Save prediction results to pickle file
result_output_path = os.path.join(output_dir, f"result_{model_name}.pkl")
with open(result_output_path, "wb") as f:
pickle.dump(np_prediction_result, f, protocol=4)
pickle.dump(np_prediction_result, f, protocol=4)
prediction_result.update(
{"seqs": multimeric_object.input_seqs if hasattr(multimeric_object,"input_seqs") else [multimeric_object.sequence]})
prediction_result.update({"unrelaxed_protein": unrelaxed_protein})
Expand All @@ -410,6 +423,7 @@ def predict_individual_job(
timings_output_path = os.path.join(output_dir, "timings.json")
with open(timings_output_path, "w") as f:
f.write(json.dumps(timings, indent=4))

return prediction_results

@staticmethod
Expand Down Expand Up @@ -437,10 +451,7 @@ def predict(model_runners: Dict,
@staticmethod
def recalculate_confidence(prediction_results: Dict, multimer_mode:bool,
total_num_res: int) -> Dict[str, Any]:
"""
A method that remove pae values of padded residues and recalculate iptm_ptm score again
Modified based on https://github.com/KosinskiLab/alphafold/blob/c844e1bb60a3beb50bb8d562c6be046da1e43e3d/alphafold/model/model.py#L31
"""
"""A method that remove pae values of padded residues and recalculate iptm_ptm score again """
if type(prediction_results['predicted_aligned_error']) == np.ndarray:
return prediction_results
else:
Expand All @@ -451,8 +462,7 @@ def recalculate_confidence(prediction_results: Dict, multimer_mode:bool,
logits=prediction_results['predicted_aligned_error']['logits'][:total_num_res,:total_num_res],
breaks=prediction_results['predicted_aligned_error']['breaks'],
asym_id=None)
output['ptm'] = ptm


pae = confidence.compute_predicted_aligned_error(
logits=prediction_results['predicted_aligned_error']['logits'],
breaks=prediction_results['predicted_aligned_error']['breaks'])
Expand All @@ -476,8 +486,8 @@ def recalculate_confidence(prediction_results: Dict, multimer_mode:bool,
ranking_confidence = np.mean(
plddt)
output.update({'ranking_confidence' : ranking_confidence})
return output

return output

@staticmethod
def postprocess(
Expand All @@ -488,7 +498,6 @@ def postprocess(
models_to_relax: ModelsToRelax,
zip_pickles: bool = False,
remove_pickles: bool = False,
convert_to_modelcif: bool = True,
use_gpu_relax: bool = True,
pae_plot_style: str = "red_blue",

Expand Down Expand Up @@ -517,8 +526,6 @@ def postprocess(
remove_pickles : bool, optional
If True, removes the pickle files after post-processing is complete.
Default is False.
convert_to_modelcif : bool, optional
If set to True, converts all predicted models to ModelCIF format, default is True.
use_gpu_relax : bool, optional
If set to True, utilizes GPU acceleration for the relaxation step, default is True.
pae_plot_style : str, optional
Expand Down Expand Up @@ -546,12 +553,6 @@ def postprocess(
for model_name, prediction_result in prediction_results.items():
prediction_result.update(AlphaFoldBackend.recalculate_confidence(prediction_result,multimer_mode,
total_num_res))
# Remove jax dependency from results
np_prediction_result = _jnp_to_np(dict(prediction_result))
# Save prediction results to pickle file
result_output_path = os.path.join(output_dir, f"result_{model_name}.pkl")
with open(result_output_path, "wb") as f:
pickle.dump(np_prediction_result, f, protocol=4)
if 'iptm' in prediction_result:
label = 'iptm+ptm'
iptm_scores[model_name] = float(prediction_result['iptm'])
Expand Down Expand Up @@ -664,23 +665,21 @@ def postprocess(
# template_file_path, ranked_output_path, temp_dir
# )

# Call convert_to_modelcif script
if convert_to_modelcif:
parent_dir = os.path.dirname(os.path.dirname((os.path.abspath(__file__))))
logging.info(f"Converting {output_dir} to ModelCIF format...")
command = f"python3 {parent_dir}/scripts/convert_to_modelcif.py " \
f"--ap_output {output_dir} "

result = subprocess.run(command,
check=True,
shell=True,
capture_output=True,
text=True)

if result.stderr:
logging.error("Error:", result.stderr)
else:
logging.info("All PDBs converted to ModelCIF format.")
#Call convert_to_modelcif script
# parent_dir = os.path.dirname(os.path.dirname((os.path.abspath(__file__))))
# command = f"python3 {parent_dir}/scripts/convert_to_modelcif.py " \
# f"--ap_output {output_dir} " \
# f"--monomer_objects_dir {''.join(features_directory)}"

#result = subprocess.run(command,
# check=True,
# shell=True,
# capture_output=True,
# text=True)

#logging.info(result.stdout)
#if result.stderr:
# logging.error("Error:", result.stderr)
post_prediction_process(
output_dir,
zip_pickles=zip_pickles,
Expand Down

0 comments on commit 5c77cf3

Please sign in to comment.