Skip to content

Commit

Permalink
Add support for sync/async base transport properties & methods (#38)
Browse files Browse the repository at this point in the history
* Add support for sync/async base transport properties & methods

Fixes #36

* Use Union for Python <3.10 compatibility
  • Loading branch information
simonkurtz-MSFT authored Oct 17, 2024
1 parent 1c85dc0 commit 416bda2
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/openai_priority_loadbalancer/openai_priority_loadbalancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Python Standard Library
import logging
import random
from typing import List
from typing import List, Union
from datetime import datetime, MAXYEAR, MINYEAR, timedelta, timezone

# Third-Party Libraries
Expand All @@ -29,7 +29,7 @@ class BaseLoadBalancer():
"""Logically abstracts the BaseLoadBalancer class which should be inherited by the synchronous and asynchronous load balancer classes."""

# Constructor
def __init__(self, transport: httpx.BaseTransport, backends: List[Backend]):
def __init__(self, transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport], backends: List[Backend]):
# Public instance variables
self.backends = backends

Expand All @@ -39,6 +39,12 @@ def __init__(self, transport: httpx.BaseTransport, backends: List[Backend]):
self._available_backends = 1
self._transport = transport

# Magic Methods

# If a method in the BaseTransport or AsyncBaseTransport classes is not found, it will be looked up in base _transport object.
def __getattr__(self, name):
return getattr(self._transport, name)

# "Protected" Methods
def _check_throttling(self) -> None:
"""Check if any backend is throttling and reset if necessary."""
Expand Down
41 changes: 41 additions & 0 deletions tests/lib/test_openai_priority_loadbalancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,26 @@ def test_loadbalancer_handle_4xx_failure(self, client_same_priority):
# Assert that the final response status code was 400
assert response.status_code == 400

@pytest.mark.loadbalancer
def test_loadbalancer_loadbalancer_close(self, client_same_priority):
client = client_same_priority

# Create a mock response for the transport
mock_response = httpx.Response(200)

with patch('httpx.Client.send', return_value = mock_response):
req = client._build_request(create_final_request_options())
response = client._client._transport.handle_request(req)

assert response.status_code == 200

# Assert that the transport is not closed, then close it, then assert that it was closed.
assert client._client._transport.is_closed is False

client._client._transport.close()

assert client._client._transport.is_closed is True

# Asynchronous Tests

class TestAsynchronous:
Expand Down Expand Up @@ -591,3 +611,24 @@ async def test_async_loadbalancer_handle_4xx_failure(self, async_client_same_pri

# Assert that the final response status code was 400
assert response.status_code == 400

@pytest.mark.asyncio
@pytest.mark.async_loadbalancer
async def test_async_loadbalancer_close(self, async_client_same_priority):
client = async_client_same_priority

# Create a mock response for the transport
mock_response = httpx.Response(200)

with patch('httpx.AsyncClient.send', return_value = mock_response):
req = client._build_request(create_final_request_options())
response = await client._client._transport.handle_async_request(req)

assert response.status_code == 200

# Assert that the transport is not closed, then close it, then assert that it was closed.
assert client._client._transport.is_closed is False

await client._client._transport.aclose()

assert client._client._transport.is_closed is True

0 comments on commit 416bda2

Please sign in to comment.