diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index e893d860e..9b33090f1 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -4,6 +4,7 @@ import os import os.path import posixpath +import signal import subprocess import sys import time @@ -97,6 +98,47 @@ def noop(_: str): pass +class TerminationSignal(RuntimeError): # noqa: N818 + def __init__(self, signal): + self.signal = signal + super().__init__("Received termination signal", signal) + + def __repr__(self): + return f"{self.__class__.__name__}({self.signal})" + + +if sys.platform == "win32": + SIGINT = signal.CTRL_C_EVENT +else: + SIGINT = signal.SIGINT + + +def shutdown_process( + proc: subprocess.Popen, + interrupt_timeout: Optional[int] = None, + terminate_timeout: Optional[int] = None, +) -> int: + """Shut down the process gracefully with SIGINT -> SIGTERM -> SIGKILL.""" + + logger.info("sending interrupt signal to the process %s", proc.pid) + proc.send_signal(SIGINT) + + logger.info("waiting for the process %s to finish", proc.pid) + try: + return proc.wait(interrupt_timeout) + except subprocess.TimeoutExpired: + logger.info( + "timed out waiting, sending terminate signal to the process %s", proc.pid + ) + proc.terminate() + try: + return proc.wait(terminate_timeout) + except subprocess.TimeoutExpired: + logger.info("timed out waiting, killing the process %s", proc.pid) + proc.kill() + return proc.wait() + + def _process_stream(stream: "IO[bytes]", callback: Callable[[str], None]) -> None: buffer = b"" while byt := stream.read(1): # Read one byte at a time @@ -1493,6 +1535,8 @@ def query( output_hook: Callable[[str], None] = noop, params: Optional[dict[str, str]] = None, job_id: Optional[str] = None, + interrupt_timeout: Optional[int] = None, + terminate_timeout: Optional[int] = None, ) -> None: cmd = [python_executable, "-c", query_script] env = dict(env or os.environ) @@ -1506,13 +1550,48 @@ def query( if capture_output: popen_kwargs = {"stdout": subprocess.PIPE, "stderr": subprocess.STDOUT} + def raise_termination_signal(sig: int, _: Any) -> NoReturn: + raise TerminationSignal(sig) + + thread: Optional[Thread] = None with subprocess.Popen(cmd, env=env, **popen_kwargs) as proc: # noqa: S603 - if capture_output: - args = (proc.stdout, output_hook) - thread = Thread(target=_process_stream, args=args, daemon=True) - thread.start() - thread.join() # wait for the reader thread + logger.info("Starting process %s", proc.pid) + + orig_sigint_handler = signal.getsignal(signal.SIGINT) + # ignore SIGINT in the main process. + # In the terminal, SIGINTs are received by all the processes in + # the foreground process group, so the script will receive the signal too. + # (If we forward the signal to the child, it will receive it twice.) + signal.signal(signal.SIGINT, signal.SIG_IGN) + orig_sigterm_handler = signal.getsignal(signal.SIGTERM) + signal.signal(signal.SIGTERM, raise_termination_signal) + try: + if capture_output: + args = (proc.stdout, output_hook) + thread = Thread(target=_process_stream, args=args, daemon=True) + thread.start() + + proc.wait() + except TerminationSignal as exc: + signal.signal(signal.SIGTERM, orig_sigterm_handler) + signal.signal(signal.SIGINT, orig_sigint_handler) + logging.info("Shutting down process %s, received %r", proc.pid, exc) + # Rather than forwarding the signal to the child, we try to shut it down + # gracefully. This is because we consider the script to be interactive + # and special, so we give it time to cleanup before exiting. + shutdown_process(proc, interrupt_timeout, terminate_timeout) + if proc.returncode: + raise QueryScriptCancelError( + "Query script was canceled by user", return_code=proc.returncode + ) from exc + finally: + signal.signal(signal.SIGTERM, orig_sigterm_handler) + signal.signal(signal.SIGINT, orig_sigint_handler) + if thread: + thread.join() # wait for the reader thread + + logging.info("Process %s exited with return code %s", proc.pid, proc.returncode) if proc.returncode == QUERY_SCRIPT_CANCELED_EXIT_CODE: raise QueryScriptCancelError( "Query script was canceled by user", diff --git a/tests/func/test_query.py b/tests/func/test_query.py index f16a9db71..c25183ae6 100644 --- a/tests/func/test_query.py +++ b/tests/func/test_query.py @@ -1,12 +1,19 @@ +import os.path +import signal import sys +from multiprocessing.pool import ExceptionWithTraceback # type: ignore[attr-defined] from textwrap import dedent import cloudpickle +import multiprocess import pytest +from datachain.catalog import Catalog from datachain.cli import query from datachain.data_storage import AbstractDBMetastore, JobQueryType, JobStatus +from datachain.error import QueryScriptCancelError from datachain.job import Job +from tests.utils import wait_for_condition @pytest.fixture @@ -102,3 +109,75 @@ def test_query_cli(cloud_test_catalog_tmpfile, tmp_path, catalog_info_filepath, assert job.params == {"url": src_uri} assert job.metrics == {"count": 7} assert job.python_version == f"{sys.version_info.major}.{sys.version_info.minor}" + + +if sys.platform == "win32": + SIGKILL = signal.SIGTERM +else: + SIGKILL = signal.SIGKILL + + +@pytest.mark.skipif(sys.platform == "win32", reason="Windows does not have SIGTERM") +@pytest.mark.parametrize( + "setup,expected_return_code", + [ + ("", -signal.SIGINT), + ("signal.signal(signal.SIGINT, signal.SIG_IGN)", -signal.SIGTERM), + ( + """\ +signal.signal(signal.SIGINT, signal.SIG_IGN) +signal.signal(signal.SIGTERM, signal.SIG_IGN) +""", + -SIGKILL, + ), + ], +) +def test_shutdown_on_sigterm(tmp_dir, request, catalog, setup, expected_return_code): + query = f"""\ +import os, pathlib, signal, sys, time + +pathlib.Path("ready").touch(exist_ok=False) +{setup} +time.sleep(10) +""" + + def apply(f, args, kwargs): + return f(*args, **kwargs) + + def func(ms_params, wh_params, init_params, q): + catalog = Catalog(apply(*ms_params), apply(*wh_params), **init_params) + try: + catalog.query(query, interrupt_timeout=0.5, terminate_timeout=0.5) + except Exception as e: # noqa: BLE001 + q.put(ExceptionWithTraceback(e, e.__traceback__)) + else: + q.put(None) + + mp_ctx = multiprocess.get_context("spawn") + q = mp_ctx.Queue() + p = mp_ctx.Process( + target=func, + args=( + catalog.metastore.clone_params(), + catalog.warehouse.clone_params(), + catalog.get_init_params(), + q, + ), + ) + p.start() + request.addfinalizer(p.kill) + + def is_ready(): + assert p.is_alive(), "Process is dead" + return os.path.exists("ready") + + # make sure the process is running before we send the signal + wait_for_condition(is_ready, "script to start", timeout=5) + + os.kill(p.pid, signal.SIGTERM) + p.join(timeout=3) # might take as long as 1 second to complete shutdown_process + assert not p.exitcode + + e = q.get_nowait() + assert isinstance(e, QueryScriptCancelError) + assert e.return_code == expected_return_code diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index 0cde686eb..7565bb27f 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -6,7 +6,10 @@ import pytest -from datachain.catalog.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE +from datachain.catalog.catalog import ( + QUERY_SCRIPT_CANCELED_EXIT_CODE, + TerminationSignal, +) from datachain.error import QueryScriptCancelError, QueryScriptRunError @@ -63,3 +66,15 @@ def test_non_zero_exitcode(catalog, mock_popen): catalog.query("pass") assert e.value.return_code == 1 assert "Query script exited with error code 1" in str(e.value) + + +def test_shutdown_process_on_sigterm(mocker, catalog, mock_popen): + mock_popen.returncode = -2 + mock_popen.wait.side_effect = [TerminationSignal(15)] + m = mocker.patch("datachain.catalog.catalog.shutdown_process", return_value=-2) + + with pytest.raises(QueryScriptCancelError) as e: + catalog.query("pass", interrupt_timeout=0.1, terminate_timeout=0.2) + assert e.value.return_code == -2 + assert "Query script was canceled by user" in str(e.value) + m.assert_called_once_with(mock_popen, 0.1, 0.2)