diff --git a/src/aiida_quantumespresso/workflows/pw/base.py b/src/aiida_quantumespresso/workflows/pw/base.py index 4a3ab0381..56b03bb35 100644 --- a/src/aiida_quantumespresso/workflows/pw/base.py +++ b/src/aiida_quantumespresso/workflows/pw/base.py @@ -444,19 +444,13 @@ def handle_diagonalization_errors(self, calculation): def handle_out_of_walltime(self, calculation): """Handle `ERROR_OUT_OF_WALLTIME` exit code. - In this case the calculation shut down neatly and we can simply restart. We consider two cases: - - 1. If the structure is unchanged, we do a full restart. - 2. If the structure has changed during the calculation, we restart from scratch. + In this case the calculation shut down cleanly and we can do a full restart. """ - try: + if 'output_structure' in calculation.outputs: self.ctx.inputs.structure = calculation.outputs.output_structure - except exceptions.NotExistent: - self.set_restart_type(RestartType.FULL, calculation.outputs.remote_folder) - self.report_error_handled(calculation, 'simply restart from the last calculation') - else: - self.set_restart_type(RestartType.FROM_SCRATCH) - self.report_error_handled(calculation, 'out of walltime: structure changed so restarting from scratch') + + self.set_restart_type(RestartType.FULL, calculation.outputs.remote_folder) + self.report_error_handled(calculation, "restarting in full with `CONTROL.restart_mode` = 'restart'") return ProcessHandlerReport(True) diff --git a/tests/workflows/pw/test_base.py b/tests/workflows/pw/test_base.py index a1fd4b468..4c2b34e73 100644 --- a/tests/workflows/pw/test_base.py +++ b/tests/workflows/pw/test_base.py @@ -32,12 +32,25 @@ def test_handle_unrecoverable_failure(generate_workchain_pw): assert result == PwBaseWorkChain.exit_codes.ERROR_UNRECOVERABLE_FAILURE -def test_handle_out_of_walltime(generate_workchain_pw, fixture_localhost, generate_remote_data): +@pytest.mark.parametrize('structure_changed', ( + True, + False, +)) +def test_handle_out_of_walltime( + generate_workchain_pw, fixture_localhost, generate_remote_data, generate_structure, structure_changed +): """Test `PwBaseWorkChain.handle_out_of_walltime`.""" - remote_data = generate_remote_data(computer=fixture_localhost, remote_path='/path/to/remote') - process = generate_workchain_pw( - exit_code=PwCalculation.exit_codes.ERROR_OUT_OF_WALLTIME, pw_outputs={'remote_folder': remote_data} - ) + generate_inputs = { + 'exit_code': PwCalculation.exit_codes.ERROR_OUT_OF_WALLTIME, + 'pw_outputs': { + 'remote_folder': generate_remote_data(computer=fixture_localhost, remote_path='/path/to/remote') + } + } + if structure_changed: + output_structure = generate_structure() + generate_inputs['pw_outputs']['output_structure'] = output_structure + + process = generate_workchain_pw(**generate_inputs) process.setup() result = process.handle_electronic_convergence_not_reached(process.ctx.children[-1]) @@ -49,22 +62,8 @@ def test_handle_out_of_walltime(generate_workchain_pw, fixture_localhost, genera result = process.inspect_process() assert result.status == 0 - -def test_handle_out_of_walltime_structure_changed(generate_workchain_pw, generate_structure): - """Test `PwBaseWorkChain.handle_out_of_walltime`.""" - structure = generate_structure() - process = generate_workchain_pw( - exit_code=PwCalculation.exit_codes.ERROR_OUT_OF_WALLTIME, pw_outputs={'output_structure': structure} - ) - process.setup() - - result = process.handle_out_of_walltime(process.ctx.children[-1]) - assert isinstance(result, ProcessHandlerReport) - assert process.ctx.inputs.parameters['CONTROL']['restart_mode'] == 'from_scratch' - assert result.do_break - - result = process.inspect_process() - assert result.status == 0 + if structure_changed: + assert process.ctx.inputs.structure == output_structure def test_handle_electronic_convergence_not_reached(generate_workchain_pw, fixture_localhost, generate_remote_data):