From af7c69416cb39873825817c3ff8c21fff50bc50a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 17 Feb 2024 08:10:26 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../workflows/cp2k/reftraj_md_workchain.py | 18 +++++++++++------- examples/workflows/example_cp2k_md_reftraj.py | 18 ++++++++++-------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/aiida_nanotech_empa/workflows/cp2k/reftraj_md_workchain.py b/aiida_nanotech_empa/workflows/cp2k/reftraj_md_workchain.py index b976287..1fea1e1 100644 --- a/aiida_nanotech_empa/workflows/cp2k/reftraj_md_workchain.py +++ b/aiida_nanotech_empa/workflows/cp2k/reftraj_md_workchain.py @@ -18,22 +18,26 @@ def merge_trajectories(*trajectories): arrays = {} traj_keys = trajectories[0].get_arraynames() symbols = trajectories[0].symbols - traj_keys.remove('steps') + traj_keys.remove("steps") for key in traj_keys: - arrays[key]=[] + arrays[key] = [] for trajectory in trajectories: for key in traj_keys: arrays[key].append(trajectory.get_array(key)) merged_trajectory = TrajectoryData() - if 'cells' in traj_keys: - merged_trajectory.set_trajectory(symbols, np.concatenate(arrays['positions']),cells=np.concatenate(arrays['cells'])) + if "cells" in traj_keys: + merged_trajectory.set_trajectory( + symbols, + np.concatenate(arrays["positions"]), + cells=np.concatenate(arrays["cells"]), + ) else: - merged_trajectory.set_trajectory(symbols, np.concatenate(arrays['positions'])) - traj_keys = [key for key in traj_keys if key not in ['cells','positions']] + merged_trajectory.set_trajectory(symbols, np.concatenate(arrays["positions"])) + traj_keys = [key for key in traj_keys if key not in ["cells", "positions"]] for key in traj_keys: merged_trajectory.set_array(key, np.concatenate(arrays[key])) - + return merged_trajectory diff --git a/examples/workflows/example_cp2k_md_reftraj.py b/examples/workflows/example_cp2k_md_reftraj.py index 1677787..618c226 100644 --- a/examples/workflows/example_cp2k_md_reftraj.py +++ b/examples/workflows/example_cp2k_md_reftraj.py @@ -10,7 +10,7 @@ TrajectoryData = DataFactory("core.array.trajectory") -def _example_cp2k_reftraj(cp2k_code,num_batches=2): +def _example_cp2k_reftraj(cp2k_code, num_batches=2): os.path.dirname(os.path.realpath(__file__)) # Structure. @@ -82,16 +82,18 @@ def example_cp2k_reftraj(cp2k_code): @click.option("-c", "--n-cores-per-node", default=1) def run_all(cp2k_code, n_nodes, n_cores_per_node): print("#### RKS one batch") - uuid1 = _example_cp2k_reftraj( - cp2k_code=orm.load_code(cp2k_code),num_batches=1 - ) + uuid1 = _example_cp2k_reftraj(cp2k_code=orm.load_code(cp2k_code), num_batches=1) print("#### RKS two batches") - uuid2 = _example_cp2k_reftraj( - cp2k_code=orm.load_code(cp2k_code),num_batches=2 - ) + uuid2 = _example_cp2k_reftraj(cp2k_code=orm.load_code(cp2k_code), num_batches=2) traj1 = orm.load_node(uuid1).outputs.output_trajectory traj2 = orm.load_node(uuid2).outputs.output_trajectory - assert np.allclose(traj1.get_array('cells'), traj2.get_array('cells'), rtol=1e-07, atol=1e-08, equal_nan=False) + assert np.allclose( + traj1.get_array("cells"), + traj2.get_array("cells"), + rtol=1e-07, + atol=1e-08, + equal_nan=False, + ) print(f"arrays match")