Skip to content

Commit

Permalink
test_pass_through_request_logging_failure_with_stream
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaan-jaff committed Nov 26, 2024
1 parent 068f1af commit 904ece6
Showing 1 changed file with 100 additions and 1 deletion.
101 changes: 100 additions & 1 deletion tests/pass_through_unit_tests/test_pass_through_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import sys
from datetime import datetime
from unittest.mock import AsyncMock, Mock, patch, MagicMock
from typing import Optional

sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path

import fastapi
import httpx
import pytest
import litellm
Expand All @@ -21,6 +23,9 @@
PassThroughStreamingHandler,
)

from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
pass_through_request,
)
from fastapi import Request
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
Expand All @@ -33,9 +38,21 @@
@pytest.fixture
def mock_request():
# Create a mock request with headers
class QueryParams:
def __init__(self):
self._dict = {}

class MockRequest:
def __init__(self, headers=None):
def __init__(
self, headers=None, method="POST", request_body: Optional[dict] = None
):
self.headers = headers or {}
self.query_params = QueryParams()
self.method = method
self.request_body = request_body or {}

async def body(self) -> bytes:
return bytes(json.dumps(self.request_body), "utf-8")

return MockRequest

Expand Down Expand Up @@ -163,3 +180,85 @@ def test_init_kwargs_with_tags_in_header(mock_request, mock_user_api_key_dict):
metadata = result["litellm_params"]["metadata"]
print("metadata", metadata)
assert metadata["tags"] == ["tag1", "tag2"]


athropic_request_body = {
"model": "claude-3-5-sonnet-20241022",
"max_tokens": 256,
"messages": [{"role": "user", "content": "Hello, world tell me 2 sentences "}],
"litellm_metadata": {"tags": ["hi", "hello"]},
}


@pytest.mark.asyncio
async def test_pass_through_request_logging_failure(
mock_request, mock_user_api_key_dict
):
"""
Test that pass_through_request still returns a response even if logging raises an Exception
"""

# Mock the logging handler to raise an error
async def mock_logging_failure(*args, **kwargs):
raise Exception("Logging failed!")

# Patch only the logging handler
with patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.PassThroughEndpointLogging.pass_through_async_success_handler",
new=mock_logging_failure,
):
request = mock_request(
headers={}, method="POST", request_body=athropic_request_body
)
response = await pass_through_request(
request=request,
target="https://exampleopenaiendpoint-production.up.railway.app/v1/messages",
custom_headers={},
user_api_key_dict=mock_user_api_key_dict,
)

# Assert response was returned successfully despite logging failure
assert response.status_code == 200
print("response", response)
print(vars(response))


@pytest.mark.asyncio
async def test_pass_through_request_logging_failure_with_stream(
mock_request, mock_user_api_key_dict
):
"""
Test that pass_through_request still returns a response even if logging raises an Exception
"""

# Mock the logging handler to raise an error
async def mock_logging_failure(*args, **kwargs):
raise Exception("Logging failed!")

athropic_request_body["stream"] = True
# Patch only the logging handler
with patch(
"litellm.proxy.pass_through_endpoints.streaming_handler.PassThroughStreamingHandler._route_streaming_logging_to_handler",
new=mock_logging_failure,
):
request = mock_request(
headers={}, method="POST", request_body=athropic_request_body
)
response = await pass_through_request(
request=request,
target="https://exampleopenaiendpoint-production.up.railway.app/v1/messages",
custom_headers={},
user_api_key_dict=mock_user_api_key_dict,
)

# Assert response was returned successfully despite logging failure
assert response.status_code == 200

print(vars(response))
print(dir(response))
body_iterator = response.body_iterator
async for chunk in body_iterator:
assert chunk

print("response", response)
print(vars(response))

0 comments on commit 904ece6

Please sign in to comment.