diff --git a/alphapulldown/predict_structure.py b/alphapulldown/predict_structure.py index 2abbc6ad..965b59b0 100644 --- a/alphapulldown/predict_structure.py +++ b/alphapulldown/predict_structure.py @@ -13,6 +13,7 @@ from alphafold.relax import relax import numpy as np from alphapulldown.utils import get_run_alphafold +import jax.numpy as jnp run_af = get_run_alphafold() @@ -24,6 +25,15 @@ ModelsToRelax = run_af.ModelsToRelax +def _jnp_to_np(output): + """Recursively changes jax arrays to numpy arrays.""" + for k, v in output.items(): + if isinstance(v, dict): + output[k] = _jnp_to_np(v) + elif isinstance(v, jnp.ndarray): + output[k] = np.array(v) + return output + def get_score_from_result_pkl(pkl_path): """Get the score from the model result pkl file""" @@ -176,9 +186,12 @@ def predict( plddt = prediction_result["plddt"] ranking_confidences[model_name] = prediction_result["ranking_confidence"] + # Remove jax dependency from results. + np_prediction_result = _jnp_to_np(dict(prediction_result)) + result_output_path = os.path.join(output_dir, f"result_{model_name}.pkl") with open(result_output_path, "wb") as f: - pickle.dump(prediction_result, f, protocol=4) + pickle.dump(np_prediction_result, f, protocol=4) plddt_b_factors = np.repeat( plddt[:, None], residue_constants.atom_type_num, axis=-1