diff --git a/cwltool/argparser.py b/cwltool/argparser.py index 7b3125d94..2638b04f9 100644 --- a/cwltool/argparser.py +++ b/cwltool/argparser.py @@ -596,10 +596,11 @@ def arg_parser() -> argparse.ArgumentParser: parser.add_argument( "--on-error", help="Desired workflow behavior when a step fails. One of 'stop' (do " - "not submit any more steps) or 'continue' (may submit other steps that " - "are not downstream from the error). Default is 'stop'.", + "not submit any more steps), 'continue' (may submit other steps that " + "are not downstream from the error), or 'kill' (same as 'stop', but also " + "terminates running jobs in the active step(s)). Default is 'stop'.", default="stop", - choices=("stop", "continue"), + choices=("stop", "continue", "kill"), ) checkgroup = parser.add_mutually_exclusive_group() diff --git a/cwltool/context.py b/cwltool/context.py index 237a90968..a5bdcd99b 100644 --- a/cwltool/context.py +++ b/cwltool/context.py @@ -172,7 +172,7 @@ def __init__(self, kwargs: Optional[dict[str, Any]] = None) -> None: self.select_resources: Optional[select_resources_callable] = None self.eval_timeout: float = 60 self.postScatterEval: Optional[Callable[[CWLObjectType], Optional[CWLObjectType]]] = None - self.on_error: Union[Literal["stop"], Literal["continue"]] = "stop" + self.on_error: Union[Literal["stop"], Literal["continue"], Literal["kill"]] = "stop" self.strict_memory_limit: bool = False self.strict_cpu_limit: bool = False self.cidfile_dir: Optional[str] = None @@ -189,6 +189,7 @@ def __init__(self, kwargs: Optional[dict[str, Any]] = None) -> None: self.default_stderr: Optional[Union[IO[bytes], TextIO]] = None self.validate_only: bool = False self.validate_stdout: Optional["SupportsWrite[str]"] = None + self.kill_switch: Optional[threading.Event] = None super().__init__(kwargs) if self.tmp_outdir_prefix == "": self.tmp_outdir_prefix = self.tmpdir_prefix diff --git a/cwltool/errors.py b/cwltool/errors.py index 045b9b383..1eb92d21c 100644 --- a/cwltool/errors.py +++ b/cwltool/errors.py @@ -18,3 +18,16 @@ class UnsupportedRequirement(WorkflowException): class ArgumentException(Exception): """Mismatched command line arguments provided.""" + + +class WorkflowKillSwitch(Exception): + """When processStatus != "success" and on-error=kill, raise this exception.""" + + def __init__(self, job_id: str, rcode: int) -> None: + """Record the job identifier and the error code.""" + self.job_id = job_id + self.rcode = rcode + + def __str__(self) -> str: + """Represent this exception as a string.""" + return f"[job {self.job_id}] activated kill switch with return code {self.rcode}" diff --git a/cwltool/executors.py b/cwltool/executors.py index e25426c9d..2a0f69a9c 100644 --- a/cwltool/executors.py +++ b/cwltool/executors.py @@ -20,7 +20,7 @@ from .context import RuntimeContext, getdefault from .cuda import cuda_version_and_device_count from .cwlprov.provenance_profile import ProvenanceProfile -from .errors import WorkflowException +from .errors import WorkflowException, WorkflowKillSwitch from .job import JobBase from .loghandler import _logger from .mutation import MutationManager @@ -102,6 +102,7 @@ def check_for_abstract_op(tool: CWLObjectType) -> None: runtime_context.mutation_manager = MutationManager() runtime_context.toplevel = True runtime_context.workflow_eval_lock = threading.Condition(threading.RLock()) + runtime_context.kill_switch = threading.Event() job_reqs: Optional[list[CWLObjectType]] = None if "https://w3id.org/cwl/cwl#requirements" in job_order_object: @@ -251,6 +252,11 @@ def run_jobs( WorkflowException, ): # pylint: disable=try-except-raise raise + except WorkflowKillSwitch as err: + _logger.error( + f"Workflow kill switch activated by [job {err.job_id}] " + f"because on-error={runtime_context.on_error}" + ) except Exception as err: logger.exception("Got workflow error") raise WorkflowException(str(err)) from err @@ -323,6 +329,11 @@ def _runner( except WorkflowException as err: _logger.exception(f"Got workflow error: {err}") self.exceptions.append(err) + except WorkflowKillSwitch as err: + _logger.error( + f"Workflow kill switch activated by [job {err.job_id}] " + f"because on-error={runtime_context.on_error}" + ) except Exception as err: # pylint: disable=broad-except _logger.exception(f"Got workflow error: {err}") self.exceptions.append(WorkflowException(str(err))) @@ -429,7 +440,13 @@ def run_jobs( logger: logging.Logger, runtime_context: RuntimeContext, ) -> None: - self.taskqueue: TaskQueue = TaskQueue(threading.Lock(), int(math.ceil(self.max_cores))) + if runtime_context.kill_switch is None: + runtime_context.kill_switch = threading.Event() + + self.taskqueue: TaskQueue = TaskQueue( + threading.Lock(), int(math.ceil(self.max_cores)), runtime_context.kill_switch + ) + try: jobiter = process.job(job_order_object, self.output_callback, runtime_context) @@ -457,9 +474,9 @@ def run_jobs( while self.taskqueue.in_flight > 0: self.wait_for_next_completion(runtime_context) self.run_job(None, runtime_context) - - runtime_context.workflow_eval_lock.release() finally: + if (lock := runtime_context.workflow_eval_lock) is not None: + lock.release() self.taskqueue.drain() self.taskqueue.join() diff --git a/cwltool/job.py b/cwltool/job.py index b360be25f..ac5507a7e 100644 --- a/cwltool/job.py +++ b/cwltool/job.py @@ -30,7 +30,7 @@ from .builder import Builder from .context import RuntimeContext from .cuda import cuda_check -from .errors import UnsupportedRequirement, WorkflowException +from .errors import UnsupportedRequirement, WorkflowException, WorkflowKillSwitch from .loghandler import _logger from .pathmapper import MapperEnt, PathMapper from .process import stage_files @@ -217,7 +217,9 @@ def _execute( runtime: list[str], env: MutableMapping[str, str], runtimeContext: RuntimeContext, - monitor_function: Optional[Callable[["subprocess.Popen[str]"], None]] = None, + monitor_function: Optional[ + Callable[["subprocess.Popen[str]", threading.Event], None] + ] = None, ) -> None: """Execute the tool, either directly or via script. @@ -282,6 +284,7 @@ def _execute( "{}".format(runtimeContext) ) outputs: CWLObjectType = {} + processStatus = "indeterminate" try: stdin_path = None if self.stdin is not None: @@ -319,6 +322,10 @@ def stderr_stdout_log_path( builder: Optional[Builder] = getattr(self, "builder", None) if builder is not None: job_script_contents = builder.build_job_script(commands) + if runtimeContext.kill_switch is None: + runtimeContext.kill_switch = kill_switch = threading.Event() + else: + kill_switch = runtimeContext.kill_switch rcode = _job_popen( commands, stdin_path=stdin_path, @@ -327,6 +334,7 @@ def stderr_stdout_log_path( env=env, cwd=self.outdir, make_job_dir=lambda: runtimeContext.create_outdir(), + kill_switch=kill_switch, job_script_contents=job_script_contents, timelimit=self.timelimit, name=self.name, @@ -347,7 +355,10 @@ def stderr_stdout_log_path( processStatus = "permanentFail" if processStatus != "success": - if rcode < 0: + if runtimeContext.kill_switch.is_set(): + processStatus = "killed" + return + elif rcode < 0: _logger.warning( "[job %s] was terminated by signal: %s", self.name, @@ -355,6 +366,9 @@ def stderr_stdout_log_path( ) else: _logger.warning("[job %s] exited with status: %d", self.name, rcode) + if runtimeContext.on_error == "kill": + runtimeContext.kill_switch.set() + raise WorkflowKillSwitch(self.name, rcode) if "listing" in self.generatefiles: if self.generatemapper: @@ -385,61 +399,69 @@ def stderr_stdout_log_path( except WorkflowException as err: _logger.error("[job %s] Job error:\n%s", self.name, str(err)) processStatus = "permanentFail" + except WorkflowKillSwitch: + processStatus = "permanentFail" + raise except Exception: _logger.exception("Exception while running job") processStatus = "permanentFail" - if ( - runtimeContext.research_obj is not None - and self.prov_obj is not None - and runtimeContext.process_run_id is not None - ): - # creating entities for the outputs produced by each step (in the provenance document) - self.prov_obj.record_process_end( - str(self.name), - runtimeContext.process_run_id, - outputs, - datetime.datetime.now(), - ) - if processStatus != "success": - _logger.warning("[job %s] completed %s", self.name, processStatus) - else: - _logger.info("[job %s] completed %s", self.name, processStatus) - - if _logger.isEnabledFor(logging.DEBUG): - _logger.debug("[job %s] outputs %s", self.name, json_dumps(outputs, indent=4)) - - if self.generatemapper is not None and runtimeContext.secret_store is not None: - # Delete any runtime-generated files containing secrets. - for _, p in self.generatemapper.items(): - if p.type == "CreateFile": - if runtimeContext.secret_store.has_secret(p.resolved): - host_outdir = self.outdir - container_outdir = self.builder.outdir - host_outdir_tgt = p.target - if p.target.startswith(container_outdir + "/"): - host_outdir_tgt = os.path.join( - host_outdir, p.target[len(container_outdir) + 1 :] - ) - os.remove(host_outdir_tgt) - - if runtimeContext.workflow_eval_lock is None: - raise WorkflowException("runtimeContext.workflow_eval_lock must not be None") - - if self.output_callback: - with runtimeContext.workflow_eval_lock: - self.output_callback(outputs, processStatus) - - if runtimeContext.rm_tmpdir and self.stagedir is not None and os.path.exists(self.stagedir): - _logger.debug( - "[job %s] Removing input staging directory %s", - self.name, - self.stagedir, - ) - shutil.rmtree(self.stagedir, True) + finally: + if ( + runtimeContext.research_obj is not None + and self.prov_obj is not None + and runtimeContext.process_run_id is not None + ): + # creating entities for the outputs produced by each step (in the provenance document) + self.prov_obj.record_process_end( + str(self.name), + runtimeContext.process_run_id, + outputs, + datetime.datetime.now(), + ) + if processStatus != "success": + _logger.warning("[job %s] completed %s", self.name, processStatus) + else: + _logger.info("[job %s] completed %s", self.name, processStatus) - if runtimeContext.rm_tmpdir: - _logger.debug("[job %s] Removing temporary directory %s", self.name, self.tmpdir) - shutil.rmtree(self.tmpdir, True) + if _logger.isEnabledFor(logging.DEBUG): + _logger.debug("[job %s] outputs %s", self.name, json_dumps(outputs, indent=4)) + + if self.generatemapper is not None and runtimeContext.secret_store is not None: + # Delete any runtime-generated files containing secrets. + for _, p in self.generatemapper.items(): + if p.type == "CreateFile": + if runtimeContext.secret_store.has_secret(p.resolved): + host_outdir = self.outdir + container_outdir = self.builder.outdir + host_outdir_tgt = p.target + if p.target.startswith(container_outdir + "/"): + host_outdir_tgt = os.path.join( + host_outdir, p.target[len(container_outdir) + 1 :] + ) + os.remove(host_outdir_tgt) + + if runtimeContext.workflow_eval_lock is None: + raise WorkflowException("runtimeContext.workflow_eval_lock must not be None") + + if self.output_callback: + with runtimeContext.workflow_eval_lock: + self.output_callback(outputs, processStatus) + + if ( + runtimeContext.rm_tmpdir + and self.stagedir is not None + and os.path.exists(self.stagedir) + ): + _logger.debug( + "[job %s] Removing input staging directory %s", + self.name, + self.stagedir, + ) + shutil.rmtree(self.stagedir, True) + + if runtimeContext.rm_tmpdir: + _logger.debug("[job %s] Removing temporary directory %s", self.name, self.tmpdir) + shutil.rmtree(self.tmpdir, True) @abstractmethod def _required_env(self) -> dict[str, str]: @@ -492,13 +514,14 @@ def prepare_environment( # Set on ourselves self.environment = env - def process_monitor(self, sproc: "subprocess.Popen[str]") -> None: - """Watch a process, logging its max memory usage.""" + def process_monitor(self, sproc: "subprocess.Popen[str]", kill_switch: threading.Event) -> None: + """Watch a process, logging its max memory usage or terminating it if kill_switch is activated.""" monitor = psutil.Process(sproc.pid) # Value must be list rather than integer to utilise pass-by-reference in python memory_usage: MutableSequence[Optional[int]] = [None] mem_tm: "Optional[Timer]" = None + ks_tm: "Optional[Timer]" = None def get_tree_mem_usage(memory_usage: MutableSequence[Optional[int]]) -> None: nonlocal mem_tm @@ -520,10 +543,28 @@ def get_tree_mem_usage(memory_usage: MutableSequence[Optional[int]]) -> None: if mem_tm is not None: mem_tm.cancel() + def monitor_kill_switch() -> None: + nonlocal ks_tm + if kill_switch.is_set(): + _logger.error("[job %s] terminating by kill switch", self.name) + if sproc.stdin: + sproc.stdin.close() + sproc.terminate() + else: + ks_tm = Timer(interval=1, function=monitor_kill_switch) + ks_tm.daemon = True + ks_tm.start() + + ks_tm = Timer(interval=1, function=monitor_kill_switch) + ks_tm.daemon = True + ks_tm.start() + mem_tm = Timer(interval=1, function=get_tree_mem_usage, args=(memory_usage,)) mem_tm.daemon = True mem_tm.start() + sproc.wait() + ks_tm.cancel() mem_tm.cancel() if memory_usage[0] is not None: _logger.info( @@ -835,14 +876,41 @@ def docker_monitor( cleanup_cidfile: bool, docker_exe: str, process: "subprocess.Popen[str]", + kill_switch: threading.Event, ) -> None: - """Record memory usage of the running Docker container.""" + """Record memory usage of the running Docker container. Terminate if kill_switch is activated.""" + ks_tm: "Optional[Timer]" = None + cid: Optional[str] = None + + def monitor_kill_switch() -> None: + nonlocal ks_tm + if kill_switch.is_set(): + _logger.error("[job %s] terminating by kill switch", self.name) + if process.stdin: + process.stdin.close() + if cid is not None: + kill_proc = subprocess.Popen( # nosec + [docker_exe, "kill", cid], shell=False # nosec + ) + try: + kill_proc.wait(timeout=10) + except subprocess.TimeoutExpired: + kill_proc.kill() + process.terminate() # Always terminate, even if we tried with the cidfile + else: + ks_tm = Timer(interval=1, function=monitor_kill_switch) + ks_tm.daemon = True + ks_tm.start() + + ks_tm = Timer(interval=1, function=monitor_kill_switch) + ks_tm.daemon = True + ks_tm.start() + # Todo: consider switching to `docker create` / `docker start` # instead of `docker run` as `docker create` outputs the container ID # to stdout, but the container is frozen, thus allowing us to start the # monitoring process without dealing with the cidfile or too-fast # container execution - cid: Optional[str] = None while cid is None: time.sleep(1) # This is needed to avoid a race condition where the job @@ -850,6 +918,7 @@ def docker_monitor( if process.returncode is None: process.poll() if process.returncode is not None: + ks_tm.cancel() if cleanup_cidfile: try: os.remove(cidfile) @@ -881,6 +950,9 @@ def docker_monitor( except OSError as exc: _logger.warning("Ignored error with %s stats: %s", docker_exe, exc) return + finally: + ks_tm.cancel() + max_mem_percent: float = 0.0 mem_percent: float = 0.0 with open(stats_file_name) as stats: @@ -911,10 +983,11 @@ def _job_popen( env: Mapping[str, str], cwd: str, make_job_dir: Callable[[], str], + kill_switch: threading.Event, job_script_contents: Optional[str] = None, timelimit: Optional[int] = None, name: Optional[str] = None, - monitor_function: Optional[Callable[["subprocess.Popen[str]"], None]] = None, + monitor_function: Optional[Callable[["subprocess.Popen[str]", "threading.Event"], None]] = None, default_stdout: Optional[Union[IO[bytes], TextIO]] = None, default_stderr: Optional[Union[IO[bytes], TextIO]] = None, ) -> int: @@ -969,7 +1042,7 @@ def terminate(): # type: () -> None tm.daemon = True tm.start() if monitor_function: - monitor_function(sproc) + monitor_function(sproc, kill_switch) rcode = sproc.wait() if tm is not None: @@ -1045,7 +1118,7 @@ def terminate(): # type: () -> None tm.daemon = True tm.start() if monitor_function: - monitor_function(sproc) + monitor_function(sproc, kill_switch) rcode = sproc.wait() diff --git a/cwltool/task_queue.py b/cwltool/task_queue.py index 59b1609e9..7606cf369 100644 --- a/cwltool/task_queue.py +++ b/cwltool/task_queue.py @@ -7,6 +7,7 @@ import threading from typing import Callable, Optional +from .errors import WorkflowKillSwitch from .loghandler import _logger @@ -33,7 +34,7 @@ class TaskQueue: in_flight: int = 0 """The number of tasks in the queue.""" - def __init__(self, lock: threading.Lock, thread_count: int): + def __init__(self, lock: threading.Lock, thread_count: int, kill_switch: threading.Event): """Create a new task queue using the specified lock and number of threads.""" self.thread_count = thread_count self.task_queue: queue.Queue[Optional[Callable[[], None]]] = queue.Queue( @@ -42,6 +43,7 @@ def __init__(self, lock: threading.Lock, thread_count: int): self.task_queue_threads = [] self.lock = lock self.error: Optional[BaseException] = None + self.kill_switch = kill_switch for _r in range(0, self.thread_count): t = threading.Thread(target=self._task_queue_func) @@ -51,10 +53,14 @@ def __init__(self, lock: threading.Lock, thread_count: int): def _task_queue_func(self) -> None: while True: task = self.task_queue.get() - if task is None: + if task is None or self.kill_switch.is_set(): return try: task() + except WorkflowKillSwitch: + self.kill_switch.set() + self.drain() + break except BaseException as e: # noqa: B036 _logger.exception("Unhandled exception running task", exc_info=e) self.error = e @@ -92,7 +98,7 @@ def add( try: if unlock is not None: unlock.release() - if check_done is not None and check_done.is_set(): + if (check_done is not None and check_done.is_set()) or self.kill_switch.is_set(): with self.lock: self.in_flight -= 1 return diff --git a/cwltool/workflow.py b/cwltool/workflow.py index 3bf32251f..8898a489b 100644 --- a/cwltool/workflow.py +++ b/cwltool/workflow.py @@ -2,6 +2,7 @@ import datetime import functools import logging +import os import random from collections.abc import Mapping, MutableMapping, MutableSequence from typing import Callable, Optional, cast @@ -401,12 +402,13 @@ def receive_output( processStatus: str, ) -> None: output = {} - for i in self.tool["outputs"]: - field = shortname(i["id"]) - if field in jobout: - output[i["id"]] = jobout[field] - else: - processStatus = "permanentFail" + if processStatus != "killed": + for i in self.tool["outputs"]: + field = shortname(i["id"]) + if field in jobout: + output[i["id"]] = jobout[field] + else: + processStatus = "permanentFail" output_callback(output, processStatus) def job( @@ -451,3 +453,13 @@ def job( def visit(self, op: Callable[[CommentedMap], None]) -> None: self.embedded_tool.visit(op) + + def __repr__(self) -> str: + """Return a non-expression string representation of the object instance.""" + if "#" in self.id: + wf_file, step_id = self.id.rsplit("#", 1) + step_name = "#".join([os.path.basename(wf_file), step_id]) + else: + step_name = self.id + + return f"<{self.__class__.__name__} [{step_name}] at {hex(id(self))}>" diff --git a/cwltool/workflow_job.py b/cwltool/workflow_job.py index 6cd0b2e7c..70df00277 100644 --- a/cwltool/workflow_job.py +++ b/cwltool/workflow_job.py @@ -66,6 +66,10 @@ def job( yield from self.step.job(joborder, output_callback, runtimeContext) + def __repr__(self) -> str: + """Return a non-expression string representation of the object instance.""" + return f"<{self.__class__.__name__} [{self.name}] at {hex(id(self))}>" + class ReceiveScatterOutput: """Produced by the scatter generators.""" @@ -89,7 +93,9 @@ def completed(self) -> int: """The number of completed internal jobs.""" return len(self._completed) - def receive_scatter_output(self, index: int, jobout: CWLObjectType, processStatus: str) -> None: + def receive_scatter_output( + self, index: int, runtimeContext: RuntimeContext, jobout: CWLObjectType, processStatus: str + ) -> None: """Record the results of a scatter operation.""" for key, val in jobout.items(): self.dest[key][index] = val @@ -102,6 +108,8 @@ def receive_scatter_output(self, index: int, jobout: CWLObjectType, processStatu if processStatus != "success": if self.processStatus != "permanentFail": self.processStatus = processStatus + if runtimeContext.on_error == "kill": + self.output_callback(self.dest, self.processStatus) if index not in self._completed: self._completed.add(index) @@ -130,10 +138,11 @@ def parallel_steps( rc: ReceiveScatterOutput, runtimeContext: RuntimeContext, ) -> JobsGeneratorType: + """Yield scatter jobs (or None if there's no work to do) until all scatter jobs complete.""" while rc.completed < rc.total: made_progress = False for index, step in enumerate(steps): - if getdefault(runtimeContext.on_error, "stop") == "stop" and rc.processStatus not in ( + if runtimeContext.on_error != "continue" and rc.processStatus not in ( "success", "skipped", ): @@ -142,9 +151,10 @@ def parallel_steps( continue try: for j in step: - if getdefault( - runtimeContext.on_error, "stop" - ) == "stop" and rc.processStatus not in ("success", "skipped"): + if runtimeContext.on_error != "continue" and rc.processStatus not in ( + "success", + "skipped", + ): break if j is not None: made_progress = True @@ -156,7 +166,7 @@ def parallel_steps( except WorkflowException as exc: _logger.error("Cannot make scatter job: %s", str(exc)) _logger.debug("", exc_info=True) - rc.receive_scatter_output(index, {}, "permanentFail") + rc.receive_scatter_output(index, runtimeContext, {}, "permanentFail") if not made_progress and rc.completed < rc.total: yield None @@ -185,7 +195,7 @@ def nested_crossproduct_scatter( if len(scatter_keys) == 1: if runtimeContext.postScatterEval is not None: sjob = runtimeContext.postScatterEval(sjob) - curriedcallback = functools.partial(rc.receive_scatter_output, index) + curriedcallback = functools.partial(rc.receive_scatter_output, index, runtimeContext) if sjob is not None: steps.append(process.job(sjob, curriedcallback, runtimeContext)) else: @@ -197,7 +207,7 @@ def nested_crossproduct_scatter( process, sjob, scatter_keys[1:], - functools.partial(rc.receive_scatter_output, index), + functools.partial(rc.receive_scatter_output, index, runtimeContext), runtimeContext, ) ) @@ -257,7 +267,9 @@ def _flat_crossproduct_scatter( if len(scatter_keys) == 1: if runtimeContext.postScatterEval is not None: sjob = runtimeContext.postScatterEval(sjob) - curriedcallback = functools.partial(callback.receive_scatter_output, put) + curriedcallback = functools.partial( + callback.receive_scatter_output, put, runtimeContext + ) if sjob is not None: steps.append(process.job(sjob, curriedcallback, runtimeContext)) else: @@ -307,7 +319,7 @@ def dotproduct_scatter( if runtimeContext.postScatterEval is not None: sjobo = runtimeContext.postScatterEval(sjobo) - curriedcallback = functools.partial(rc.receive_scatter_output, index) + curriedcallback = functools.partial(rc.receive_scatter_output, index, runtimeContext) if sjobo is not None: steps.append(process.job(sjobo, curriedcallback, runtimeContext)) else: @@ -548,16 +560,17 @@ def receive_output( jobout: CWLObjectType, processStatus: str, ) -> None: - for i in outputparms: - if "id" in i: - iid = cast(str, i["id"]) - if iid in jobout: - self.state[iid] = WorkflowStateItem(i, jobout[iid], processStatus) - else: - _logger.error("[%s] Output is missing expected field %s", step.name, iid) - processStatus = "permanentFail" if _logger.isEnabledFor(logging.DEBUG): _logger.debug("[%s] produced output %s", step.name, json_dumps(jobout, indent=4)) + if processStatus != "killed": + for i in outputparms: + if "id" in i: + iid = cast(str, i["id"]) + if iid in jobout: + self.state[iid] = WorkflowStateItem(i, jobout[iid], processStatus) + else: + _logger.error("[%s] Output is missing expected field %s", step.name, iid) + processStatus = "permanentFail" if processStatus not in ("success", "skipped"): if self.processStatus != "permanentFail": @@ -804,10 +817,7 @@ def job( self.made_progress = False for step in self.steps: - if ( - getdefault(runtimeContext.on_error, "stop") == "stop" - and self.processStatus != "success" - ): + if runtimeContext.on_error != "continue" and self.processStatus != "success": break if not step.submitted: @@ -822,7 +832,7 @@ def job( try: for newjob in step.iterable: if ( - getdefault(runtimeContext.on_error, "stop") == "stop" + runtimeContext.on_error != "continue" and self.processStatus != "success" ): break @@ -850,6 +860,10 @@ def job( # depends which one comes first. All steps are completed # or all outputs have been produced. + def __repr__(self) -> str: + """Return a non-expression string representation of the object instance.""" + return f"<{self.__class__.__name__} [{self.name}] at {hex(id(self))}>" + class WorkflowJobLoopStep: """Generated for each step in Workflow.steps() containing a `loop` directive.""" diff --git a/tests/process_roulette.cwl b/tests/process_roulette.cwl new file mode 100644 index 000000000..e1b02ff31 --- /dev/null +++ b/tests/process_roulette.cwl @@ -0,0 +1,35 @@ +#!/usr/bin/env cwl-runner + +cwlVersion: v1.2 +class: CommandLineTool + + +doc: | + This tool selects a random process whose associated command matches + search_str, terminates it, and reports the PID of the terminated process. + The search_str supports regex. Example search_strs: + - "sleep" + - "sleep 33" + - "sleep [0-9]+" + + +baseCommand: [ 'bash', '-c' ] +arguments: + - | + sleep $(inputs.delay) + pid=\$(ps -ef | grep '$(inputs.search_str)' | grep -v grep | awk '{print $2}' | shuf | head -n 1) + echo "$pid" | tee >(xargs kill -SIGTERM) +inputs: + search_str: + type: string + delay: + type: int? + default: 3 +stdout: "pid.txt" +outputs: + pid: + type: string + outputBinding: + glob: pid.txt + loadContents: true + outputEval: $(self[0].contents) \ No newline at end of file diff --git a/tests/test_parallel.py b/tests/test_parallel.py index 8c86d41fc..93f735adf 100644 --- a/tests/test_parallel.py +++ b/tests/test_parallel.py @@ -1,9 +1,12 @@ import json +import math +import time from pathlib import Path +from typing import Union, cast from cwltool.context import RuntimeContext from cwltool.executors import MultithreadedJobExecutor -from cwltool.factory import Factory +from cwltool.factory import Factory, WorkflowStatus from .util import get_data, needs_docker @@ -29,3 +32,42 @@ def test_scattered_workflow() -> None: echo = factory.make(get_data(test_file)) with open(get_data(job_file)) as job: assert echo(**json.load(job)) == {"out": ["foo one three", "foo two four"]} + + +def test_on_error_kill() -> None: + test_file = "tests/wf/on-error_kill.cwl" + + def selectResources( + request: dict[str, Union[int, float]], _: RuntimeContext + ) -> dict[str, Union[int, float]]: + # Remove the "one job per core" resource constraint so that + # parallel jobs aren't withheld on machines with few cores + return { + "cores": 0, + "ram": math.ceil(request["ramMin"]), # default + "tmpdirSize": math.ceil(request["tmpdirMin"]), # default + "outdirSize": math.ceil(request["outdirMin"]), # default + } + + runtime_context = RuntimeContext() + runtime_context.on_error = "kill" + runtime_context.select_resources = selectResources + factory = Factory(MultithreadedJobExecutor(), None, runtime_context) + ks_test = factory.make(get_data(test_file)) + + # arbitrary test values + sleep_time = 3333 # a "sufficiently large" timeout + n_sleepers = 4 + start_time = 0.0 + + try: + start_time = time.time() + ks_test( + sleep_time=sleep_time, + n_sleepers=n_sleepers, + ) + except WorkflowStatus as e: + end_time = time.time() + output = cast(dict[str, list[bool]], e.out)["roulette_mask"] + assert len(output) == n_sleepers and sum(output) == 1 + assert end_time - start_time < sleep_time diff --git a/tests/wf/on-error_kill.cwl b/tests/wf/on-error_kill.cwl new file mode 100644 index 000000000..82846a9de --- /dev/null +++ b/tests/wf/on-error_kill.cwl @@ -0,0 +1,86 @@ +#!/usr/bin/env cwl-runner + +cwlVersion: v1.2 +class: Workflow +requirements: + ScatterFeatureRequirement: {} + InlineJavascriptRequirement: {} + StepInputExpressionRequirement: {} + + +doc: | + This workflow tests the optional argument --on-error kill. + MultithreadedJobExecutor() or --parallel should be used. + A successful run should: + 1) Finish in (much) less than sleep_time seconds. + 2) Return outputs produced by successful steps. + + +inputs: + sleep_time: {type: int, default: 3333} + n_sleepers: {type: int, default: 4} + + +steps: + roulette: + doc: | + This step produces a boolean array with exactly one true value + whose index is assigned at random. + in: {n_sleepers: n_sleepers} + out: [mask] + run: + class: ExpressionTool + inputs: {n_sleepers: {type: int}} + outputs: {mask: {type: "boolean[]"}} + expression: | + ${ + var mask = Array(inputs.n_sleepers).fill(false); + var spin = Math.floor(Math.random() * inputs.n_sleepers); + mask[spin] = true; + return {"mask": mask} + } + + scatter_step: + doc: | + This step starts several parallel jobs that each sleep for + sleep_time seconds. The job whose k_mask value is true will + self-terminate early, thereby activating the kill switch. + in: + time: sleep_time + k_mask: roulette/mask + scatter: k_mask + out: [placeholder] + run: + class: CommandLineTool + requirements: + ToolTimeLimit: + timelimit: '${return inputs.k_mask ? 5 : inputs.time + 5}' # 5 is an arbitrary value + baseCommand: sleep + inputs: + time: {type: int, inputBinding: {position: 1}} + k_mask: {type: boolean} + outputs: + placeholder: {type: string, outputBinding: {outputEval: $("foo")}} + + dangling_step: + doc: | + This step should never run. It confirms that additional jobs aren't + submitted and allowed to run to completion after the kill switch has + been set. The input force_downstream_order ensures that this step + doesn't run before scatter_step completes. + in: + force_downstream_order: scatter_step/placeholder + time: sleep_time + out: [] + run: + class: CommandLineTool + baseCommand: sleep + inputs: + time: {type: int, inputBinding: {position: 1}} + outputs: {} + + +outputs: + roulette_mask: + type: boolean[] + outputSource: roulette/mask