diff --git a/tests/pass_through_unit_tests/test_pass_through_unit_tests.py b/tests/pass_through_unit_tests/test_pass_through_unit_tests.py index c55bdc7a83c6..287a44e698a9 100644 --- a/tests/pass_through_unit_tests/test_pass_through_unit_tests.py +++ b/tests/pass_through_unit_tests/test_pass_through_unit_tests.py @@ -3,11 +3,13 @@ import sys from datetime import datetime from unittest.mock import AsyncMock, Mock, patch, MagicMock +from typing import Optional sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path +import fastapi import httpx import pytest import litellm @@ -21,6 +23,9 @@ PassThroughStreamingHandler, ) +from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( + pass_through_request, +) from fastapi import Request from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( @@ -33,9 +38,21 @@ @pytest.fixture def mock_request(): # Create a mock request with headers + class QueryParams: + def __init__(self): + self._dict = {} + class MockRequest: - def __init__(self, headers=None): + def __init__( + self, headers=None, method="POST", request_body: Optional[dict] = None + ): self.headers = headers or {} + self.query_params = QueryParams() + self.method = method + self.request_body = request_body or {} + + async def body(self) -> bytes: + return bytes(json.dumps(self.request_body), "utf-8") return MockRequest @@ -163,3 +180,85 @@ def test_init_kwargs_with_tags_in_header(mock_request, mock_user_api_key_dict): metadata = result["litellm_params"]["metadata"] print("metadata", metadata) assert metadata["tags"] == ["tag1", "tag2"] + + +athropic_request_body = { + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 256, + "messages": [{"role": "user", "content": "Hello, world tell me 2 sentences "}], + "litellm_metadata": {"tags": ["hi", "hello"]}, +} + + +@pytest.mark.asyncio +async def test_pass_through_request_logging_failure( + mock_request, mock_user_api_key_dict +): + """ + Test that pass_through_request still returns a response even if logging raises an Exception + """ + + # Mock the logging handler to raise an error + async def mock_logging_failure(*args, **kwargs): + raise Exception("Logging failed!") + + # Patch only the logging handler + with patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.PassThroughEndpointLogging.pass_through_async_success_handler", + new=mock_logging_failure, + ): + request = mock_request( + headers={}, method="POST", request_body=athropic_request_body + ) + response = await pass_through_request( + request=request, + target="https://exampleopenaiendpoint-production.up.railway.app/v1/messages", + custom_headers={}, + user_api_key_dict=mock_user_api_key_dict, + ) + + # Assert response was returned successfully despite logging failure + assert response.status_code == 200 + print("response", response) + print(vars(response)) + + +@pytest.mark.asyncio +async def test_pass_through_request_logging_failure_with_stream( + mock_request, mock_user_api_key_dict +): + """ + Test that pass_through_request still returns a response even if logging raises an Exception + """ + + # Mock the logging handler to raise an error + async def mock_logging_failure(*args, **kwargs): + raise Exception("Logging failed!") + + athropic_request_body["stream"] = True + # Patch only the logging handler + with patch( + "litellm.proxy.pass_through_endpoints.streaming_handler.PassThroughStreamingHandler._route_streaming_logging_to_handler", + new=mock_logging_failure, + ): + request = mock_request( + headers={}, method="POST", request_body=athropic_request_body + ) + response = await pass_through_request( + request=request, + target="https://exampleopenaiendpoint-production.up.railway.app/v1/messages", + custom_headers={}, + user_api_key_dict=mock_user_api_key_dict, + ) + + # Assert response was returned successfully despite logging failure + assert response.status_code == 200 + + print(vars(response)) + print(dir(response)) + body_iterator = response.body_iterator + async for chunk in body_iterator: + assert chunk + + print("response", response) + print(vars(response))