Skip to content

Commit c0f23b4

Browse files
authored
query: interrupt query script on SIGTERM (#858)
1 parent 2520cae commit c0f23b4

File tree

3 files changed

+179
-6
lines changed

3 files changed

+179
-6
lines changed

src/datachain/catalog/catalog.py

Lines changed: 84 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import os.path
66
import posixpath
7+
import signal
78
import subprocess
89
import sys
910
import time
@@ -97,6 +98,47 @@ def noop(_: str):
9798
pass
9899

99100

101+
class TerminationSignal(RuntimeError): # noqa: N818
102+
def __init__(self, signal):
103+
self.signal = signal
104+
super().__init__("Received termination signal", signal)
105+
106+
def __repr__(self):
107+
return f"{self.__class__.__name__}({self.signal})"
108+
109+
110+
if sys.platform == "win32":
111+
SIGINT = signal.CTRL_C_EVENT
112+
else:
113+
SIGINT = signal.SIGINT
114+
115+
116+
def shutdown_process(
117+
proc: subprocess.Popen,
118+
interrupt_timeout: Optional[int] = None,
119+
terminate_timeout: Optional[int] = None,
120+
) -> int:
121+
"""Shut down the process gracefully with SIGINT -> SIGTERM -> SIGKILL."""
122+
123+
logger.info("sending interrupt signal to the process %s", proc.pid)
124+
proc.send_signal(SIGINT)
125+
126+
logger.info("waiting for the process %s to finish", proc.pid)
127+
try:
128+
return proc.wait(interrupt_timeout)
129+
except subprocess.TimeoutExpired:
130+
logger.info(
131+
"timed out waiting, sending terminate signal to the process %s", proc.pid
132+
)
133+
proc.terminate()
134+
try:
135+
return proc.wait(terminate_timeout)
136+
except subprocess.TimeoutExpired:
137+
logger.info("timed out waiting, killing the process %s", proc.pid)
138+
proc.kill()
139+
return proc.wait()
140+
141+
100142
def _process_stream(stream: "IO[bytes]", callback: Callable[[str], None]) -> None:
101143
buffer = b""
102144
while byt := stream.read(1): # Read one byte at a time
@@ -1493,6 +1535,8 @@ def query(
14931535
output_hook: Callable[[str], None] = noop,
14941536
params: Optional[dict[str, str]] = None,
14951537
job_id: Optional[str] = None,
1538+
interrupt_timeout: Optional[int] = None,
1539+
terminate_timeout: Optional[int] = None,
14961540
) -> None:
14971541
cmd = [python_executable, "-c", query_script]
14981542
env = dict(env or os.environ)
@@ -1506,13 +1550,48 @@ def query(
15061550
if capture_output:
15071551
popen_kwargs = {"stdout": subprocess.PIPE, "stderr": subprocess.STDOUT}
15081552

1553+
def raise_termination_signal(sig: int, _: Any) -> NoReturn:
1554+
raise TerminationSignal(sig)
1555+
1556+
thread: Optional[Thread] = None
15091557
with subprocess.Popen(cmd, env=env, **popen_kwargs) as proc: # noqa: S603
1510-
if capture_output:
1511-
args = (proc.stdout, output_hook)
1512-
thread = Thread(target=_process_stream, args=args, daemon=True)
1513-
thread.start()
1514-
thread.join() # wait for the reader thread
1558+
logger.info("Starting process %s", proc.pid)
1559+
1560+
orig_sigint_handler = signal.getsignal(signal.SIGINT)
1561+
# ignore SIGINT in the main process.
1562+
# In the terminal, SIGINTs are received by all the processes in
1563+
# the foreground process group, so the script will receive the signal too.
1564+
# (If we forward the signal to the child, it will receive it twice.)
1565+
signal.signal(signal.SIGINT, signal.SIG_IGN)
15151566

1567+
orig_sigterm_handler = signal.getsignal(signal.SIGTERM)
1568+
signal.signal(signal.SIGTERM, raise_termination_signal)
1569+
try:
1570+
if capture_output:
1571+
args = (proc.stdout, output_hook)
1572+
thread = Thread(target=_process_stream, args=args, daemon=True)
1573+
thread.start()
1574+
1575+
proc.wait()
1576+
except TerminationSignal as exc:
1577+
signal.signal(signal.SIGTERM, orig_sigterm_handler)
1578+
signal.signal(signal.SIGINT, orig_sigint_handler)
1579+
logging.info("Shutting down process %s, received %r", proc.pid, exc)
1580+
# Rather than forwarding the signal to the child, we try to shut it down
1581+
# gracefully. This is because we consider the script to be interactive
1582+
# and special, so we give it time to cleanup before exiting.
1583+
shutdown_process(proc, interrupt_timeout, terminate_timeout)
1584+
if proc.returncode:
1585+
raise QueryScriptCancelError(
1586+
"Query script was canceled by user", return_code=proc.returncode
1587+
) from exc
1588+
finally:
1589+
signal.signal(signal.SIGTERM, orig_sigterm_handler)
1590+
signal.signal(signal.SIGINT, orig_sigint_handler)
1591+
if thread:
1592+
thread.join() # wait for the reader thread
1593+
1594+
logging.info("Process %s exited with return code %s", proc.pid, proc.returncode)
15161595
if proc.returncode == QUERY_SCRIPT_CANCELED_EXIT_CODE:
15171596
raise QueryScriptCancelError(
15181597
"Query script was canceled by user",

tests/func/test_query.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
1+
import os.path
2+
import signal
13
import sys
4+
from multiprocessing.pool import ExceptionWithTraceback # type: ignore[attr-defined]
25
from textwrap import dedent
36

47
import cloudpickle
8+
import multiprocess
59
import pytest
610

11+
from datachain.catalog import Catalog
712
from datachain.cli import query
813
from datachain.data_storage import AbstractDBMetastore, JobQueryType, JobStatus
14+
from datachain.error import QueryScriptCancelError
915
from datachain.job import Job
16+
from tests.utils import wait_for_condition
1017

1118

1219
@pytest.fixture
@@ -102,3 +109,75 @@ def test_query_cli(cloud_test_catalog_tmpfile, tmp_path, catalog_info_filepath,
102109
assert job.params == {"url": src_uri}
103110
assert job.metrics == {"count": 7}
104111
assert job.python_version == f"{sys.version_info.major}.{sys.version_info.minor}"
112+
113+
114+
if sys.platform == "win32":
115+
SIGKILL = signal.SIGTERM
116+
else:
117+
SIGKILL = signal.SIGKILL
118+
119+
120+
@pytest.mark.skipif(sys.platform == "win32", reason="Windows does not have SIGTERM")
121+
@pytest.mark.parametrize(
122+
"setup,expected_return_code",
123+
[
124+
("", -signal.SIGINT),
125+
("signal.signal(signal.SIGINT, signal.SIG_IGN)", -signal.SIGTERM),
126+
(
127+
"""\
128+
signal.signal(signal.SIGINT, signal.SIG_IGN)
129+
signal.signal(signal.SIGTERM, signal.SIG_IGN)
130+
""",
131+
-SIGKILL,
132+
),
133+
],
134+
)
135+
def test_shutdown_on_sigterm(tmp_dir, request, catalog, setup, expected_return_code):
136+
query = f"""\
137+
import os, pathlib, signal, sys, time
138+
139+
pathlib.Path("ready").touch(exist_ok=False)
140+
{setup}
141+
time.sleep(10)
142+
"""
143+
144+
def apply(f, args, kwargs):
145+
return f(*args, **kwargs)
146+
147+
def func(ms_params, wh_params, init_params, q):
148+
catalog = Catalog(apply(*ms_params), apply(*wh_params), **init_params)
149+
try:
150+
catalog.query(query, interrupt_timeout=0.5, terminate_timeout=0.5)
151+
except Exception as e: # noqa: BLE001
152+
q.put(ExceptionWithTraceback(e, e.__traceback__))
153+
else:
154+
q.put(None)
155+
156+
mp_ctx = multiprocess.get_context("spawn")
157+
q = mp_ctx.Queue()
158+
p = mp_ctx.Process(
159+
target=func,
160+
args=(
161+
catalog.metastore.clone_params(),
162+
catalog.warehouse.clone_params(),
163+
catalog.get_init_params(),
164+
q,
165+
),
166+
)
167+
p.start()
168+
request.addfinalizer(p.kill)
169+
170+
def is_ready():
171+
assert p.is_alive(), "Process is dead"
172+
return os.path.exists("ready")
173+
174+
# make sure the process is running before we send the signal
175+
wait_for_condition(is_ready, "script to start", timeout=5)
176+
177+
os.kill(p.pid, signal.SIGTERM)
178+
p.join(timeout=3) # might take as long as 1 second to complete shutdown_process
179+
assert not p.exitcode
180+
181+
e = q.get_nowait()
182+
assert isinstance(e, QueryScriptCancelError)
183+
assert e.return_code == expected_return_code

tests/unit/test_query.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66

77
import pytest
88

9-
from datachain.catalog.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE
9+
from datachain.catalog.catalog import (
10+
QUERY_SCRIPT_CANCELED_EXIT_CODE,
11+
TerminationSignal,
12+
)
1013
from datachain.error import QueryScriptCancelError, QueryScriptRunError
1114

1215

@@ -63,3 +66,15 @@ def test_non_zero_exitcode(catalog, mock_popen):
6366
catalog.query("pass")
6467
assert e.value.return_code == 1
6568
assert "Query script exited with error code 1" in str(e.value)
69+
70+
71+
def test_shutdown_process_on_sigterm(mocker, catalog, mock_popen):
72+
mock_popen.returncode = -2
73+
mock_popen.wait.side_effect = [TerminationSignal(15)]
74+
m = mocker.patch("datachain.catalog.catalog.shutdown_process", return_value=-2)
75+
76+
with pytest.raises(QueryScriptCancelError) as e:
77+
catalog.query("pass", interrupt_timeout=0.1, terminate_timeout=0.2)
78+
assert e.value.return_code == -2
79+
assert "Query script was canceled by user" in str(e.value)
80+
m.assert_called_once_with(mock_popen, 0.1, 0.2)

0 commit comments

Comments
 (0)