From 3655f21109dac247aac9438fc9c2bda09a00222f Mon Sep 17 00:00:00 2001 From: Dingquan Yu Date: Mon, 17 Jun 2024 15:44:36 +0200 Subject: [PATCH] update recalculation function --- .../folding_backend/alphafold_backend.py | 68 ++++++++++--------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/alphapulldown/folding_backend/alphafold_backend.py b/alphapulldown/folding_backend/alphafold_backend.py index 146c222d..3e80db7a 100644 --- a/alphapulldown/folding_backend/alphafold_backend.py +++ b/alphapulldown/folding_backend/alphafold_backend.py @@ -441,40 +441,43 @@ def recalculate_confidence(prediction_results: Dict, multimer_mode:bool, A method that remove pae values of padded residues and recalculate iptm_ptm score again Modified based on https://github.com/KosinskiLab/alphafold/blob/c844e1bb60a3beb50bb8d562c6be046da1e43e3d/alphafold/model/model.py#L31 """ - output = {} - plddt = prediction_results['plddt'][:total_num_res] - if 'predicted_aligned_error' in prediction_results: - ptm = confidence.predicted_tm_score( - logits=prediction_results['predicted_aligned_error']['logits'][:total_num_res,:total_num_res], - breaks=prediction_results['predicted_aligned_error']['breaks'], - asym_id=None) - output['ptm'] = ptm - - pae = confidence.compute_predicted_aligned_error( - logits=prediction_results['predicted_aligned_error']['logits'], - breaks=prediction_results['predicted_aligned_error']['breaks']) - max_pae = pae.pop('max_predicted_aligned_error') - - for k,v in pae.items(): - output.update({k:v[:total_num_res, :total_num_res]}) - output['max_predicted_aligned_error'] = max_pae - if multimer_mode: - # Compute the ipTM only for the multimer model. - iptm = confidence.predicted_tm_score( + if type(prediction_results['predicted_aligned_error']) == np.ndarray: + return prediction_results + else: + output = {} + plddt = prediction_results['plddt'][:total_num_res] + if 'predicted_aligned_error' in prediction_results: + ptm = confidence.predicted_tm_score( logits=prediction_results['predicted_aligned_error']['logits'][:total_num_res,:total_num_res], breaks=prediction_results['predicted_aligned_error']['breaks'], - asym_id=prediction_results['predicted_aligned_error']['asym_id'][:total_num_res], - interface=True) - output.update({'iptm' : iptm}) - ranking_confidence = 0.8 * iptm + 0.2 * ptm - output.update({'ranking_confidence' : ranking_confidence}) - if not multimer_mode: - # Monomer models use mean pLDDT for model ranking. - ranking_confidence = np.mean( - plddt) - output.update({'ranking_confidence' : ranking_confidence}) - - return output + asym_id=None) + output['ptm'] = ptm + + pae = confidence.compute_predicted_aligned_error( + logits=prediction_results['predicted_aligned_error']['logits'], + breaks=prediction_results['predicted_aligned_error']['breaks']) + max_pae = pae.pop('max_predicted_aligned_error') + + for k,v in pae.items(): + output.update({k:v[:total_num_res, :total_num_res]}) + output['max_predicted_aligned_error'] = max_pae + if multimer_mode: + # Compute the ipTM only for the multimer model. + iptm = confidence.predicted_tm_score( + logits=prediction_results['predicted_aligned_error']['logits'][:total_num_res,:total_num_res], + breaks=prediction_results['predicted_aligned_error']['breaks'], + asym_id=prediction_results['predicted_aligned_error']['asym_id'][:total_num_res], + interface=True) + output.update({'iptm' : iptm}) + ranking_confidence = 0.8 * iptm + 0.2 * ptm + output.update({'ranking_confidence' : ranking_confidence}) + if not multimer_mode: + # Monomer models use mean pLDDT for model ranking. + ranking_confidence = np.mean( + plddt) + output.update({'ranking_confidence' : ranking_confidence}) + + return output @staticmethod def postprocess( @@ -540,7 +543,6 @@ def postprocess( for model_name, prediction_result in prediction_results.items(): prediction_result.update(AlphaFoldBackend.recalculate_confidence(prediction_result,multimer_mode, total_num_res)) - logging.info(f"prediction_result has keys : {prediction_result.keys()}") # Remove jax dependency from results np_prediction_result = _jnp_to_np(dict(prediction_result)) # Save prediction results to pickle file