diff --git a/python/README.md b/python/README.md index aa82e0c3..3f84c0f1 100644 --- a/python/README.md +++ b/python/README.md @@ -23,6 +23,45 @@ python chat.py ## Quick Start +### Using Context Managers (Recommended) + +The SDK supports Python's async context manager protocol for automatic resource cleanup: + +```python +import asyncio +from copilot import CopilotClient + +async def main(): + # Client automatically starts on enter and cleans up on exit + async with CopilotClient() as client: + # Create a session with automatic cleanup + async with await client.create_session({"model": "gpt-4"}) as session: + # Wait for response using session.idle event + done = asyncio.Event() + + def on_event(event): + if event.type.value == "assistant.message": + print(event.data.content) + elif event.type.value == "session.idle": + done.set() + + session.on(on_event) + + # Send a message and wait for completion + await session.send({"prompt": "What is 2+2?"}) + await done.wait() + + # Session automatically destroyed here + + # Client automatically stopped here + +asyncio.run(main()) +``` + +### Manual Resource Management + +You can also manage resources manually: + ```python import asyncio from copilot import CopilotClient @@ -65,6 +104,7 @@ asyncio.run(main()) - ✅ Session history with `get_messages()` - ✅ Type hints throughout - ✅ Async/await native +- ✅ Async context manager support for automatic resource cleanup ## API Reference @@ -149,6 +189,44 @@ unsubscribe() - `session.foreground` - A session became the foreground session in TUI - `session.background` - A session is no longer the foreground session +### Context Manager Support + +Both `CopilotClient` and `CopilotSession` support Python's async context manager protocol for automatic resource cleanup. This is the recommended pattern as it ensures resources are properly cleaned up even if exceptions occur. + +**CopilotClient Context Manager:** + +```python +async with CopilotClient() as client: + # Client automatically starts on enter + session = await client.create_session() + await session.send({"prompt": "Hello!"}) + # Client automatically stops on exit, cleaning up all sessions +``` + +**CopilotSession Context Manager:** + +```python +async with await client.create_session() as session: + await session.send({"prompt": "Hello!"}) + # Session automatically destroyed on exit +``` + +**Nested Context Managers:** + +```python +async with CopilotClient() as client: + async with await client.create_session() as session: + await session.send({"prompt": "Hello!"}) + # Session destroyed here +# Client stopped here +``` + +**Benefits:** +- Prevents resource leaks by ensuring cleanup even if exceptions occur +- Eliminates the need to manually call `stop()` and `destroy()` +- Follows Python best practices for resource management +- Particularly useful in batch operations and evaluations to prevent process accumulation + ### Tools Define tools with automatic JSON schema generation using the `@define_tool` decorator and Pydantic models: diff --git a/python/copilot/client.py b/python/copilot/client.py index 90260ffb..716901f2 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -21,6 +21,7 @@ import threading from dataclasses import asdict, is_dataclass from pathlib import Path +from types import TracebackType from typing import Any, Callable, Optional, cast from .generated.rpc import ServerRpc @@ -208,6 +209,54 @@ def __init__(self, options: Optional[CopilotClientOptions] = None): self._lifecycle_handlers_lock = threading.Lock() self._rpc: Optional[ServerRpc] = None + async def __aenter__(self) -> "CopilotClient": + """ + Enter the async context manager. + + Automatically starts the CLI server and establishes a connection if not + already connected. + + Returns: + The CopilotClient instance. + + Example: + >>> async with CopilotClient() as client: + ... session = await client.create_session() + ... await session.send({"prompt": "Hello!"}) + """ + await self.start() + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + """ + Exit the async context manager. + + Performs graceful cleanup by destroying all active sessions and stopping + the CLI server. If a cleanup error occurs and no exception was raised + inside the context, the cleanup error is propagated. If an exception + was already raised inside the context, the cleanup error is suppressed + so the original exception is not masked. + + Args: + exc_type: The type of exception that occurred, if any. + exc_val: The exception instance that occurred, if any. + exc_tb: The traceback of the exception that occurred, if any. + + Returns: + False to propagate any exception that occurred in the context. + """ + try: + await self.stop() + except Exception: + if exc_type is None: + raise + return False + @property def rpc(self) -> ServerRpc: """Typed server-scoped RPC methods.""" diff --git a/python/copilot/session.py b/python/copilot/session.py index af02a312..33a473ec 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -8,6 +8,7 @@ import asyncio import inspect import threading +from types import TracebackType from typing import Any, Callable, Optional from .generated.rpc import SessionRpc @@ -72,6 +73,7 @@ def __init__(self, session_id: str, client: Any, workspace_path: Optional[str] = self.session_id = session_id self._client = client self._workspace_path = workspace_path + self._destroyed = False self._event_handlers: set[Callable[[SessionEvent], None]] = set() self._event_handlers_lock = threading.Lock() self._tool_handlers: dict[str, ToolHandler] = {} @@ -84,6 +86,52 @@ def __init__(self, session_id: str, client: Any, workspace_path: Optional[str] = self._hooks_lock = threading.Lock() self._rpc: Optional[SessionRpc] = None + async def __aenter__(self) -> "CopilotSession": + """ + Enter the async context manager. + + Returns the session instance, ready for use. The session must already be + created (via CopilotClient.create_session or resume_session). + + Returns: + The CopilotSession instance. + + Example: + >>> async with await client.create_session() as session: + ... await session.send({"prompt": "Hello!"}) + """ + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + """ + Exit the async context manager. + + Automatically destroys the session and releases all associated resources. + If a cleanup error occurs and no exception was raised inside the context, + the cleanup error is propagated. If an exception was already raised inside + the context, the cleanup error is suppressed so the original exception is + not masked. + + Args: + exc_type: The type of exception that occurred, if any. + exc_val: The exception instance that occurred, if any. + exc_tb: The traceback of the exception that occurred, if any. + + Returns: + False to propagate any exception that occurred in the context. + """ + try: + await self.destroy() + except Exception: + if exc_type is None: + raise + return False + @property def rpc(self) -> SessionRpc: """Typed session-scoped RPC methods.""" @@ -483,20 +531,33 @@ async def destroy(self) -> None: handlers and tool handlers are cleared. To continue the conversation, use :meth:`CopilotClient.resume_session` with the session ID. + This method is idempotent—calling it multiple times is safe and will + not raise an error if the session is already destroyed. + Raises: - Exception: If the connection fails. + Exception: If the connection fails (on first destroy call). Example: >>> # Clean up when done >>> await session.destroy() """ - await self._client.request("session.destroy", {"sessionId": self.session_id}) + # Ensure that the check and update of _destroyed are atomic so that + # only the first caller proceeds to send the destroy RPC. with self._event_handlers_lock: - self._event_handlers.clear() - with self._tool_handlers_lock: - self._tool_handlers.clear() - with self._permission_handler_lock: - self._permission_handler = None + if self._destroyed: + return + self._destroyed = True + + try: + await self._client.request("session.destroy", {"sessionId": self.session_id}) + finally: + # Clear handlers even if the request fails + with self._event_handlers_lock: + self._event_handlers.clear() + with self._tool_handlers_lock: + self._tool_handlers.clear() + with self._permission_handler_lock: + self._permission_handler = None async def abort(self) -> None: """ diff --git a/python/test_client.py b/python/test_client.py index 0bc99ea6..cbeed50a 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -189,3 +189,46 @@ async def mock_request(method, params): assert captured["session.resume"]["clientName"] == "my-app" finally: await client.force_stop() + + +class TestContextManager: + @pytest.mark.asyncio + async def test_client_context_manager_returns_self(self): + """Test that __aenter__ returns the client instance.""" + client = CopilotClient({"cli_path": CLI_PATH}) + returned_client = await client.__aenter__() + assert returned_client is client + await client.force_stop() + + @pytest.mark.asyncio + async def test_client_aexit_returns_false(self): + """Test that __aexit__ returns False to propagate exceptions.""" + client = CopilotClient({"cli_path": CLI_PATH}) + await client.start() + result = await client.__aexit__(None, None, None) + assert result is False + + @pytest.mark.asyncio + async def test_session_context_manager_returns_self(self): + """Test that session __aenter__ returns the session instance.""" + client = CopilotClient({"cli_path": CLI_PATH}) + await client.start() + try: + session = await client.create_session() + returned_session = await session.__aenter__() + assert returned_session is session + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_session_aexit_returns_false(self): + """Test that session __aexit__ returns False to propagate exceptions.""" + client = CopilotClient({"cli_path": CLI_PATH}) + await client.start() + try: + session = await client.create_session() + result = await session.__aexit__(None, None, None) + assert result is False + finally: + await client.force_stop() +