Skip to content

Commit

Permalink
Update on reftraj workchain (#176)
Browse files Browse the repository at this point in the history
* The calculation of a trajectory will be split into several batches as close as possible to 1+num_batches,
the first batch always contains one element (see below).
* Batches contain only consecutive stepIDs example `{0:[3],1:[7,8,9],2:[11],3:[14,15,6]}`. This means that
missing stepIDs have been completed in a previous work chain.
* Only `num_batches` calculations can be submitted at the same time.
  • Loading branch information
cpignedoli authored Feb 11, 2025
1 parent de579d1 commit 3008b07
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ standard:
PRINT:
RESTART_HISTORY:
_: 'OFF'
RESTART:
EACH:
MD: '10'
FORCES:
EACH:
MD: '1'
Expand Down Expand Up @@ -103,6 +106,9 @@ low_accuracy:
PRINT:
RESTART_HISTORY:
_: 'OFF'
RESTART:
EACH:
MD: '10'
FORCES:
EACH:
MD: '1'
Expand Down Expand Up @@ -186,6 +192,9 @@ debug:
PRINT:
RESTART_HISTORY:
_: 'OFF'
RESTART:
EACH:
MD: '1'
FORCES:
EACH:
MD: '1'
Expand Down
156 changes: 99 additions & 57 deletions aiida_nanotech_empa/workflows/cp2k/reftraj_md_workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def retireve_previous_trajectories(reftraj_wc):
cp2k_calcs = [
calc
for calc in base_wc.called_descendants
if calc.process_label == "Cp2kCalculation" and calc.is_finished_ok
if calc.process_label
== "Cp2kCalculation" # and calc.is_finished_ok
]
for calc in cp2k_calcs:
trajectories.append(calc.outputs.output_trajectory)
Expand Down Expand Up @@ -108,37 +109,45 @@ def merge_trajectories(*trajectories):

# @engine.calcfunction
def create_batches(trajectory, num_batches, steps_completed):
"""Create lists of consecutive integers. Counting start from 1 for CP2K input. The first list contains only one element."""

"""Create balanced batches of consecutive or isolated indices.
- The first batch contains only one element.
- Each batch has at most ceil(total_number_of_frames / num_batches) frames.
- The constraint of num_batches is overridden if necessary due to consecutive constraints.
"""
# Generate the initial list of indices
input_list = [i + 1 for i in range(trajectory.get_shape("positions")[0])]
for i in steps_completed:
input_list.remove(i)

# Remove steps that have already been completed
input_list = [i for i in input_list if i not in steps_completed]

if len(input_list) == 0:
return {}
# If there are fewer elements than num_batches + 1, return each element as a separate list
if len(input_list) < num_batches.value + 1:
return {i: [value] for i, value in enumerate(input_list)}

# Initialize the batches with the first batch containing only the first element
batches = [[input_list[0]]]

# Calculate the number of remaining elements to distribute among other batches
remaining_elements = input_list[1:]
total_remaining = len(remaining_elements)

# Calculate the minimum number of elements each batch must have
min_elements_per_batch = total_remaining // num_batches.value
extra_elements = (
total_remaining % num_batches.value
) # Determine how many batches will have an extra element

start_idx = 0
for i in range(num_batches.value):
# If there are extra elements, add one more to this batch
end_idx = start_idx + min_elements_per_batch + (1 if i < extra_elements else 0)
batches.append(remaining_elements[start_idx:end_idx])
start_idx = end_idx

total_elements = len(input_list)
max_batch_size = int((total_elements) / num_batches.value) + 1

batches = []
current_batch = [input_list[0]]

for i in range(1, len(input_list)):
current = input_list[i]
previous = input_list[i - 1]

# Start a new batch if:
# 1. There's a gap (non-consecutive)
# 2. The current batch has reached the max batch size
if current != previous + 1 or len(current_batch) >= max_batch_size:
batches.append(current_batch)
current_batch = [current]
else:
current_batch.append(current)

# Add the last batch
batches.append(current_batch)

# Ensure the first batch contains only one element
if len(batches[0]) > 1:
batches = [[batches[0][0]]] + [batches[0][1:]] + batches[1:]

return dict(enumerate(batches))

Expand Down Expand Up @@ -176,10 +185,16 @@ def define(cls, spec):

spec.outline(
cls.setup, # create batches, if reordering of structures create indexing
engine.if_(cls.something_to_run)(
engine.if_(cls.still_batches_to_run)(
cls.first_structure, # Run the first SCF to get the initial wavefunction
cls.check_submitted_batches,
),
engine.while_(
cls.still_batches_to_run
)( # Run the batches of the reftraj simulations
cls.run_reftraj_batches,
), # Run the batches of the reftraj simulations
cls.check_submitted_batches,
),
cls.merge_batches_output,
)

Expand All @@ -198,12 +213,16 @@ def setup(self):
last_wc = last_reftraj_wc(self.inputs.trajectory)
self.report(f"Restrating from last workchain found: {last_wc}")
previous_trajectories = retireve_previous_trajectories(last_wc)
self.report(
f"Retrieved {len(previous_trajectories)} trajectories {[traj.pk for traj in previous_trajectories]}"
)
self.ctx.previuos_trajectory = merge_trajectory_data_unique(
*previous_trajectories
)
self.ctx.steps_completed = (
self.ctx.previuos_trajectory.get_stepids().tolist()
)
self.report(f"Steps previously completed: {self.ctx.steps_completed}")

(
self.ctx.files,
Expand All @@ -215,25 +234,31 @@ def setup(self):
"md_reftraj_protocol.yml",
self.inputs.protocol.value,
)
self.ctx.input_dict["GLOBAL"]["WALLTIME"] = max(
600, self.inputs.options["max_wallclock_seconds"] - 600
self.ctx.input_dict["GLOBAL"]["WALLTIME"] = (
self.inputs.options["max_wallclock_seconds"] - 600
if self.inputs.options["max_wallclock_seconds"] > 600
else self.inputs.options["max_wallclock_seconds"]
)
# create batches avoiding steps already completed.
self.ctx.something_to_run = False
self.ctx.batches_to_be_done = []
self.ctx.batches = create_batches(
self.inputs.trajectory, self.inputs.num_batches, self.ctx.steps_completed
)
if len(self.ctx.batches) > 0:
self.ctx.something_to_run = True
self.report(f"Created {len(self.ctx.batches)} batches {self.ctx.batches}")
self.ctx.n_batches = len(self.ctx.batches)
self.ctx.batches_to_be_done = [i for i in range(self.ctx.n_batches)]
self.ctx.batches_to_check = []
return engine.ExitCode(0)

def something_to_run(self):
"""Function that returnns whether to run or not soem batch"""
return self.ctx.something_to_run
def still_batches_to_run(self):
"""Check if there are still batches to run."""
return len(self.ctx.batches_to_be_done) > 0

def first_structure(self):
"""Run scf on the initial geometry."""
self.ctx.batches_to_be_done.remove(0)
self.ctx.batches_to_check.append(0)
input_dict = deepcopy(self.ctx.input_dict)
batch = self.ctx.batches[0]

Expand All @@ -244,6 +269,10 @@ def first_structure(self):

# create the input for the reftraj workchain
builder = Cp2kBaseWorkChain.get_builder()
# Switch on resubmit_unconverged_geometry disabled by default.
builder.handler_overrides = orm.Dict(
{"restart_incomplete_calculation": {"enabled": True}}
)
builder.cp2k.structure = orm.StructureData(ase=self.ctx.structure_with_tags)
builder.cp2k.trajectory = self.inputs.trajectory
builder.cp2k.code = self.inputs.code
Expand All @@ -263,25 +292,46 @@ def first_structure(self):

self.to_context(**{key: future})

def check_submitted_batches(self):
"""Check if submitted calculations completed."""
for batch in self.ctx.batches_to_check:
key = f"reftraj_batch_{batch}"
if not getattr(self.ctx, key).is_finished_ok:
self.report(f"Batch {key} failed")
return self.exit_codes.ERROR_TERMINATION
self.ctx.batches_to_check.remove(batch)
return engine.ExitCode(0)

def run_reftraj_batches(self):
"""Check if all calculations completed and merge trejectories."""
key0 = "reftraj_batch_0"
if not getattr(self.ctx, key0).is_finished_ok:
self.report(f"Batch {key0} failed")
return self.exit_codes.ERROR_TERMINATION
for batch in range(1, self.ctx.n_batches):
"""Submit remaining batches up to the maximum number of concurrent calculations."""
# Submit new batches if under the limit
count = 0
n_to_submit = min(
len(self.ctx.batches_to_be_done), self.inputs.num_batches.value
)
while count < n_to_submit:
self.report(f"batches to be done {self.ctx.batches_to_be_done}")
batch = self.ctx.batches_to_be_done[0]
self.ctx.batches_to_be_done.remove(batch)

count += 1

self.ctx.batches_to_check.append(batch)
first = self.ctx.batches[batch][0]
last = self.ctx.batches[batch][-1]
self.report(f"Running structures {first} to {last} ")
self.report(f"Running structures {first} to {last}")

# update the input_dict with the new batch
# Update the input_dict with the new batch
input_dict = deepcopy(self.ctx.input_dict)
input_dict["MOTION"]["MD"]["STEPS"] = 1 + first - last
input_dict["MOTION"]["MD"]["STEPS"] = 1 + last - first
input_dict["MOTION"]["MD"]["REFTRAJ"]["FIRST_SNAPSHOT"] = first
input_dict["MOTION"]["MD"]["REFTRAJ"]["LAST_SNAPSHOT"] = last

# create the input for the reftraj workchain
# Create the input for the reftraj workchain
builder = Cp2kBaseWorkChain.get_builder()
builder.handler_overrides = orm.Dict(
{"restart_incomplete_calculation": {"enabled": True}}
)
builder.cp2k.structure = orm.StructureData(ase=self.ctx.structure_with_tags)
builder.cp2k.trajectory = self.inputs.trajectory
builder.cp2k.code = self.inputs.code
Expand All @@ -291,29 +341,21 @@ def run_reftraj_batches(self):
builder.cp2k.metadata.options.parser_name = "cp2k_advanced_parser"
builder.cp2k.parameters = orm.Dict(dict=input_dict)
builder.cp2k.parent_calc_folder = getattr(
self.ctx, key0
self.ctx, "reftraj_batch_0"
).outputs.remote_folder

future = self.submit(builder)

key = f"reftraj_batch_{batch}"
self.report(f"Submitted reftraj batch: {key} with pk: {future.pk}")

self.to_context(**{key: future})

def merge_batches_output(self):
"""Merge the output of the succefull batches only."""

# merged_traj = []
# for i_batch in range(self.ctx.n_batches):
# merged_traj.extend(self.ctx[f"reftraj_batch_{i_batch}"].outputs.trajectory)
"""Merge the output of the batches."""

trajectories_to_merge = []
for batch in self.ctx.batches:
key = f"reftraj_batch_{batch}"
if not getattr(self.ctx, key).is_finished_ok:
self.report(f"Batch {key} failed")
return self.exit_codes.ERROR_TERMINATION
trajectories_to_merge.append(
getattr(self.ctx, key).outputs.output_trajectory
)
Expand Down
Loading

0 comments on commit 3008b07

Please sign in to comment.