diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index e893d860e..9a8ee8c40 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -4,6 +4,7 @@ import os import os.path import posixpath +import signal import subprocess import sys import time @@ -97,6 +98,46 @@ def noop(_: str): pass +class TerminationSignal(RuntimeError): # noqa: N818 + def __init__(self, signal): + self.signal = signal + super().__init__("Received termination signal", signal) + + def __repr__(self): + return f"{self.__class__.__name__}({self.signal})" + + +if sys.platform == "win32": + SIGINT = signal.CTRL_C_EVENT +else: + SIGINT = signal.SIGINT + + +def graceful_shutdown( + proc: subprocess.Popen, + interrupt: bool = True, + interrupt_timeout: Optional[int] = None, + terminate_timeout: Optional[int] = None, +) -> None: + if interrupt: + logger.info("sending interrupt signal to the process %s", proc.pid) + proc.send_signal(SIGINT) + + logger.info("waiting for the process %s to finish", proc.pid) + try: + proc.wait(interrupt_timeout) + except subprocess.TimeoutExpired: + logger.info( + "timed out waiting, sending terminate signal to the process %s", proc.pid + ) + proc.terminate() + try: + proc.wait(terminate_timeout) + except subprocess.TimeoutExpired: + logger.info("timed out waiting, killing the process %s", proc.pid) + proc.kill() + + def _process_stream(stream: "IO[bytes]", callback: Callable[[str], None]) -> None: buffer = b"" while byt := stream.read(1): # Read one byte at a time @@ -1493,6 +1534,9 @@ def query( output_hook: Callable[[str], None] = noop, params: Optional[dict[str, str]] = None, job_id: Optional[str] = None, + interrupt_timeout: Optional[int] = None, + terminate_timeout: Optional[int] = None, + send_interrupt: bool = False, ) -> None: cmd = [python_executable, "-c", query_script] env = dict(env or os.environ) @@ -1506,13 +1550,48 @@ def query( if capture_output: popen_kwargs = {"stdout": subprocess.PIPE, "stderr": subprocess.STDOUT} + def signal_handler(sig: int, frame: Any) -> NoReturn: + raise TerminationSignal(sig) + + orig_handler = signal.getsignal(signal.SIGTERM) + signal.signal(signal.SIGTERM, signal_handler) + + thread = None with subprocess.Popen(cmd, env=env, **popen_kwargs) as proc: # noqa: S603 - if capture_output: - args = (proc.stdout, output_hook) - thread = Thread(target=_process_stream, args=args, daemon=True) - thread.start() - thread.join() # wait for the reader thread + logger.info("Running script with PID %s", proc.pid) + try: + if capture_output: + args = (proc.stdout, output_hook) + thread = Thread(target=_process_stream, args=args, daemon=True) + thread.start() + + proc.wait() + except (KeyboardInterrupt, TerminationSignal) as exc: + if orig_handler is not None: + signal.signal(signal.SIGTERM, orig_handler) + + logging.info("Terminating process %s, received %r", proc.pid, exc) + + # If process is running in the foreground, TTY sends signals to all + # processes in the process group. So, by default, we don't send + # interrupt signal to the child process again. + # If interrupt needs to be sent, this can be enabled by setting + # send_interrupt to True. + graceful_shutdown( + proc, + send_interrupt, + interrupt_timeout=interrupt_timeout, + terminate_timeout=terminate_timeout, + ) + if proc.returncode: + raise + finally: + if orig_handler is not None: + signal.signal(signal.SIGTERM, orig_handler) + if thread: + thread.join() + logging.info("Process %s exited with return code %s", proc.pid, proc.returncode) if proc.returncode == QUERY_SCRIPT_CANCELED_EXIT_CODE: raise QueryScriptCancelError( "Query script was canceled by user",