diff --git a/src/openai_priority_loadbalancer/openai_priority_loadbalancer.py b/src/openai_priority_loadbalancer/openai_priority_loadbalancer.py index 324e68d..736098c 100644 --- a/src/openai_priority_loadbalancer/openai_priority_loadbalancer.py +++ b/src/openai_priority_loadbalancer/openai_priority_loadbalancer.py @@ -3,7 +3,7 @@ # Python Standard Library import logging import random -from typing import List +from typing import List, Union from datetime import datetime, MAXYEAR, MINYEAR, timedelta, timezone # Third-Party Libraries @@ -29,7 +29,7 @@ class BaseLoadBalancer(): """Logically abstracts the BaseLoadBalancer class which should be inherited by the synchronous and asynchronous load balancer classes.""" # Constructor - def __init__(self, transport: httpx.BaseTransport, backends: List[Backend]): + def __init__(self, transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport], backends: List[Backend]): # Public instance variables self.backends = backends @@ -39,6 +39,12 @@ def __init__(self, transport: httpx.BaseTransport, backends: List[Backend]): self._available_backends = 1 self._transport = transport + # Magic Methods + + # If a method in the BaseTransport or AsyncBaseTransport classes is not found, it will be looked up in base _transport object. + def __getattr__(self, name): + return getattr(self._transport, name) + # "Protected" Methods def _check_throttling(self) -> None: """Check if any backend is throttling and reset if necessary.""" diff --git a/tests/lib/test_openai_priority_loadbalancer.py b/tests/lib/test_openai_priority_loadbalancer.py index 681b598..bae260c 100644 --- a/tests/lib/test_openai_priority_loadbalancer.py +++ b/tests/lib/test_openai_priority_loadbalancer.py @@ -382,6 +382,26 @@ def test_loadbalancer_handle_4xx_failure(self, client_same_priority): # Assert that the final response status code was 400 assert response.status_code == 400 + @pytest.mark.loadbalancer + def test_loadbalancer_loadbalancer_close(self, client_same_priority): + client = client_same_priority + + # Create a mock response for the transport + mock_response = httpx.Response(200) + + with patch('httpx.Client.send', return_value = mock_response): + req = client._build_request(create_final_request_options()) + response = client._client._transport.handle_request(req) + + assert response.status_code == 200 + + # Assert that the transport is not closed, then close it, then assert that it was closed. + assert client._client._transport.is_closed is False + + client._client._transport.close() + + assert client._client._transport.is_closed is True + # Asynchronous Tests class TestAsynchronous: @@ -591,3 +611,24 @@ async def test_async_loadbalancer_handle_4xx_failure(self, async_client_same_pri # Assert that the final response status code was 400 assert response.status_code == 400 + + @pytest.mark.asyncio + @pytest.mark.async_loadbalancer + async def test_async_loadbalancer_close(self, async_client_same_priority): + client = async_client_same_priority + + # Create a mock response for the transport + mock_response = httpx.Response(200) + + with patch('httpx.AsyncClient.send', return_value = mock_response): + req = client._build_request(create_final_request_options()) + response = await client._client._transport.handle_async_request(req) + + assert response.status_code == 200 + + # Assert that the transport is not closed, then close it, then assert that it was closed. + assert client._client._transport.is_closed is False + + await client._client._transport.aclose() + + assert client._client._transport.is_closed is True