Skip to content

Commit 416bda2

Browse files
Add support for sync/async base transport properties & methods (#38)
* Add support for sync/async base transport properties & methods Fixes #36 * Use Union for Python <3.10 compatibility
1 parent 1c85dc0 commit 416bda2

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

src/openai_priority_loadbalancer/openai_priority_loadbalancer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Python Standard Library
44
import logging
55
import random
6-
from typing import List
6+
from typing import List, Union
77
from datetime import datetime, MAXYEAR, MINYEAR, timedelta, timezone
88

99
# Third-Party Libraries
@@ -29,7 +29,7 @@ class BaseLoadBalancer():
2929
"""Logically abstracts the BaseLoadBalancer class which should be inherited by the synchronous and asynchronous load balancer classes."""
3030

3131
# Constructor
32-
def __init__(self, transport: httpx.BaseTransport, backends: List[Backend]):
32+
def __init__(self, transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport], backends: List[Backend]):
3333
# Public instance variables
3434
self.backends = backends
3535

@@ -39,6 +39,12 @@ def __init__(self, transport: httpx.BaseTransport, backends: List[Backend]):
3939
self._available_backends = 1
4040
self._transport = transport
4141

42+
# Magic Methods
43+
44+
# If a method in the BaseTransport or AsyncBaseTransport classes is not found, it will be looked up in base _transport object.
45+
def __getattr__(self, name):
46+
return getattr(self._transport, name)
47+
4248
# "Protected" Methods
4349
def _check_throttling(self) -> None:
4450
"""Check if any backend is throttling and reset if necessary."""

tests/lib/test_openai_priority_loadbalancer.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,26 @@ def test_loadbalancer_handle_4xx_failure(self, client_same_priority):
382382
# Assert that the final response status code was 400
383383
assert response.status_code == 400
384384

385+
@pytest.mark.loadbalancer
386+
def test_loadbalancer_loadbalancer_close(self, client_same_priority):
387+
client = client_same_priority
388+
389+
# Create a mock response for the transport
390+
mock_response = httpx.Response(200)
391+
392+
with patch('httpx.Client.send', return_value = mock_response):
393+
req = client._build_request(create_final_request_options())
394+
response = client._client._transport.handle_request(req)
395+
396+
assert response.status_code == 200
397+
398+
# Assert that the transport is not closed, then close it, then assert that it was closed.
399+
assert client._client._transport.is_closed is False
400+
401+
client._client._transport.close()
402+
403+
assert client._client._transport.is_closed is True
404+
385405
# Asynchronous Tests
386406

387407
class TestAsynchronous:
@@ -591,3 +611,24 @@ async def test_async_loadbalancer_handle_4xx_failure(self, async_client_same_pri
591611

592612
# Assert that the final response status code was 400
593613
assert response.status_code == 400
614+
615+
@pytest.mark.asyncio
616+
@pytest.mark.async_loadbalancer
617+
async def test_async_loadbalancer_close(self, async_client_same_priority):
618+
client = async_client_same_priority
619+
620+
# Create a mock response for the transport
621+
mock_response = httpx.Response(200)
622+
623+
with patch('httpx.AsyncClient.send', return_value = mock_response):
624+
req = client._build_request(create_final_request_options())
625+
response = await client._client._transport.handle_async_request(req)
626+
627+
assert response.status_code == 200
628+
629+
# Assert that the transport is not closed, then close it, then assert that it was closed.
630+
assert client._client._transport.is_closed is False
631+
632+
await client._client._transport.aclose()
633+
634+
assert client._client._transport.is_closed is True

0 commit comments

Comments
 (0)