Skip to content

Commit

Permalink
(fix) pass through endpoints - run logging async + use thread pool ex…
Browse files Browse the repository at this point in the history
…ecutor for sync logging callbacks (#6907)

* run pass through logging async

* fix use thread_pool_executor for pass through logging

* test_pass_through_request_logging_failure_with_stream

* fix anthropic pt logging test

* test_pass_through_request_logging_failure
  • Loading branch information
ishaan-jaff authored Nov 26, 2024
1 parent d52aae4 commit 552c0dd
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 33 deletions.
22 changes: 12 additions & 10 deletions litellm/proxy/pass_through_endpoints/pass_through_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,16 +529,18 @@ async def pass_through_request( # noqa: PLR0915
response_body: Optional[dict] = get_response_body(response)
passthrough_logging_payload["response_body"] = response_body
end_time = datetime.now()
await pass_through_endpoint_logging.pass_through_async_success_handler(
httpx_response=response,
response_body=response_body,
url_route=str(url),
result="",
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
cache_hit=False,
**kwargs,
asyncio.create_task(
pass_through_endpoint_logging.pass_through_async_success_handler(
httpx_response=response,
response_body=response_body,
url_route=str(url),
result="",
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
cache_hit=False,
**kwargs,
)
)

return Response(
Expand Down
34 changes: 18 additions & 16 deletions litellm/proxy/pass_through_endpoints/streaming_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,17 @@ async def chunk_processor(
# After all chunks are processed, handle post-processing
end_time = datetime.now()

await PassThroughStreamingHandler._route_streaming_logging_to_handler(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body or {},
endpoint_type=endpoint_type,
start_time=start_time,
raw_bytes=raw_bytes,
end_time=end_time,
asyncio.create_task(
PassThroughStreamingHandler._route_streaming_logging_to_handler(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body or {},
endpoint_type=endpoint_type,
start_time=start_time,
raw_bytes=raw_bytes,
end_time=end_time,
)
)
except Exception as e:
verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}")
Expand Down Expand Up @@ -108,9 +110,9 @@ async def _route_streaming_logging_to_handler(
all_chunks=all_chunks,
end_time=end_time,
)
standard_logging_response_object = anthropic_passthrough_logging_handler_result[
"result"
]
standard_logging_response_object = (
anthropic_passthrough_logging_handler_result["result"]
)
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
elif endpoint_type == EndpointType.VERTEX_AI:
vertex_passthrough_logging_handler_result = (
Expand All @@ -125,9 +127,9 @@ async def _route_streaming_logging_to_handler(
end_time=end_time,
)
)
standard_logging_response_object = vertex_passthrough_logging_handler_result[
"result"
]
standard_logging_response_object = (
vertex_passthrough_logging_handler_result["result"]
)
kwargs = vertex_passthrough_logging_handler_result["kwargs"]

if standard_logging_response_object is None:
Expand Down Expand Up @@ -168,4 +170,4 @@ def _convert_raw_bytes_to_str_lines(raw_bytes: List[bytes]) -> List[str]:
# Split by newlines and filter out empty lines
lines = [line.strip() for line in combined_str.split("\n") if line.strip()]

return lines
return lines
8 changes: 5 additions & 3 deletions litellm/proxy/pass_through_endpoints/success_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.utils import StandardPassThroughResponseObject
from litellm.utils import executor as thread_pool_executor

from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
AnthropicPassthroughLoggingHandler,
Expand Down Expand Up @@ -93,15 +94,16 @@ async def pass_through_async_success_handler(
standard_logging_response_object = StandardPassThroughResponseObject(
response=httpx_response.text
)
threading.Thread(
target=logging_obj.success_handler,
thread_pool_executor.submit(
logging_obj.success_handler,
args=(
standard_logging_response_object,
start_time,
end_time,
cache_hit,
),
).start()
)

await logging_obj.async_success_handler(
result=(
json.dumps(result)
Expand Down
3 changes: 2 additions & 1 deletion litellm/proxy/proxy_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ router_settings:
redis_password: os.environ/REDIS_PASSWORD

litellm_settings:
callbacks: ["prometheus"]
callbacks: ["prometheus"]
success_callback: ["langfuse"]
8 changes: 6 additions & 2 deletions tests/pass_through_tests/test_anthropic_passthrough.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ async def test_anthropic_basic_completion_with_headers():
), "Start time should be before end time"

# Metadata assertions
assert log_entry["cache_hit"] == "False", "Cache should be off"
assert (
str(log_entry["cache_hit"]).lower() != "true"
), "Cache should be off"
assert log_entry["request_tags"] == [
"test-tag-1",
"test-tag-2",
Expand Down Expand Up @@ -251,7 +253,9 @@ async def test_anthropic_streaming_with_headers():
), "Start time should be before end time"

# Metadata assertions
assert log_entry["cache_hit"] == "False", "Cache should be off"
assert (
str(log_entry["cache_hit"]).lower() != "true"
), "Cache should be off"
assert log_entry["request_tags"] == [
"test-tag-stream-1",
"test-tag-stream-2",
Expand Down
159 changes: 158 additions & 1 deletion tests/pass_through_unit_tests/test_pass_through_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import sys
from datetime import datetime
from unittest.mock import AsyncMock, Mock, patch, MagicMock
from typing import Optional

sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path

import fastapi
import httpx
import pytest
import litellm
Expand All @@ -21,6 +23,9 @@
PassThroughStreamingHandler,
)

from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
pass_through_request,
)
from fastapi import Request
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
Expand All @@ -33,9 +38,21 @@
@pytest.fixture
def mock_request():
# Create a mock request with headers
class QueryParams:
def __init__(self):
self._dict = {}

class MockRequest:
def __init__(self, headers=None):
def __init__(
self, headers=None, method="POST", request_body: Optional[dict] = None
):
self.headers = headers or {}
self.query_params = QueryParams()
self.method = method
self.request_body = request_body or {}

async def body(self) -> bytes:
return bytes(json.dumps(self.request_body), "utf-8")

return MockRequest

Expand Down Expand Up @@ -163,3 +180,143 @@ def test_init_kwargs_with_tags_in_header(mock_request, mock_user_api_key_dict):
metadata = result["litellm_params"]["metadata"]
print("metadata", metadata)
assert metadata["tags"] == ["tag1", "tag2"]


athropic_request_body = {
"model": "claude-3-5-sonnet-20241022",
"max_tokens": 256,
"messages": [{"role": "user", "content": "Hello, world tell me 2 sentences "}],
"litellm_metadata": {"tags": ["hi", "hello"]},
}


@pytest.mark.asyncio
async def test_pass_through_request_logging_failure(
mock_request, mock_user_api_key_dict
):
"""
Test that pass_through_request still returns a response even if logging raises an Exception
"""

# Mock the logging handler to raise an error
async def mock_logging_failure(*args, **kwargs):
raise Exception("Logging failed!")

# Create a mock response
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.headers = {"content-type": "application/json"}

# Add mock content
mock_response._content = b'{"mock": "response"}'

async def mock_aread():
return mock_response._content

mock_response.aread = mock_aread

# Patch both the logging handler and the httpx client
with patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.PassThroughEndpointLogging.pass_through_async_success_handler",
new=mock_logging_failure,
), patch(
"httpx.AsyncClient.send",
return_value=mock_response,
), patch(
"httpx.AsyncClient.request",
return_value=mock_response,
):
request = mock_request(
headers={}, method="POST", request_body=athropic_request_body
)
response = await pass_through_request(
request=request,
target="https://exampleopenaiendpoint-production.up.railway.app/v1/messages",
custom_headers={},
user_api_key_dict=mock_user_api_key_dict,
)

# Assert response was returned successfully despite logging failure
assert response.status_code == 200

# Verify we got the mock response content
if hasattr(response, "body"):
content = response.body
else:
content = await response.aread()

assert content == b'{"mock": "response"}'


@pytest.mark.asyncio
async def test_pass_through_request_logging_failure_with_stream(
mock_request, mock_user_api_key_dict
):
"""
Test that pass_through_request still returns a response even if logging raises an Exception
"""

# Mock the logging handler to raise an error
async def mock_logging_failure(*args, **kwargs):
raise Exception("Logging failed!")

# Create a mock response
mock_response = AsyncMock()
mock_response.status_code = 200

# Add headers property to mock response
mock_response.headers = {
"content-type": "application/json", # Not streaming
}

# Create mock chunks for streaming
mock_chunks = [b'{"chunk": 1}', b'{"chunk": 2}']
mock_response.body_iterator = AsyncMock()
mock_response.body_iterator.__aiter__.return_value = mock_chunks

# Add aread method to mock response
mock_response._content = b'{"mock": "response"}'

async def mock_aread():
return mock_response._content

mock_response.aread = mock_aread

# Patch both the logging handler and the httpx client
with patch(
"litellm.proxy.pass_through_endpoints.streaming_handler.PassThroughStreamingHandler._route_streaming_logging_to_handler",
new=mock_logging_failure,
), patch(
"httpx.AsyncClient.send",
return_value=mock_response,
), patch(
"httpx.AsyncClient.request",
return_value=mock_response,
):
request = mock_request(
headers={}, method="POST", request_body=athropic_request_body
)
response = await pass_through_request(
request=request,
target="https://exampleopenaiendpoint-production.up.railway.app/v1/messages",
custom_headers={},
user_api_key_dict=mock_user_api_key_dict,
)

# Assert response was returned successfully despite logging failure
assert response.status_code == 200

# For non-streaming responses, we can access the content directly
if hasattr(response, "body"):
content = response.body
else:
# For streaming responses, we need to read the chunks
chunks = []
async for chunk in response.body_iterator:
chunks.append(chunk)
content = b"".join(chunks)

# Verify we got some response content
assert content is not None
if isinstance(content, bytes):
assert len(content) > 0

0 comments on commit 552c0dd

Please sign in to comment.