Skip to content

Commit

Permalink
Remove jax dependency from the result.pkl
Browse files Browse the repository at this point in the history
  • Loading branch information
DimaMolod committed Dec 18, 2023
1 parent 4bf71fd commit 06461b8
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion alphapulldown/predict_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from alphafold.relax import relax
import numpy as np
from alphapulldown.utils import get_run_alphafold
import jax.numpy as jnp


run_af = get_run_alphafold()
Expand All @@ -24,6 +25,15 @@

ModelsToRelax = run_af.ModelsToRelax

def _jnp_to_np(output):
"""Recursively changes jax arrays to numpy arrays."""
for k, v in output.items():
if isinstance(v, dict):
output[k] = _jnp_to_np(v)
elif isinstance(v, jnp.ndarray):
output[k] = np.array(v)
return output

def get_score_from_result_pkl(pkl_path):
"""Get the score from the model result pkl file"""

Expand Down Expand Up @@ -176,9 +186,12 @@ def predict(
plddt = prediction_result["plddt"]
ranking_confidences[model_name] = prediction_result["ranking_confidence"]

# Remove jax dependency from results.
np_prediction_result = _jnp_to_np(dict(prediction_result))

result_output_path = os.path.join(output_dir, f"result_{model_name}.pkl")
with open(result_output_path, "wb") as f:
pickle.dump(prediction_result, f, protocol=4)
pickle.dump(np_prediction_result, f, protocol=4)

plddt_b_factors = np.repeat(
plddt[:, None], residue_constants.atom_type_num, axis=-1
Expand Down

0 comments on commit 06461b8

Please sign in to comment.