Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 17, 2024
1 parent 0caa390 commit af7c694
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
18 changes: 11 additions & 7 deletions aiida_nanotech_empa/workflows/cp2k/reftraj_md_workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,26 @@ def merge_trajectories(*trajectories):
arrays = {}
traj_keys = trajectories[0].get_arraynames()
symbols = trajectories[0].symbols
traj_keys.remove('steps')
traj_keys.remove("steps")
for key in traj_keys:
arrays[key]=[]
arrays[key] = []
for trajectory in trajectories:
for key in traj_keys:
arrays[key].append(trajectory.get_array(key))

merged_trajectory = TrajectoryData()
if 'cells' in traj_keys:
merged_trajectory.set_trajectory(symbols, np.concatenate(arrays['positions']),cells=np.concatenate(arrays['cells']))
if "cells" in traj_keys:
merged_trajectory.set_trajectory(
symbols,
np.concatenate(arrays["positions"]),
cells=np.concatenate(arrays["cells"]),
)
else:
merged_trajectory.set_trajectory(symbols, np.concatenate(arrays['positions']))
traj_keys = [key for key in traj_keys if key not in ['cells','positions']]
merged_trajectory.set_trajectory(symbols, np.concatenate(arrays["positions"]))
traj_keys = [key for key in traj_keys if key not in ["cells", "positions"]]
for key in traj_keys:
merged_trajectory.set_array(key, np.concatenate(arrays[key]))

return merged_trajectory


Expand Down
18 changes: 10 additions & 8 deletions examples/workflows/example_cp2k_md_reftraj.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
TrajectoryData = DataFactory("core.array.trajectory")


def _example_cp2k_reftraj(cp2k_code,num_batches=2):
def _example_cp2k_reftraj(cp2k_code, num_batches=2):
os.path.dirname(os.path.realpath(__file__))

# Structure.
Expand Down Expand Up @@ -82,16 +82,18 @@ def example_cp2k_reftraj(cp2k_code):
@click.option("-c", "--n-cores-per-node", default=1)
def run_all(cp2k_code, n_nodes, n_cores_per_node):
print("#### RKS one batch")
uuid1 = _example_cp2k_reftraj(
cp2k_code=orm.load_code(cp2k_code),num_batches=1
)
uuid1 = _example_cp2k_reftraj(cp2k_code=orm.load_code(cp2k_code), num_batches=1)
print("#### RKS two batches")
uuid2 = _example_cp2k_reftraj(
cp2k_code=orm.load_code(cp2k_code),num_batches=2
)
uuid2 = _example_cp2k_reftraj(cp2k_code=orm.load_code(cp2k_code), num_batches=2)
traj1 = orm.load_node(uuid1).outputs.output_trajectory
traj2 = orm.load_node(uuid2).outputs.output_trajectory
assert np.allclose(traj1.get_array('cells'), traj2.get_array('cells'), rtol=1e-07, atol=1e-08, equal_nan=False)
assert np.allclose(
traj1.get_array("cells"),
traj2.get_array("cells"),
rtol=1e-07,
atol=1e-08,
equal_nan=False,
)
print(f"arrays match")


Expand Down

0 comments on commit af7c694

Please sign in to comment.