Skip to content

Commit

Permalink
working, tbd: loop on keys of available arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
cpignedoli committed Feb 16, 2024
1 parent 360672a commit 1701a2d
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 49 deletions.
100 changes: 56 additions & 44 deletions aiida_nanotech_empa/workflows/cp2k/reftraj_md_workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,31 @@
# Cp2kRefTrajWorkChain = plugins.WorkflowFactory("cp2k.reftraj")
TrajectoryData = plugins.DataFactory("array.trajectory")

@engine.calcfunction
def merge_trajectories(*trajectories):
"""Merge a list of trajectories into a single one."""
positions=[]
cells=[]
forces=[]
for trajectory in trajectories:
positions.append(trajectory.get_array("positions") )
try:
cells.append(trajectory.get_array("cells"))
except KeyError:
pass
forces.append(trajectory.get_array("forces"))
symbols = trajectory.symbols
positions = np.concatenate(positions)
cells = np.concatenate(cells)
forces = np.concatenate(forces)
merged_trajectory = TrajectoryData()
if len(cells) == 0:
merged_trajectory.set_trajectory(symbols, positions)
else:
merged_trajectory.set_trajectory(symbols, positions, cells=cells)
merged_trajectory.set_array("forces", forces)

return merged_trajectory

@engine.calcfunction
def create_batches(trajectory, num_batches, steps_completed):
Expand All @@ -31,7 +56,12 @@ def create_batches(trajectory, num_batches, steps_completed):
current_list = []
if current_list:
consecutive_lists.append(current_list)
return orm.List(consecutive_lists)

# [[1,2,3],[4,5,6]] --> [[1],[2,3],[4,5,6]]
batches = [[consecutive_lists[0].pop(0)]]
for batch in consecutive_lists:
batches.append(batch)
return orm.List(batches)


class Cp2kRefTrajWorkChain(engine.WorkChain):
Expand Down Expand Up @@ -104,14 +134,12 @@ def setup(self):
def first_structure(self):
"""Run scf on the initial geometry."""
input_dict = deepcopy(self.ctx.input_dict)
batches = self.ctx.batches
first_snapshot = batches[0].pop(0)
self.ctx.batches = batches
batch = self.ctx.batches[0]

self.report(f"Running structure {first_snapshot} to {first_snapshot} ")
self.report(f"Running structure {batch[0]} to {batch[-1]} ")

input_dict["MOTION"]["MD"]["REFTRAJ"]["FIRST_SNAPSHOT"] = first_snapshot
input_dict["MOTION"]["MD"]["REFTRAJ"]["LAST_SNAPSHOT"] = first_snapshot
input_dict["MOTION"]["MD"]["REFTRAJ"]["FIRST_SNAPSHOT"] = batch[0]
input_dict["MOTION"]["MD"]["REFTRAJ"]["LAST_SNAPSHOT"] = batch[-1]

# create the input for the reftraj workchain
builder = Cp2kBaseWorkChain.get_builder()
Expand All @@ -122,22 +150,25 @@ def first_structure(self):
if "parent_calc_folder" in self.inputs:
builder.cp2k.parent_calc_folder = self.inputs.parent_calc_folder
builder.cp2k.metadata.options = self.inputs.options
builder.cp2k.metadata.label = f"structures_{first_snapshot}_to_{first_snapshot}"
builder.cp2k.metadata.label = f"structures_{batch[0]}_to_{batch[-1]}"
builder.cp2k.metadata.options.parser_name = "cp2k_advanced_parser"

builder.cp2k.parameters = orm.Dict(dict=input_dict)

future = self.submit(builder)
self.report(
f"Submitted structures {first_snapshot} to {first_snapshot}: {future.pk}"
)
self.to_context(first_structure=future)

key = f"reftraj_batch_{batch[0]}_to_{batch[-1]}"
self.report(f"Submitted reftraj batch: {key} with pk: {future.pk}")

self.to_context(**{key: future})

def run_reftraj_batches(self):
self.report(f"Running the reftraj batches {self.ctx.batches} ")
if not self.ctx.first_structure.is_finished_ok:
"""Check if all calculations completed and merge trejectories."""
key0 = f"reftraj_batch_{self.ctx.batches[0][0]}_to_{self.ctx.batches[0][0]}"
if not getattr(self.ctx, key0).is_finished_ok:
self.report(f"Batch {ke0y} failed")
return self.exit_codes.ERROR_TERMINATION
for batch in self.ctx.batches:
for batch in self.ctx.batches[1:]:
self.report(f"Running structures {batch[0]} to {batch[-1]} ")

# update the input_dict with the new batch
Expand All @@ -156,10 +187,8 @@ def run_reftraj_batches(self):
builder.cp2k.metadata.label = f"structures_{batch[0]}_to_{batch[-1]}"
builder.cp2k.metadata.options.parser_name = "cp2k_advanced_parser"
builder.cp2k.parameters = orm.Dict(dict=input_dict)
builder.cp2k.parent_calc_folder = (
self.ctx.first_structure.outputs.remote_folder
)

builder.cp2k.parent_calc_folder = getattr(self.ctx, key0).outputs.remote_folder

future = self.submit(builder)

key = f"reftraj_batch_{batch[0]}_to_{batch[-1]}"
Expand All @@ -173,34 +202,17 @@ def merge_batches_output(self):
# merged_traj = []
# for i_batch in range(self.ctx.n_batches):
# merged_traj.extend(self.ctx[f"reftraj_batch_{i_batch}"].outputs.trajectory)
positions = [
self.ctx.first_structure.outputs.output_trajectory.get_array("positions")
]
cells = [self.ctx.first_structure.outputs.output_trajectory.get_array("cells")]
forces = [
self.ctx.first_structure.outputs.output_trajectory.get_array("forces")
]
for batch in self.ctx.batches:


trajectories_to_merge=[getattr(self.ctx, f"reftraj_batch_{self.ctx.batches[0][0]}_to_{self.ctx.batches[0][0]}").outputs.output_trajectory]
for batch in self.ctx.batches[1:]:
key = f"reftraj_batch_{batch[0]}_to_{batch[-1]}"
if not getattr(self.ctx, key).is_finished_ok:
self.report(f"Batch {key} failed")
return self.exit_codes.ERROR_TERMINATION
positions.append(
getattr(self.ctx, key).outputs.output_trajectory.get_array("positions")
)
cells.append(
getattr(self.ctx, key).outputs.output_trajectory.get_array("cells")
)
forces.append(
getattr(self.ctx, key).outputs.output_trajectory.get_array("forces")
)

positions = np.concatenate(positions)
cells = np.concatenate(cells)
forces = np.concatenate(forces)
symbols = self.ctx.first_structure.outputs.output_trajectory.symbols
output_trajectory = TrajectoryData()
output_trajectory.set_trajectory(symbols, positions, cells=cells)
self.out("output_trajectory", output_trajectory)
trajectories_to_merge.append(getattr(self.ctx, key).outputs.output_trajectory)
merged_trajectory = merge_trajectories(*trajectories_to_merge)

self.out("output_trajectory", merged_trajectory)
self.report("done")
return engine.ExitCode(0)
9 changes: 4 additions & 5 deletions examples/workflows/example_cp2k_md_reftraj.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import random

import ase.io
import click
Expand All @@ -22,15 +21,15 @@ def _example_cp2k_reftraj(cp2k_code):
positions = np.array(
[
[
[2.52851027, 3.96611323, 3.75 + 0.05 * random.random()],
[2.52851027, 3.96611323, 3],
[2.528, 3.966, 3.75 + 0.0001 * i],
[2.528, 3.966, 3],
]
for i in range(steps)
]
)
cells = np.array(
[
[[5, 0, 0], [0, 5, 0], [0, 0, 5 + 0.05 * random.random()]]
[[5, 0, 0], [0, 5, 0], [0, 0, 5 + 0.0001 * i]]
for i in range(steps)
]
)
Expand All @@ -56,7 +55,7 @@ def _example_cp2k_reftraj(cp2k_code):

# builder.structure = structure
builder.trajectory = trajectory
builder.num_batches = orm.Int(2)
builder.num_batches = orm.Int(3)
builder.protocol = orm.Str("debug")

dft_params = {
Expand Down

0 comments on commit 1701a2d

Please sign in to comment.