Skip to content

query: interrupt query script on SIGTERM #858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 84 additions & 5 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import os.path
import posixpath
import signal
import subprocess
import sys
import time
Expand Down Expand Up @@ -97,6 +98,47 @@
pass


class TerminationSignal(RuntimeError): # noqa: N818
def __init__(self, signal):
self.signal = signal
super().__init__("Received termination signal", signal)

def __repr__(self):
return f"{self.__class__.__name__}({self.signal})"

Check warning on line 107 in src/datachain/catalog/catalog.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/catalog/catalog.py#L107

Added line #L107 was not covered by tests


if sys.platform == "win32":
SIGINT = signal.CTRL_C_EVENT
else:
SIGINT = signal.SIGINT


def shutdown_process(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unit test this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no point in unittesting this. We'll check mock_popen.terminate.assert_called_once() which is not that useful.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean function test this - it should not be about testing "mocks" or course

proc: subprocess.Popen,
interrupt_timeout: Optional[int] = None,
terminate_timeout: Optional[int] = None,
) -> int:
"""Shut down the process gracefully with SIGINT -> SIGTERM -> SIGKILL."""

logger.info("sending interrupt signal to the process %s", proc.pid)
proc.send_signal(SIGINT)

logger.info("waiting for the process %s to finish", proc.pid)
try:
return proc.wait(interrupt_timeout)
except subprocess.TimeoutExpired:
logger.info(
"timed out waiting, sending terminate signal to the process %s", proc.pid
)
proc.terminate()
try:
return proc.wait(terminate_timeout)
except subprocess.TimeoutExpired:
logger.info("timed out waiting, killing the process %s", proc.pid)
proc.kill()
return proc.wait()


def _process_stream(stream: "IO[bytes]", callback: Callable[[str], None]) -> None:
buffer = b""
while byt := stream.read(1): # Read one byte at a time
Expand Down Expand Up @@ -1493,6 +1535,8 @@
output_hook: Callable[[str], None] = noop,
params: Optional[dict[str, str]] = None,
job_id: Optional[str] = None,
interrupt_timeout: Optional[int] = None,
terminate_timeout: Optional[int] = None,
) -> None:
cmd = [python_executable, "-c", query_script]
env = dict(env or os.environ)
Expand All @@ -1506,13 +1550,48 @@
if capture_output:
popen_kwargs = {"stdout": subprocess.PIPE, "stderr": subprocess.STDOUT}

def raise_termination_signal(sig: int, _: Any) -> NoReturn:
raise TerminationSignal(sig)

thread: Optional[Thread] = None
with subprocess.Popen(cmd, env=env, **popen_kwargs) as proc: # noqa: S603
if capture_output:
args = (proc.stdout, output_hook)
thread = Thread(target=_process_stream, args=args, daemon=True)
thread.start()
thread.join() # wait for the reader thread
logger.info("Starting process %s", proc.pid)

orig_sigint_handler = signal.getsignal(signal.SIGINT)
# ignore SIGINT in the main process.
# In the terminal, SIGINTs are received by all the processes in
# the foreground process group, so the script will receive the signal too.
# (If we forward the signal to the child, it will receive it twice.)
signal.signal(signal.SIGINT, signal.SIG_IGN)

orig_sigterm_handler = signal.getsignal(signal.SIGTERM)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be happening inside try (to avoid getting the TerminationSignal raised in between)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TerminationSignal is raised only after we set a signal handler, which happens after this line.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it can be raised right before try?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theoretically yes, it could.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so, then we should move it inside and catch and make sure that we handle if for example of the original handlers is not yet initialized, etc

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this a few lines code change, no?

also, doing multithreading / multiprocessing is always hard exactly because lot of problems manifest themselves rarely but bite hard in terms of time someone can then spend debugging them later ... so, yes, if there are obvious improvements while it is still possible w/o obvious downsides I would always err on the side of doing safer code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The best thing to do here is move sigterm handler out of this with loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, doing multithreading / multiprocessing is always hard exactly because lot of problems manifest themselves rarely but bite hard in terms of time someone can then spend debugging them later ... so, yes, if there are obvious improvements while it is still possible w/o obvious downsides I would always err on the side of doing safer code

There is no multithreading or multiprocessing involved here. The main thread always gets a signal.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The best thing to do here is move sigterm handler out of this with loop.

probably, and in except clause and in finally check if proc is setup, if original handlers are set

There is no multithreading or multiprocessing involved here. The main thread always gets a signal.

there is to my mind. An external process sends a signal to a another one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#873.

signal.signal(signal.SIGTERM, raise_termination_signal)
try:
if capture_output:
args = (proc.stdout, output_hook)
thread = Thread(target=_process_stream, args=args, daemon=True)
thread.start()

proc.wait()
except TerminationSignal as exc:
signal.signal(signal.SIGTERM, orig_sigterm_handler)
signal.signal(signal.SIGINT, orig_sigint_handler)
logging.info("Shutting down process %s, received %r", proc.pid, exc)
# Rather than forwarding the signal to the child, we try to shut it down
# gracefully. This is because we consider the script to be interactive
# and special, so we give it time to cleanup before exiting.
shutdown_process(proc, interrupt_timeout, terminate_timeout)
if proc.returncode:
raise QueryScriptCancelError(
"Query script was canceled by user", return_code=proc.returncode
) from exc
finally:
signal.signal(signal.SIGTERM, orig_sigterm_handler)
signal.signal(signal.SIGINT, orig_sigint_handler)
if thread:
thread.join() # wait for the reader thread
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Q] Will this still run into deadlocks while we are waiting for tqdm/tqdm#1649?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deadlock was happening when fp.write() was raising exception. That won't happen now that we use SIGTERM to cancel and wait.

But I can imagine, when the conditions are favorable, the tqdm may still deadlock.
But the chances are slim. Eg: when you press Ctrl-C right at the time display runs, it may deadlock if there are other tqdm processes running in other thread.


logging.info("Process %s exited with return code %s", proc.pid, proc.returncode)
if proc.returncode == QUERY_SCRIPT_CANCELED_EXIT_CODE:
raise QueryScriptCancelError(
"Query script was canceled by user",
Expand Down
79 changes: 79 additions & 0 deletions tests/func/test_query.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import os.path
import signal
import sys
from multiprocessing.pool import ExceptionWithTraceback # type: ignore[attr-defined]
from textwrap import dedent

import cloudpickle
import multiprocess
import pytest

from datachain.catalog import Catalog
from datachain.cli import query
from datachain.data_storage import AbstractDBMetastore, JobQueryType, JobStatus
from datachain.error import QueryScriptCancelError
from datachain.job import Job
from tests.utils import wait_for_condition


@pytest.fixture
Expand Down Expand Up @@ -102,3 +109,75 @@ def test_query_cli(cloud_test_catalog_tmpfile, tmp_path, catalog_info_filepath,
assert job.params == {"url": src_uri}
assert job.metrics == {"count": 7}
assert job.python_version == f"{sys.version_info.major}.{sys.version_info.minor}"


if sys.platform == "win32":
SIGKILL = signal.SIGTERM
else:
SIGKILL = signal.SIGKILL


@pytest.mark.skipif(sys.platform == "win32", reason="Windows does not have SIGTERM")
@pytest.mark.parametrize(
"setup,expected_return_code",
[
("", -signal.SIGINT),
("signal.signal(signal.SIGINT, signal.SIG_IGN)", -signal.SIGTERM),
(
"""\
signal.signal(signal.SIGINT, signal.SIG_IGN)
signal.signal(signal.SIGTERM, signal.SIG_IGN)
""",
-SIGKILL,
),
],
)
def test_shutdown_on_sigterm(tmp_dir, request, catalog, setup, expected_return_code):
query = f"""\
import os, pathlib, signal, sys, time

pathlib.Path("ready").touch(exist_ok=False)
{setup}
time.sleep(10)
"""

def apply(f, args, kwargs):
return f(*args, **kwargs)

def func(ms_params, wh_params, init_params, q):
catalog = Catalog(apply(*ms_params), apply(*wh_params), **init_params)
try:
catalog.query(query, interrupt_timeout=0.5, terminate_timeout=0.5)
except Exception as e: # noqa: BLE001
q.put(ExceptionWithTraceback(e, e.__traceback__))
else:
q.put(None)

mp_ctx = multiprocess.get_context("spawn")
q = mp_ctx.Queue()
p = mp_ctx.Process(
target=func,
args=(
catalog.metastore.clone_params(),
catalog.warehouse.clone_params(),
catalog.get_init_params(),
q,
),
)
p.start()
request.addfinalizer(p.kill)

def is_ready():
assert p.is_alive(), "Process is dead"
return os.path.exists("ready")

# make sure the process is running before we send the signal
wait_for_condition(is_ready, "script to start", timeout=5)

os.kill(p.pid, signal.SIGTERM)
p.join(timeout=3) # might take as long as 1 second to complete shutdown_process
assert not p.exitcode

e = q.get_nowait()
assert isinstance(e, QueryScriptCancelError)
assert e.return_code == expected_return_code
17 changes: 16 additions & 1 deletion tests/unit/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

import pytest

from datachain.catalog.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE
from datachain.catalog.catalog import (
QUERY_SCRIPT_CANCELED_EXIT_CODE,
TerminationSignal,
)
from datachain.error import QueryScriptCancelError, QueryScriptRunError


Expand Down Expand Up @@ -63,3 +66,15 @@ def test_non_zero_exitcode(catalog, mock_popen):
catalog.query("pass")
assert e.value.return_code == 1
assert "Query script exited with error code 1" in str(e.value)


def test_shutdown_process_on_sigterm(mocker, catalog, mock_popen):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure tbh this is a very meaningful test :) seems we are testing mocks primarily

Copy link
Contributor Author

@skshetry skshetry Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main goal of this test here is to assert that shutdown_process was called (and with expected values).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, I understand. Doesn't see to be very useful though. Or let me put it in a different way - I think we rarely test things this way (to check if function is called w/o testing logic in some way). Still seems to be not very meaningful to my mind. While we can do in this case most likely a proper test

Copy link
Contributor Author

@skshetry skshetry Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handling signals is a tricky thing, especially when the main process needs to receive it (where it happens to be a test process for tests).

I think we rarely test things this way (to check if function is called w/o testing logic in some way

We do that in plenty of places in DVC, especially in CLI tests.

While we can do in this case most likely a proper test

If you have some example unittest that we can replace this with (without crossing boundaries into the working of shutdown_process), I'd be very interested.
Again, my goal with this test is to make sure that the shutdown_process gets called with expected values. If you see a way to do that, I'll be happy to write it, or review it.

I did add a functional test, but it's complex, and you have to tread very carefully not to get main process killed/blocked/waiting forever.

mock_popen.returncode = -2
mock_popen.wait.side_effect = [TerminationSignal(15)]
m = mocker.patch("datachain.catalog.catalog.shutdown_process", return_value=-2)

with pytest.raises(QueryScriptCancelError) as e:
catalog.query("pass", interrupt_timeout=0.1, terminate_timeout=0.2)
assert e.value.return_code == -2
assert "Query script was canceled by user" in str(e.value)
m.assert_called_once_with(mock_popen, 0.1, 0.2)
Loading