Skip to content

Commit d5c10f6

Browse files
committed
query: interrupt query script gracefully on SIGTERM
This now supports gracefully terminating script on SIGTERM. Also added logs.
1 parent ef23a20 commit d5c10f6

File tree

3 files changed

+98
-7
lines changed

3 files changed

+98
-7
lines changed

src/datachain/catalog/catalog.py

Lines changed: 78 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import os.path
66
import posixpath
7+
import signal
78
import subprocess
89
import sys
910
import time
@@ -97,6 +98,46 @@ def noop(_: str):
9798
pass
9899

99100

101+
class TerminationSignal(RuntimeError): # noqa: N818
102+
def __init__(self, signal):
103+
self.signal = signal
104+
super().__init__("Received termination signal", signal)
105+
106+
def __repr__(self):
107+
return f"{self.__class__.__name__}({self.signal})"
108+
109+
110+
if sys.platform == "win32":
111+
SIGINT = signal.CTRL_C_EVENT
112+
else:
113+
SIGINT = signal.SIGINT
114+
115+
116+
def shutdown_process(
117+
proc: subprocess.Popen,
118+
interrupt_timeout: Optional[int] = None,
119+
terminate_timeout: Optional[int] = None,
120+
) -> None:
121+
"""Shut down the process gracefully with SIGINT -> SIGTERM -> SIGKILL."""
122+
123+
logger.info("sending interrupt signal to the process %s", proc.pid)
124+
proc.send_signal(SIGINT)
125+
126+
logger.info("waiting for the process %s to finish", proc.pid)
127+
try:
128+
proc.wait(interrupt_timeout)
129+
except subprocess.TimeoutExpired:
130+
logger.info(
131+
"timed out waiting, sending terminate signal to the process %s", proc.pid
132+
)
133+
proc.terminate()
134+
try:
135+
proc.wait(terminate_timeout)
136+
except subprocess.TimeoutExpired:
137+
logger.info("timed out waiting, killing the process %s", proc.pid)
138+
proc.kill()
139+
140+
100141
def _process_stream(stream: "IO[bytes]", callback: Callable[[str], None]) -> None:
101142
buffer = b""
102143
while byt := stream.read(1): # Read one byte at a time
@@ -1493,6 +1534,8 @@ def query(
14931534
output_hook: Callable[[str], None] = noop,
14941535
params: Optional[dict[str, str]] = None,
14951536
job_id: Optional[str] = None,
1537+
interrupt_timeout: Optional[int] = None,
1538+
terminate_timeout: Optional[int] = None,
14961539
) -> None:
14971540
cmd = [python_executable, "-c", query_script]
14981541
env = dict(env or os.environ)
@@ -1506,13 +1549,42 @@ def query(
15061549
if capture_output:
15071550
popen_kwargs = {"stdout": subprocess.PIPE, "stderr": subprocess.STDOUT}
15081551

1509-
with subprocess.Popen(cmd, env=env, **popen_kwargs) as proc: # noqa: S603
1510-
if capture_output:
1511-
args = (proc.stdout, output_hook)
1512-
thread = Thread(target=_process_stream, args=args, daemon=True)
1513-
thread.start()
1514-
thread.join() # wait for the reader thread
1552+
def signal_handler(sig: int, _: Any) -> NoReturn:
1553+
raise TerminationSignal(sig)
1554+
1555+
original_sigterm_handler = signal.getsignal(signal.SIGTERM)
1556+
signal.signal(signal.SIGTERM, signal_handler)
15151557

1558+
with subprocess.Popen(cmd, env=env, **popen_kwargs) as proc: # noqa: S603
1559+
logger.info("Starting process %s", proc.pid)
1560+
thread: Optional[Thread] = None
1561+
try:
1562+
if capture_output:
1563+
args = (proc.stdout, output_hook)
1564+
thread = Thread(target=_process_stream, args=args, daemon=True)
1565+
thread.start()
1566+
1567+
proc.wait()
1568+
except (KeyboardInterrupt, TerminationSignal) as exc:
1569+
# NOTE: we ignore `Ctrl-C` signals on CLI so `KeyboardInterrupt`
1570+
# is not expected to be raised here, but we handle it just in case.
1571+
# If we don't ignore it in CLI, the child will receive the signal twice.
1572+
logging.info("Shutting down process %s, received %r", proc.pid, exc)
1573+
# Rather than forwarding the signal to the child, we try to shut it down
1574+
# gracefully. This is because we consider the script to be interactive
1575+
# and special, so we give it time to cleanup before exiting.
1576+
shutdown_process(proc, interrupt_timeout, terminate_timeout)
1577+
if proc.returncode:
1578+
raise QueryScriptCancelError(
1579+
"Query script was canceled by user", return_code=proc.returncode
1580+
) from exc
1581+
finally:
1582+
if original_sigterm_handler:
1583+
signal.signal(signal.SIGTERM, original_sigterm_handler)
1584+
if thread:
1585+
thread.join()
1586+
1587+
logging.info("Process %s exited with return code %s", proc.pid, proc.returncode)
15161588
if proc.returncode == QUERY_SCRIPT_CANCELED_EXIT_CODE:
15171589
raise QueryScriptCancelError(
15181590
"Query script was canceled by user",

src/datachain/cli/commands/query.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import signal
23
import sys
34
import traceback
45
from typing import TYPE_CHECKING, Optional
@@ -33,6 +34,8 @@ def query(
3334
params=params,
3435
)
3536

37+
# ignore SIGINT in the main process
38+
signal.signal(signal.SIGINT, signal.SIG_IGN)
3639
try:
3740
catalog.query(
3841
script_content,

tests/unit/test_query.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66

77
import pytest
88

9-
from datachain.catalog.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE
9+
from datachain.catalog.catalog import (
10+
QUERY_SCRIPT_CANCELED_EXIT_CODE,
11+
TerminationSignal,
12+
)
1013
from datachain.error import QueryScriptCancelError, QueryScriptRunError
1114

1215

@@ -63,3 +66,16 @@ def test_non_zero_exitcode(catalog, mock_popen):
6366
catalog.query("pass")
6467
assert e.value.return_code == 1
6568
assert "Query script exited with error code 1" in str(e.value)
69+
70+
71+
@pytest.mark.parametrize("side_effect", [KeyboardInterrupt(), TerminationSignal(15)])
72+
def test_shutdown_process_called(mocker, catalog, mock_popen, side_effect):
73+
mock_popen.returncode = -2
74+
mock_popen.wait.side_effect = [side_effect]
75+
m = mocker.patch("datachain.catalog.catalog.shutdown_process", return_value=-2)
76+
77+
with pytest.raises(QueryScriptCancelError) as e:
78+
catalog.query("pass", interrupt_timeout=0.1, terminate_timeout=0.2)
79+
assert e.value.return_code == -2
80+
assert "Query script was canceled by user" in str(e.value)
81+
m.assert_called_once_with(mock_popen, 0.1, 0.2)

0 commit comments

Comments
 (0)