Skip to content

Commit

Permalink
update recalculation function
Browse files Browse the repository at this point in the history
  • Loading branch information
dingquanyu committed Jun 17, 2024
1 parent 82f2b3b commit 3655f21
Showing 1 changed file with 35 additions and 33 deletions.
68 changes: 35 additions & 33 deletions alphapulldown/folding_backend/alphafold_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,40 +441,43 @@ def recalculate_confidence(prediction_results: Dict, multimer_mode:bool,
A method that remove pae values of padded residues and recalculate iptm_ptm score again
Modified based on https://github.com/KosinskiLab/alphafold/blob/c844e1bb60a3beb50bb8d562c6be046da1e43e3d/alphafold/model/model.py#L31
"""
output = {}
plddt = prediction_results['plddt'][:total_num_res]
if 'predicted_aligned_error' in prediction_results:
ptm = confidence.predicted_tm_score(
logits=prediction_results['predicted_aligned_error']['logits'][:total_num_res,:total_num_res],
breaks=prediction_results['predicted_aligned_error']['breaks'],
asym_id=None)
output['ptm'] = ptm

pae = confidence.compute_predicted_aligned_error(
logits=prediction_results['predicted_aligned_error']['logits'],
breaks=prediction_results['predicted_aligned_error']['breaks'])
max_pae = pae.pop('max_predicted_aligned_error')

for k,v in pae.items():
output.update({k:v[:total_num_res, :total_num_res]})
output['max_predicted_aligned_error'] = max_pae
if multimer_mode:
# Compute the ipTM only for the multimer model.
iptm = confidence.predicted_tm_score(
if type(prediction_results['predicted_aligned_error']) == np.ndarray:
return prediction_results
else:
output = {}
plddt = prediction_results['plddt'][:total_num_res]
if 'predicted_aligned_error' in prediction_results:
ptm = confidence.predicted_tm_score(
logits=prediction_results['predicted_aligned_error']['logits'][:total_num_res,:total_num_res],
breaks=prediction_results['predicted_aligned_error']['breaks'],
asym_id=prediction_results['predicted_aligned_error']['asym_id'][:total_num_res],
interface=True)
output.update({'iptm' : iptm})
ranking_confidence = 0.8 * iptm + 0.2 * ptm
output.update({'ranking_confidence' : ranking_confidence})
if not multimer_mode:
# Monomer models use mean pLDDT for model ranking.
ranking_confidence = np.mean(
plddt)
output.update({'ranking_confidence' : ranking_confidence})

return output
asym_id=None)
output['ptm'] = ptm

pae = confidence.compute_predicted_aligned_error(
logits=prediction_results['predicted_aligned_error']['logits'],
breaks=prediction_results['predicted_aligned_error']['breaks'])
max_pae = pae.pop('max_predicted_aligned_error')

for k,v in pae.items():
output.update({k:v[:total_num_res, :total_num_res]})
output['max_predicted_aligned_error'] = max_pae
if multimer_mode:
# Compute the ipTM only for the multimer model.
iptm = confidence.predicted_tm_score(
logits=prediction_results['predicted_aligned_error']['logits'][:total_num_res,:total_num_res],
breaks=prediction_results['predicted_aligned_error']['breaks'],
asym_id=prediction_results['predicted_aligned_error']['asym_id'][:total_num_res],
interface=True)
output.update({'iptm' : iptm})
ranking_confidence = 0.8 * iptm + 0.2 * ptm
output.update({'ranking_confidence' : ranking_confidence})
if not multimer_mode:
# Monomer models use mean pLDDT for model ranking.
ranking_confidence = np.mean(
plddt)
output.update({'ranking_confidence' : ranking_confidence})

return output

@staticmethod
def postprocess(
Expand Down Expand Up @@ -540,7 +543,6 @@ def postprocess(
for model_name, prediction_result in prediction_results.items():
prediction_result.update(AlphaFoldBackend.recalculate_confidence(prediction_result,multimer_mode,
total_num_res))
logging.info(f"prediction_result has keys : {prediction_result.keys()}")
# Remove jax dependency from results
np_prediction_result = _jnp_to_np(dict(prediction_result))
# Save prediction results to pickle file
Expand Down

0 comments on commit 3655f21

Please sign in to comment.