diff --git a/aiida_nanotech_empa/workflows/cp2k/reftraj_md_workchain.py b/aiida_nanotech_empa/workflows/cp2k/reftraj_md_workchain.py index a5b01cd..b6e0da8 100644 --- a/aiida_nanotech_empa/workflows/cp2k/reftraj_md_workchain.py +++ b/aiida_nanotech_empa/workflows/cp2k/reftraj_md_workchain.py @@ -10,6 +10,31 @@ # Cp2kRefTrajWorkChain = plugins.WorkflowFactory("cp2k.reftraj") TrajectoryData = plugins.DataFactory("array.trajectory") +@engine.calcfunction +def merge_trajectories(*trajectories): + """Merge a list of trajectories into a single one.""" + positions=[] + cells=[] + forces=[] + for trajectory in trajectories: + positions.append(trajectory.get_array("positions") ) + try: + cells.append(trajectory.get_array("cells")) + except KeyError: + pass + forces.append(trajectory.get_array("forces")) + symbols = trajectory.symbols + positions = np.concatenate(positions) + cells = np.concatenate(cells) + forces = np.concatenate(forces) + merged_trajectory = TrajectoryData() + if len(cells) == 0: + merged_trajectory.set_trajectory(symbols, positions) + else: + merged_trajectory.set_trajectory(symbols, positions, cells=cells) + merged_trajectory.set_array("forces", forces) + + return merged_trajectory @engine.calcfunction def create_batches(trajectory, num_batches, steps_completed): @@ -31,7 +56,12 @@ def create_batches(trajectory, num_batches, steps_completed): current_list = [] if current_list: consecutive_lists.append(current_list) - return orm.List(consecutive_lists) + + # [[1,2,3],[4,5,6]] --> [[1],[2,3],[4,5,6]] + batches = [[consecutive_lists[0].pop(0)]] + for batch in consecutive_lists: + batches.append(batch) + return orm.List(batches) class Cp2kRefTrajWorkChain(engine.WorkChain): @@ -104,14 +134,12 @@ def setup(self): def first_structure(self): """Run scf on the initial geometry.""" input_dict = deepcopy(self.ctx.input_dict) - batches = self.ctx.batches - first_snapshot = batches[0].pop(0) - self.ctx.batches = batches + batch = self.ctx.batches[0] - self.report(f"Running structure {first_snapshot} to {first_snapshot} ") + self.report(f"Running structure {batch[0]} to {batch[-1]} ") - input_dict["MOTION"]["MD"]["REFTRAJ"]["FIRST_SNAPSHOT"] = first_snapshot - input_dict["MOTION"]["MD"]["REFTRAJ"]["LAST_SNAPSHOT"] = first_snapshot + input_dict["MOTION"]["MD"]["REFTRAJ"]["FIRST_SNAPSHOT"] = batch[0] + input_dict["MOTION"]["MD"]["REFTRAJ"]["LAST_SNAPSHOT"] = batch[-1] # create the input for the reftraj workchain builder = Cp2kBaseWorkChain.get_builder() @@ -122,22 +150,25 @@ def first_structure(self): if "parent_calc_folder" in self.inputs: builder.cp2k.parent_calc_folder = self.inputs.parent_calc_folder builder.cp2k.metadata.options = self.inputs.options - builder.cp2k.metadata.label = f"structures_{first_snapshot}_to_{first_snapshot}" + builder.cp2k.metadata.label = f"structures_{batch[0]}_to_{batch[-1]}" builder.cp2k.metadata.options.parser_name = "cp2k_advanced_parser" builder.cp2k.parameters = orm.Dict(dict=input_dict) future = self.submit(builder) - self.report( - f"Submitted structures {first_snapshot} to {first_snapshot}: {future.pk}" - ) - self.to_context(first_structure=future) + + key = f"reftraj_batch_{batch[0]}_to_{batch[-1]}" + self.report(f"Submitted reftraj batch: {key} with pk: {future.pk}") + + self.to_context(**{key: future}) def run_reftraj_batches(self): - self.report(f"Running the reftraj batches {self.ctx.batches} ") - if not self.ctx.first_structure.is_finished_ok: + """Check if all calculations completed and merge trejectories.""" + key0 = f"reftraj_batch_{self.ctx.batches[0][0]}_to_{self.ctx.batches[0][0]}" + if not getattr(self.ctx, key0).is_finished_ok: + self.report(f"Batch {ke0y} failed") return self.exit_codes.ERROR_TERMINATION - for batch in self.ctx.batches: + for batch in self.ctx.batches[1:]: self.report(f"Running structures {batch[0]} to {batch[-1]} ") # update the input_dict with the new batch @@ -156,10 +187,8 @@ def run_reftraj_batches(self): builder.cp2k.metadata.label = f"structures_{batch[0]}_to_{batch[-1]}" builder.cp2k.metadata.options.parser_name = "cp2k_advanced_parser" builder.cp2k.parameters = orm.Dict(dict=input_dict) - builder.cp2k.parent_calc_folder = ( - self.ctx.first_structure.outputs.remote_folder - ) - + builder.cp2k.parent_calc_folder = getattr(self.ctx, key0).outputs.remote_folder + future = self.submit(builder) key = f"reftraj_batch_{batch[0]}_to_{batch[-1]}" @@ -173,34 +202,17 @@ def merge_batches_output(self): # merged_traj = [] # for i_batch in range(self.ctx.n_batches): # merged_traj.extend(self.ctx[f"reftraj_batch_{i_batch}"].outputs.trajectory) - positions = [ - self.ctx.first_structure.outputs.output_trajectory.get_array("positions") - ] - cells = [self.ctx.first_structure.outputs.output_trajectory.get_array("cells")] - forces = [ - self.ctx.first_structure.outputs.output_trajectory.get_array("forces") - ] - for batch in self.ctx.batches: + + + trajectories_to_merge=[getattr(self.ctx, f"reftraj_batch_{self.ctx.batches[0][0]}_to_{self.ctx.batches[0][0]}").outputs.output_trajectory] + for batch in self.ctx.batches[1:]: key = f"reftraj_batch_{batch[0]}_to_{batch[-1]}" if not getattr(self.ctx, key).is_finished_ok: self.report(f"Batch {key} failed") return self.exit_codes.ERROR_TERMINATION - positions.append( - getattr(self.ctx, key).outputs.output_trajectory.get_array("positions") - ) - cells.append( - getattr(self.ctx, key).outputs.output_trajectory.get_array("cells") - ) - forces.append( - getattr(self.ctx, key).outputs.output_trajectory.get_array("forces") - ) - - positions = np.concatenate(positions) - cells = np.concatenate(cells) - forces = np.concatenate(forces) - symbols = self.ctx.first_structure.outputs.output_trajectory.symbols - output_trajectory = TrajectoryData() - output_trajectory.set_trajectory(symbols, positions, cells=cells) - self.out("output_trajectory", output_trajectory) + trajectories_to_merge.append(getattr(self.ctx, key).outputs.output_trajectory) + merged_trajectory = merge_trajectories(*trajectories_to_merge) + + self.out("output_trajectory", merged_trajectory) self.report("done") return engine.ExitCode(0) diff --git a/examples/workflows/example_cp2k_md_reftraj.py b/examples/workflows/example_cp2k_md_reftraj.py index 6987430..ec9b17f 100644 --- a/examples/workflows/example_cp2k_md_reftraj.py +++ b/examples/workflows/example_cp2k_md_reftraj.py @@ -1,5 +1,4 @@ import os -import random import ase.io import click @@ -22,15 +21,15 @@ def _example_cp2k_reftraj(cp2k_code): positions = np.array( [ [ - [2.52851027, 3.96611323, 3.75 + 0.05 * random.random()], - [2.52851027, 3.96611323, 3], + [2.528, 3.966, 3.75 + 0.0001 * i], + [2.528, 3.966, 3], ] for i in range(steps) ] ) cells = np.array( [ - [[5, 0, 0], [0, 5, 0], [0, 0, 5 + 0.05 * random.random()]] + [[5, 0, 0], [0, 5, 0], [0, 0, 5 + 0.0001 * i]] for i in range(steps) ] ) @@ -56,7 +55,7 @@ def _example_cp2k_reftraj(cp2k_code): # builder.structure = structure builder.trajectory = trajectory - builder.num_batches = orm.Int(2) + builder.num_batches = orm.Int(3) builder.protocol = orm.Str("debug") dft_params = {