diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..9c6efbf --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,30 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.11", "3.13"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install package + run: pip install -e . + + - name: Install test dependencies + run: pip install pytest + + - name: Test + run: python -m pytest tests/ -v diff --git a/src/runqy_python/__init__.py b/src/runqy_python/__init__.py index 698525a..9e3bcd6 100644 --- a/src/runqy_python/__init__.py +++ b/src/runqy_python/__init__.py @@ -1,7 +1,7 @@ """runqy-python: Python SDK for runqy - write distributed task handlers with simple decorators.""" # Task execution (for workers) -from .decorator import task, load +from .decorator import task, load, RetryableError from .runner import run, run_once # Client (for enqueuing tasks) @@ -20,6 +20,7 @@ # Task execution "task", "load", + "RetryableError", "run", "run_once", # Client diff --git a/src/runqy_python/decorator.py b/src/runqy_python/decorator.py index bedd745..a6bf799 100644 --- a/src/runqy_python/decorator.py +++ b/src/runqy_python/decorator.py @@ -4,6 +4,23 @@ _registered_loader = None +class RetryableError(Exception): + """Raise this from a @task handler to signal that the task should be retried. + + Usage: + from runqy_python import task, RetryableError + + @task + def process(payload): + try: + result = call_external_api(payload) + except TimeoutError: + raise RetryableError("API timed out, please retry") + return result + """ + pass + + def task(func): """Decorator to register a function as the task handler. @@ -18,6 +35,11 @@ def process(payload: dict, ctx: dict) -> dict: return ctx["model"].predict(payload) """ global _registered_handler + if _registered_handler is not None: + raise RuntimeError( + f"@task handler already registered ({_registered_handler.__name__}). " + "Only one @task handler is allowed per process." + ) _registered_handler = func return func @@ -39,6 +61,11 @@ def process(payload: dict, ctx: dict) -> dict: return ctx["model"].predict(payload) """ global _registered_loader + if _registered_loader is not None: + raise RuntimeError( + f"@load handler already registered ({_registered_loader.__name__}). " + "Only one @load handler is allowed per process." + ) _registered_loader = func return func @@ -51,3 +78,10 @@ def get_handler(): def get_loader(): """Get the registered load function.""" return _registered_loader + + +def _reset(): + """Reset registered handler and loader. For testing only.""" + global _registered_handler, _registered_loader + _registered_handler = None + _registered_loader = None diff --git a/src/runqy_python/runner.py b/src/runqy_python/runner.py index 6c64641..48176dd 100644 --- a/src/runqy_python/runner.py +++ b/src/runqy_python/runner.py @@ -1,8 +1,67 @@ """Runner loop for processing tasks from runqy-worker.""" +import os import sys import json -from .decorator import get_handler, get_loader +import signal +import traceback +from .decorator import get_handler, get_loader, RetryableError + +# Flag for graceful shutdown +_shutdown_requested = False + +# Private file object for protocol communication (set by _protect_stdout) +_protocol_stdout = None + + +def _shutdown_handler(signum, frame): + """Handle SIGTERM/SIGINT for graceful shutdown. + + First signal: set flag so the current task can complete before exit. + Second signal: force exit (in case process is stuck). + """ + global _shutdown_requested + if _shutdown_requested: + # Second signal — force exit + sys.exit(1) + _shutdown_requested = True + + +def _protect_stdout(): + """Redirect sys.stdout to stderr so print() doesn't corrupt the JSON protocol. + + The original stdout fd is saved to _protocol_stdout for _safe_write to use. + """ + global _protocol_stdout + # Duplicate the real stdout fd so it survives sys.stdout reassignment + proto_fd = os.dup(sys.stdout.fileno()) + _protocol_stdout = os.fdopen(proto_fd, "w") + # Redirect sys.stdout to stderr so user print() goes to logs + sys.stdout = sys.stderr + + +def _safe_write(data): + """Safely write JSON data to the protocol channel, handling BrokenPipeError and serialization errors.""" + out = _protocol_stdout if _protocol_stdout is not None else sys.stdout + + try: + text = json.dumps(data) + except (TypeError, ValueError) as e: + # Result not JSON-serializable — send error response instead + fallback = { + "task_id": data.get("task_id", "unknown") if isinstance(data, dict) else "unknown", + "result": None, + "error": f"Result not JSON-serializable: {e}", + "retry": False, + } + text = json.dumps(fallback) + + try: + out.write(text + "\n") + out.flush() + except BrokenPipeError: + # Pipe closed by worker — exit cleanly + sys.exit(1) def run(): @@ -15,6 +74,13 @@ def run(): 4. Calls the registered @task handler with the payload (and context if @load was used) 5. Writes JSON responses to stdout """ + # Protect stdout: redirect sys.stdout to stderr so print() doesn't corrupt protocol + _protect_stdout() + + # Install signal handlers for graceful shutdown + signal.signal(signal.SIGTERM, _shutdown_handler) + signal.signal(signal.SIGINT, _shutdown_handler) + handler = get_handler() if handler is None: raise RuntimeError("No task handler registered. Use @task decorator.") @@ -23,14 +89,20 @@ def run(): loader = get_loader() ctx = None if loader is not None: - ctx = loader() + try: + ctx = loader() + except Exception as e: + _safe_write({"status": "error", "error": f"@load failed: {e}"}) + sys.exit(1) # Ready signal - print(json.dumps({"status": "ready"})) - sys.stdout.flush() + _safe_write({"status": "ready"}) # Process tasks from stdin for line in sys.stdin: + if _shutdown_requested: + break + line = line.strip() if not line: continue @@ -53,16 +125,29 @@ def run(): "error": None, "retry": False } - except Exception as e: + except json.JSONDecodeError as e: + response = { + "task_id": task_id, + "result": None, + "error": f"Invalid JSON input: {e}", + "retry": False + } + except RetryableError as e: response = { "task_id": task_id, "result": None, "error": str(e), + "retry": True + } + except Exception as e: + response = { + "task_id": task_id, + "result": None, + "error": traceback.format_exc(), "retry": False } - print(json.dumps(response)) - sys.stdout.flush() + _safe_write(response) def run_once(): @@ -78,6 +163,13 @@ def run_once(): 5. Writes response to stdout 6. Exits """ + # Protect stdout: redirect sys.stdout to stderr so print() doesn't corrupt protocol + _protect_stdout() + + # Install signal handlers for graceful shutdown + signal.signal(signal.SIGTERM, _shutdown_handler) + signal.signal(signal.SIGINT, _shutdown_handler) + handler = get_handler() if handler is None: raise RuntimeError("No task handler registered. Use @task decorator.") @@ -86,11 +178,14 @@ def run_once(): loader = get_loader() ctx = None if loader is not None: - ctx = loader() + try: + ctx = loader() + except Exception as e: + _safe_write({"status": "error", "error": f"@load failed: {e}"}) + sys.exit(1) # Ready signal - print(json.dumps({"status": "ready"})) - sys.stdout.flush() + _safe_write({"status": "ready"}) # Read ONE task line = sys.stdin.readline().strip() @@ -115,13 +210,26 @@ def run_once(): "error": None, "retry": False } - except Exception as e: + except json.JSONDecodeError as e: + response = { + "task_id": task_id, + "result": None, + "error": f"Invalid JSON input: {e}", + "retry": False + } + except RetryableError as e: response = { "task_id": task_id, "result": None, "error": str(e), + "retry": True + } + except Exception as e: + response = { + "task_id": task_id, + "result": None, + "error": traceback.format_exc(), "retry": False } - print(json.dumps(response)) - sys.stdout.flush() + _safe_write(response) diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..e4ab6e8 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,220 @@ +"""Tests for RunqyClient.""" + +import json +import unittest +from unittest import mock +from urllib.error import HTTPError + +from runqy_python.client import ( + AuthenticationError, + BatchResult, + RunqyClient, + RunqyError, + TaskInfo, + TaskNotFoundError, +) + + +class TestRunqyClientInit(unittest.TestCase): + def test_stores_url_and_key(self): + client = RunqyClient("http://localhost:3000", api_key="my-key") + self.assertEqual(client.server_url, "http://localhost:3000") + self.assertEqual(client.api_key, "my-key") + self.assertEqual(client.timeout, 30) + + def test_strips_trailing_slash(self): + client = RunqyClient("http://localhost:3000/", api_key="key") + self.assertEqual(client.server_url, "http://localhost:3000") + + def test_custom_timeout(self): + client = RunqyClient("http://localhost:3000", api_key="key", timeout=60) + self.assertEqual(client.timeout, 60) + + +class TestRunqyClientEnqueue(unittest.TestCase): + def setUp(self): + self.client = RunqyClient("http://localhost:3000", api_key="test-key") + + @mock.patch("runqy_python.client.urllib.request.urlopen") + def test_enqueue_success(self, mock_urlopen): + response_data = json.dumps({ + "info": { + "id": "task-123", + "queue": "inference.default", + "state": "pending", + } + }).encode("utf-8") + + mock_response = mock.MagicMock() + mock_response.read.return_value = response_data + mock_response.__enter__ = mock.MagicMock(return_value=mock_response) + mock_response.__exit__ = mock.MagicMock(return_value=False) + mock_urlopen.return_value = mock_response + + result = self.client.enqueue("inference.default", {"msg": "hello"}) + + self.assertIsInstance(result, TaskInfo) + self.assertEqual(result.task_id, "task-123") + self.assertEqual(result.queue, "inference.default") + self.assertEqual(result.state, "pending") + + # Verify the request was made correctly + call_args = mock_urlopen.call_args + req = call_args[0][0] + self.assertTrue(req.full_url.endswith("/queue/add")) + self.assertEqual(req.get_header("Authorization"), "Bearer test-key") + self.assertEqual(req.get_header("Content-type"), "application/json") + + body = json.loads(req.data.decode("utf-8")) + self.assertEqual(body["queue"], "inference.default") + self.assertEqual(body["data"], {"msg": "hello"}) + self.assertEqual(body["timeout"], 300) + + @mock.patch("runqy_python.client.urllib.request.urlopen") + def test_enqueue_auth_error(self, mock_urlopen): + mock_urlopen.side_effect = HTTPError( + url="http://localhost:3000/queue/add", + code=401, + msg="Unauthorized", + hdrs=None, + fp=mock.MagicMock(read=mock.MagicMock(return_value=b"invalid api key")), + ) + + with self.assertRaises(AuthenticationError): + self.client.enqueue("inference.default", {"msg": "hello"}) + + @mock.patch("runqy_python.client.urllib.request.urlopen") + def test_enqueue_server_error(self, mock_urlopen): + mock_urlopen.side_effect = HTTPError( + url="http://localhost:3000/queue/add", + code=500, + msg="Internal Server Error", + hdrs=None, + fp=mock.MagicMock(read=mock.MagicMock(return_value=b"internal error")), + ) + + with self.assertRaises(RunqyError): + self.client.enqueue("inference.default", {"msg": "hello"}) + + @mock.patch("runqy_python.client.urllib.request.urlopen") + def test_enqueue_not_found(self, mock_urlopen): + mock_urlopen.side_effect = HTTPError( + url="http://localhost:3000/queue/add", + code=404, + msg="Not Found", + hdrs=None, + fp=mock.MagicMock(read=mock.MagicMock(return_value=b"queue not found")), + ) + + with self.assertRaises(TaskNotFoundError): + self.client.enqueue("nonexistent", {"msg": "hello"}) + + +class TestRunqyClientEnqueueBatch(unittest.TestCase): + def setUp(self): + self.client = RunqyClient("http://localhost:3000", api_key="test-key") + + @mock.patch("runqy_python.client.urllib.request.urlopen") + def test_enqueue_batch_success(self, mock_urlopen): + response_data = json.dumps({ + "enqueued": 2, + "failed": 0, + "task_ids": ["t1", "t2"], + "errors": [], + }).encode("utf-8") + + mock_response = mock.MagicMock() + mock_response.read.return_value = response_data + mock_response.__enter__ = mock.MagicMock(return_value=mock_response) + mock_response.__exit__ = mock.MagicMock(return_value=False) + mock_urlopen.return_value = mock_response + + result = self.client.enqueue_batch( + "inference.default", + [{"input": "a"}, {"input": "b"}], + ) + + self.assertIsInstance(result, BatchResult) + self.assertEqual(result.enqueued, 2) + self.assertEqual(result.failed, 0) + self.assertEqual(result.task_ids, ["t1", "t2"]) + self.assertEqual(result.errors, []) + + # Verify request body structure + call_args = mock_urlopen.call_args + req = call_args[0][0] + body = json.loads(req.data.decode("utf-8")) + self.assertEqual(body["queue"], "inference.default") + self.assertEqual(len(body["jobs"]), 2) + self.assertEqual(body["jobs"][0]["data"], {"input": "a"}) + + +class TestRunqyClientGetTask(unittest.TestCase): + def setUp(self): + self.client = RunqyClient("http://localhost:3000", api_key="test-key") + + @mock.patch("runqy_python.client.urllib.request.urlopen") + def test_get_task_completed(self, mock_urlopen): + response_data = json.dumps({ + "info": { + "id": "task-456", + "queue": "inference.default", + "state": "completed", + "result": json.dumps({"output": "done"}), + "payload": json.dumps({"input": "test"}), + } + }).encode("utf-8") + + mock_response = mock.MagicMock() + mock_response.read.return_value = response_data + mock_response.__enter__ = mock.MagicMock(return_value=mock_response) + mock_response.__exit__ = mock.MagicMock(return_value=False) + mock_urlopen.return_value = mock_response + + result = self.client.get_task("task-456") + + self.assertEqual(result.task_id, "task-456") + self.assertEqual(result.state, "completed") + self.assertEqual(result.result, {"output": "done"}) + self.assertEqual(result.payload, {"input": "test"}) + + @mock.patch("runqy_python.client.urllib.request.urlopen") + def test_get_task_not_found(self, mock_urlopen): + mock_urlopen.side_effect = HTTPError( + url="http://localhost:3000/queue/task-999", + code=404, + msg="Not Found", + hdrs=None, + fp=mock.MagicMock(read=mock.MagicMock(return_value=b"not found")), + ) + + with self.assertRaises(TaskNotFoundError): + self.client.get_task("task-999") + + +class TestModuleLevelFunctions(unittest.TestCase): + @mock.patch("runqy_python.client.urllib.request.urlopen") + def test_module_enqueue(self, mock_urlopen): + from runqy_python.client import enqueue + + response_data = json.dumps({ + "info": {"id": "t1", "queue": "q", "state": "pending"} + }).encode("utf-8") + + mock_response = mock.MagicMock() + mock_response.read.return_value = response_data + mock_response.__enter__ = mock.MagicMock(return_value=mock_response) + mock_response.__exit__ = mock.MagicMock(return_value=False) + mock_urlopen.return_value = mock_response + + result = enqueue( + "q", {"key": "val"}, + server_url="http://localhost:3000", + api_key="key", + ) + self.assertIsInstance(result, TaskInfo) + self.assertEqual(result.task_id, "t1") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_runner_hardening.py b/tests/test_runner_hardening.py new file mode 100644 index 0000000..d556443 --- /dev/null +++ b/tests/test_runner_hardening.py @@ -0,0 +1,541 @@ +"""Tests for Phase 1 runner.py hardening fixes. + +Covers: + - _safe_write: normal output, non-serializable fallback, BrokenPipeError + - run(): @load failure sends {"status":"error"}, invalid JSON input handling + - run_once(): @load failure sends {"status":"error"} + +Uses unittest + unittest.mock with io.StringIO for stdin/stdout redirection. +""" + +import io +import json +import sys +import unittest +from unittest import mock + +# We need to be able to reset the decorator global state between tests, +# so import the decorator module directly. +import runqy_python.decorator as decorator +import runqy_python.runner as runner + + +class SafeWriteTestCase(unittest.TestCase): + """Tests for runner._safe_write.""" + + def setUp(self): + # Ensure _safe_write uses sys.stdout (no protocol redirect active) + self._saved_protocol_stdout = runner._protocol_stdout + runner._protocol_stdout = None + + def tearDown(self): + runner._protocol_stdout = self._saved_protocol_stdout + + def test_normal_dict_outputs_json(self): + """_safe_write with a normal dict should write valid JSON + newline to stdout.""" + fake_stdout = io.StringIO() + with mock.patch.object(sys, "stdout", fake_stdout): + runner._safe_write({"task_id": "t1", "result": {"ok": True}, "error": None, "retry": False}) + + output = fake_stdout.getvalue() + # Should end with a newline + self.assertTrue(output.endswith("\n"), "Output should end with a newline") + + # Should be valid JSON + parsed = json.loads(output.strip()) + self.assertEqual(parsed["task_id"], "t1") + self.assertEqual(parsed["result"], {"ok": True}) + self.assertIsNone(parsed["error"]) + self.assertFalse(parsed["retry"]) + + def test_non_serializable_data_outputs_error_fallback(self): + """_safe_write with non-serializable data (e.g., a set) should output + a fallback error response instead of crashing.""" + fake_stdout = io.StringIO() + non_serializable = { + "task_id": "t-bad", + "result": {"items": {1, 2, 3}}, # sets are not JSON-serializable + "error": None, + "retry": False, + } + with mock.patch.object(sys, "stdout", fake_stdout): + # Should NOT raise + runner._safe_write(non_serializable) + + output = fake_stdout.getvalue() + parsed = json.loads(output.strip()) + + # Fallback should contain the task_id from the original data + self.assertEqual(parsed["task_id"], "t-bad") + # result should be None in the fallback + self.assertIsNone(parsed["result"]) + # error should mention "not JSON-serializable" + self.assertIn("not JSON-serializable", parsed["error"]) + self.assertFalse(parsed["retry"]) + + def test_non_serializable_without_task_id_uses_unknown(self): + """_safe_write with non-serializable data and no task_id should default to 'unknown'.""" + fake_stdout = io.StringIO() + # No task_id key at all + non_serializable = {"result": object()} + with mock.patch.object(sys, "stdout", fake_stdout): + runner._safe_write(non_serializable) + + output = fake_stdout.getvalue() + parsed = json.loads(output.strip()) + self.assertEqual(parsed["task_id"], "unknown") + self.assertIn("not JSON-serializable", parsed["error"]) + + def test_non_dict_non_serializable_uses_unknown(self): + """_safe_write with a non-dict, non-serializable value should use 'unknown' task_id.""" + fake_stdout = io.StringIO() + # Pass a non-dict that is also not serializable + with mock.patch.object(sys, "stdout", fake_stdout): + runner._safe_write(object()) + + output = fake_stdout.getvalue() + parsed = json.loads(output.strip()) + self.assertEqual(parsed["task_id"], "unknown") + self.assertIn("not JSON-serializable", parsed["error"]) + + def test_broken_pipe_exits_cleanly(self): + """_safe_write should call sys.exit(1) on BrokenPipeError.""" + broken_stdout = mock.MagicMock() + broken_stdout.write.side_effect = BrokenPipeError("pipe closed") + with mock.patch.object(sys, "stdout", broken_stdout): + with self.assertRaises(SystemExit) as ctx: + runner._safe_write({"status": "ready"}) + self.assertEqual(ctx.exception.code, 1) + + +class RunLoadFailureTestCase(unittest.TestCase): + """Tests for run() handling a failing @load function.""" + + def setUp(self): + # Reset global decorator state before each test + decorator._reset() + # Reset shutdown flag and protocol stdout + runner._shutdown_requested = False + runner._protocol_stdout = None + + def tearDown(self): + decorator._reset() + runner._shutdown_requested = False + runner._protocol_stdout = None + + def test_run_load_failure_sends_error_status(self): + """run() should send {"status":"error"} and exit(1) when @load raises.""" + + @decorator.task + def my_handler(payload): + return {"done": True} + + @decorator.load + def my_loader(): + raise RuntimeError("model download failed") + + fake_stdout = io.StringIO() + + with mock.patch.object(sys, "stdout", fake_stdout), \ + mock.patch("runqy_python.runner._protect_stdout"), \ + mock.patch("runqy_python.runner.signal.signal"): + with self.assertRaises(SystemExit) as ctx: + runner.run() + + self.assertEqual(ctx.exception.code, 1) + + # Parse the output line + output = fake_stdout.getvalue().strip() + parsed = json.loads(output) + + self.assertEqual(parsed["status"], "error") + self.assertIn("@load failed", parsed["error"]) + self.assertIn("model download failed", parsed["error"]) + + def test_run_without_loader_sends_ready(self): + """run() with no @load should send {"status":"ready"} and proceed normally.""" + + @decorator.task + def my_handler(payload): + return {"echo": payload} + + task_input = json.dumps({"task_id": "t1", "payload": {"msg": "hello"}}) + "\n" + fake_stdin = io.StringIO(task_input) + fake_stdout = io.StringIO() + + with mock.patch.object(sys, "stdin", fake_stdin), \ + mock.patch.object(sys, "stdout", fake_stdout), \ + mock.patch("runqy_python.runner._protect_stdout"), \ + mock.patch("runqy_python.runner.signal.signal"): + runner.run() + + lines = fake_stdout.getvalue().strip().split("\n") + # First line: ready signal + ready = json.loads(lines[0]) + self.assertEqual(ready["status"], "ready") + + # Second line: task response + resp = json.loads(lines[1]) + self.assertEqual(resp["task_id"], "t1") + self.assertEqual(resp["result"], {"echo": {"msg": "hello"}}) + self.assertIsNone(resp["error"]) + + def test_run_no_handler_raises(self): + """run() should raise RuntimeError if no @task handler is registered.""" + with mock.patch("runqy_python.runner._protect_stdout"), \ + mock.patch("runqy_python.runner.signal.signal"): + with self.assertRaises(RuntimeError) as ctx: + runner.run() + self.assertIn("No task handler registered", str(ctx.exception)) + + +class RunInvalidJsonTestCase(unittest.TestCase): + """Tests for run() handling invalid JSON input.""" + + def setUp(self): + decorator._reset() + runner._shutdown_requested = False + runner._protocol_stdout = None + + def tearDown(self): + decorator._reset() + runner._shutdown_requested = False + runner._protocol_stdout = None + + def test_invalid_json_sends_error_response(self): + """run() should send an error response with 'Invalid JSON input' for malformed input.""" + + @decorator.task + def my_handler(payload): + return {"done": True} + + # First line is invalid JSON, no more lines after + fake_stdin = io.StringIO("this is not json\n") + fake_stdout = io.StringIO() + + with mock.patch.object(sys, "stdin", fake_stdin), \ + mock.patch.object(sys, "stdout", fake_stdout), \ + mock.patch("runqy_python.runner._protect_stdout"), \ + mock.patch("runqy_python.runner.signal.signal"): + runner.run() + + lines = fake_stdout.getvalue().strip().split("\n") + # First line: ready signal + ready = json.loads(lines[0]) + self.assertEqual(ready["status"], "ready") + + # Second line: error response for invalid JSON + resp = json.loads(lines[1]) + self.assertEqual(resp["task_id"], "unknown") + self.assertIsNone(resp["result"]) + self.assertIn("Invalid JSON input", resp["error"]) + self.assertFalse(resp["retry"]) + + def test_invalid_json_does_not_crash_and_continues(self): + """run() should handle invalid JSON and then continue processing valid tasks.""" + + @decorator.task + def my_handler(payload): + return {"value": payload.get("x", 0) * 2} + + # First line is invalid, second is valid + lines_in = "NOT_JSON\n" + json.dumps({"task_id": "t2", "payload": {"x": 5}}) + "\n" + fake_stdin = io.StringIO(lines_in) + fake_stdout = io.StringIO() + + with mock.patch.object(sys, "stdin", fake_stdin), \ + mock.patch.object(sys, "stdout", fake_stdout), \ + mock.patch("runqy_python.runner._protect_stdout"), \ + mock.patch("runqy_python.runner.signal.signal"): + runner.run() + + lines = fake_stdout.getvalue().strip().split("\n") + self.assertEqual(len(lines), 3) # ready + error + success + + # Line 0: ready + self.assertEqual(json.loads(lines[0])["status"], "ready") + # Line 1: error for invalid JSON + self.assertIn("Invalid JSON input", json.loads(lines[1])["error"]) + # Line 2: successful task response + resp = json.loads(lines[2]) + self.assertEqual(resp["task_id"], "t2") + self.assertEqual(resp["result"], {"value": 10}) + self.assertIsNone(resp["error"]) + + def test_empty_lines_are_skipped(self): + """run() should skip empty lines without producing output.""" + + @decorator.task + def my_handler(payload): + return {"ok": True} + + # Only empty/whitespace lines, then a valid task + lines_in = "\n \n" + json.dumps({"task_id": "t3", "payload": {}}) + "\n" + fake_stdin = io.StringIO(lines_in) + fake_stdout = io.StringIO() + + with mock.patch.object(sys, "stdin", fake_stdin), \ + mock.patch.object(sys, "stdout", fake_stdout), \ + mock.patch("runqy_python.runner._protect_stdout"), \ + mock.patch("runqy_python.runner.signal.signal"): + runner.run() + + lines = fake_stdout.getvalue().strip().split("\n") + # Should only have ready + one task response (empty lines skipped) + self.assertEqual(len(lines), 2) + self.assertEqual(json.loads(lines[0])["status"], "ready") + self.assertEqual(json.loads(lines[1])["task_id"], "t3") + + +class RunOnceLoadFailureTestCase(unittest.TestCase): + """Tests for run_once() handling a failing @load function.""" + + def setUp(self): + decorator._reset() + runner._shutdown_requested = False + runner._protocol_stdout = None + + def tearDown(self): + decorator._reset() + runner._shutdown_requested = False + runner._protocol_stdout = None + + def test_run_once_load_failure_sends_error_status(self): + """run_once() should send {"status":"error"} and exit(1) when @load raises.""" + + @decorator.task + def my_handler(payload): + return {"done": True} + + @decorator.load + def my_loader(): + raise ValueError("bad config") + + fake_stdout = io.StringIO() + + with mock.patch.object(sys, "stdout", fake_stdout), \ + mock.patch("runqy_python.runner._protect_stdout"), \ + mock.patch("runqy_python.runner.signal.signal"): + with self.assertRaises(SystemExit) as ctx: + runner.run_once() + + self.assertEqual(ctx.exception.code, 1) + + output = fake_stdout.getvalue().strip() + parsed = json.loads(output) + + self.assertEqual(parsed["status"], "error") + self.assertIn("@load failed", parsed["error"]) + self.assertIn("bad config", parsed["error"]) + + def test_run_once_processes_single_task(self): + """run_once() should process exactly one task and return.""" + + @decorator.task + def my_handler(payload): + return {"doubled": payload.get("n", 0) * 2} + + task_input = json.dumps({"task_id": "once-1", "payload": {"n": 7}}) + "\n" + fake_stdin = io.StringIO(task_input) + fake_stdout = io.StringIO() + + with mock.patch.object(sys, "stdin", fake_stdin), \ + mock.patch.object(sys, "stdout", fake_stdout), \ + mock.patch("runqy_python.runner._protect_stdout"), \ + mock.patch("runqy_python.runner.signal.signal"): + runner.run_once() + + lines = fake_stdout.getvalue().strip().split("\n") + self.assertEqual(len(lines), 2) # ready + response + + ready = json.loads(lines[0]) + self.assertEqual(ready["status"], "ready") + + resp = json.loads(lines[1]) + self.assertEqual(resp["task_id"], "once-1") + self.assertEqual(resp["result"], {"doubled": 14}) + self.assertIsNone(resp["error"]) + + def test_run_once_invalid_json_sends_error(self): + """run_once() should handle invalid JSON input gracefully.""" + + @decorator.task + def my_handler(payload): + return {"ok": True} + + fake_stdin = io.StringIO("{broken json\n") + fake_stdout = io.StringIO() + + with mock.patch.object(sys, "stdin", fake_stdin), \ + mock.patch.object(sys, "stdout", fake_stdout), \ + mock.patch("runqy_python.runner._protect_stdout"), \ + mock.patch("runqy_python.runner.signal.signal"): + runner.run_once() + + lines = fake_stdout.getvalue().strip().split("\n") + self.assertEqual(len(lines), 2) # ready + error + + ready = json.loads(lines[0]) + self.assertEqual(ready["status"], "ready") + + resp = json.loads(lines[1]) + self.assertEqual(resp["task_id"], "unknown") + self.assertIn("Invalid JSON input", resp["error"]) + self.assertFalse(resp["retry"]) + + def test_run_once_no_handler_raises(self): + """run_once() should raise RuntimeError if no @task handler is registered.""" + with mock.patch("runqy_python.runner._protect_stdout"), \ + mock.patch("runqy_python.runner.signal.signal"): + with self.assertRaises(RuntimeError) as ctx: + runner.run_once() + self.assertIn("No task handler registered", str(ctx.exception)) + + def test_run_once_empty_input_returns_without_error(self): + """run_once() should return cleanly when stdin is empty (no task to process).""" + + @decorator.task + def my_handler(payload): + return {"ok": True} + + fake_stdin = io.StringIO("") + fake_stdout = io.StringIO() + + with mock.patch.object(sys, "stdin", fake_stdin), \ + mock.patch.object(sys, "stdout", fake_stdout), \ + mock.patch("runqy_python.runner._protect_stdout"), \ + mock.patch("runqy_python.runner.signal.signal"): + # Should not raise + runner.run_once() + + lines = fake_stdout.getvalue().strip().split("\n") + # Only the ready signal, no task response + self.assertEqual(len(lines), 1) + self.assertEqual(json.loads(lines[0])["status"], "ready") + + +class ShutdownHandlerTestCase(unittest.TestCase): + """Tests for _shutdown_handler.""" + + def setUp(self): + runner._shutdown_requested = False + + def tearDown(self): + runner._shutdown_requested = False + + def test_shutdown_handler_sets_flag_on_first_signal(self): + """First signal should set _shutdown_requested without exiting.""" + runner._shutdown_handler(15, None) # SIGTERM = 15 + self.assertTrue(runner._shutdown_requested) + + def test_shutdown_handler_exits_on_second_signal(self): + """Second signal should force exit when already shutting down.""" + runner._shutdown_requested = True # simulate first signal already received + with self.assertRaises(SystemExit) as ctx: + runner._shutdown_handler(15, None) # second SIGTERM + self.assertEqual(ctx.exception.code, 1) + + +class StdoutProtectionTestCase(unittest.TestCase): + """Tests for stdout protection (print() shouldn't corrupt protocol).""" + + def setUp(self): + decorator._reset() + runner._shutdown_requested = False + runner._protocol_stdout = None + + def tearDown(self): + decorator._reset() + runner._shutdown_requested = False + runner._protocol_stdout = None + + def test_print_in_handler_does_not_appear_in_protocol(self): + """print() inside a @task handler should not corrupt JSON protocol output.""" + + @decorator.task + def my_handler(payload): + print("debug: processing task") # This should go to stderr, not protocol + return {"ok": True} + + task_input = json.dumps({"task_id": "t-print", "payload": {}}) + "\n" + fake_stdin = io.StringIO(task_input) + fake_protocol = io.StringIO() + fake_stderr = io.StringIO() + + # Simulate what _protect_stdout does: set _protocol_stdout and redirect stdout to stderr + runner._protocol_stdout = fake_protocol + + with mock.patch.object(sys, "stdin", fake_stdin), \ + mock.patch.object(sys, "stdout", fake_stderr), \ + mock.patch("runqy_python.runner._protect_stdout"), \ + mock.patch("runqy_python.runner.signal.signal"): + runner.run() + + # Protocol output should be clean JSON only + protocol_lines = fake_protocol.getvalue().strip().split("\n") + self.assertEqual(len(protocol_lines), 2) # ready + response + ready = json.loads(protocol_lines[0]) + self.assertEqual(ready["status"], "ready") + resp = json.loads(protocol_lines[1]) + self.assertEqual(resp["task_id"], "t-print") + self.assertEqual(resp["result"], {"ok": True}) + + # print() output should have gone to "stderr" (which is sys.stdout in this test) + self.assertIn("debug: processing task", fake_stderr.getvalue()) + + +class DecoratorOverwriteTestCase(unittest.TestCase): + """Tests that @task and @load raise on double registration.""" + + def setUp(self): + decorator._reset() + + def tearDown(self): + decorator._reset() + + def test_task_double_registration_raises(self): + """Registering @task twice should raise RuntimeError.""" + @decorator.task + def handler_one(payload): + return {} + + with self.assertRaises(RuntimeError) as ctx: + @decorator.task + def handler_two(payload): + return {} + + self.assertIn("already registered", str(ctx.exception)) + self.assertIn("handler_one", str(ctx.exception)) + + def test_load_double_registration_raises(self): + """Registering @load twice should raise RuntimeError.""" + @decorator.load + def loader_one(): + return {} + + with self.assertRaises(RuntimeError) as ctx: + @decorator.load + def loader_two(): + return {} + + self.assertIn("already registered", str(ctx.exception)) + self.assertIn("loader_one", str(ctx.exception)) + + def test_reset_allows_re_registration(self): + """After _reset(), decorators can be applied again.""" + @decorator.task + def handler_one(payload): + return {} + + decorator._reset() + + # Should not raise + @decorator.task + def handler_two(payload): + return {"new": True} + + self.assertEqual(decorator.get_handler(), handler_two) + + +if __name__ == "__main__": + unittest.main()