Skip to content

Commit

Permalink
add tests for llms
Browse files Browse the repository at this point in the history
  • Loading branch information
ashraq1455 committed Jan 6, 2024
1 parent 5141474 commit 548bd40
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 3 deletions.
3 changes: 2 additions & 1 deletion semantic_router/llms/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pydantic import BaseModel
from semantic_router.schema import Message


class BaseLLM(BaseModel):
Expand All @@ -7,5 +8,5 @@ class BaseLLM(BaseModel):
class Config:
arbitrary_types_allowed = True

def __call__(self, prompt) -> str | None:
def __call__(self, messages: list[Message]) -> str | None:
raise NotImplementedError("Subclasses must implement this method")
16 changes: 16 additions & 0 deletions tests/unit/llms/test_llm_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from semantic_router.llms import BaseLLM


class TestBaseLLM:
@pytest.fixture
def base_llm(self):
return BaseLLM(name="TestLLM")

def test_base_llm_initialization(self, base_llm):
assert base_llm.name == "TestLLM", "Initialization of name failed"

def test_base_llm_call_method_not_implemented(self, base_llm):
with pytest.raises(NotImplementedError):
base_llm("test")
52 changes: 52 additions & 0 deletions tests/unit/llms/test_llm_cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest

from semantic_router.llms import Cohere
from semantic_router.schema import Message


@pytest.fixture
def cohere_llm(mocker):
mocker.patch("cohere.Client")
return Cohere(cohere_api_key="test_api_key")


class TestCohereLLM:
def test_initialization_with_api_key(self, cohere_llm):
assert cohere_llm.client is not None, "Client should be initialized"
assert cohere_llm.name == "command", "Default name not set correctly"

def test_initialization_without_api_key(self, mocker, monkeypatch):
monkeypatch.delenv("COHERE_API_KEY", raising=False)
mocker.patch("cohere.Client")
with pytest.raises(ValueError):
Cohere()

def test_call_method(self, cohere_llm, mocker):
mock_llm = mocker.MagicMock()
mock_llm.text = "test"
cohere_llm.client.chat.return_value = mock_llm

llm_input = [Message(role="user", content="test")]
result = cohere_llm(llm_input)
assert isinstance(result, str), "Result should be a str"
cohere_llm.client.chat.assert_called_once()

def test_raises_value_error_if_cohere_client_fails_to_initialize(self, mocker):
mocker.patch(
"cohere.Client", side_effect=Exception("Failed to initialize client")
)
with pytest.raises(ValueError):
Cohere(cohere_api_key="test_api_key")

def test_raises_value_error_if_cohere_client_is_not_initialized(self, mocker):
mocker.patch("cohere.Client", return_value=None)
llm = Cohere(cohere_api_key="test_api_key")
with pytest.raises(ValueError):
llm("test")

def test_call_method_raises_error_on_api_failure(self, cohere_llm, mocker):
mocker.patch.object(
cohere_llm.client, "__call__", side_effect=Exception("API call failed")
)
with pytest.raises(ValueError):
cohere_llm("test")
55 changes: 55 additions & 0 deletions tests/unit/llms/test_llm_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest
from semantic_router.llms import OpenAI
from semantic_router.schema import Message


@pytest.fixture
def openai_llm(mocker):
mocker.patch("openai.Client")
return OpenAI(openai_api_key="test_api_key")


class TestOpenAILLM:
def test_openai_llm_init_with_api_key(self, openai_llm):
assert openai_llm.client is not None, "Client should be initialized"
assert openai_llm.name == "gpt-3.5-turbo", "Default name not set correctly"

def test_openai_llm_init_success(self, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
llm = OpenAI()
assert llm.client is not None

def test_openai_llm_init_without_api_key(self, mocker):
mocker.patch("os.getenv", return_value=None)
with pytest.raises(ValueError) as _:
OpenAI()

def test_openai_llm_call_uninitialized_client(self, openai_llm):
# Set the client to None to simulate an uninitialized client
openai_llm.client = None
with pytest.raises(ValueError) as e:
llm_input = [Message(role="user", content="test")]
openai_llm(llm_input)
assert "OpenAI client is not initialized." in str(e.value)

def test_openai_llm_init_exception(self, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("openai.OpenAI", side_effect=Exception("Initialization error"))
with pytest.raises(ValueError) as e:
OpenAI()
assert (
"OpenAI API client failed to initialize. Error: Initialization error"
in str(e.value)
)

def test_openai_llm_call_success(self, openai_llm, mocker):
mock_completion = mocker.MagicMock()
mock_completion.choices[0].message.content = "test"

mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch.object(
openai_llm.client.chat.completions, "create", return_value=mock_completion
)
llm_input = [Message(role="user", content="test")]
output = openai_llm(llm_input)
assert output == "test"
59 changes: 59 additions & 0 deletions tests/unit/llms/test_llm_openrouter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest
from semantic_router.llms import OpenRouter
from semantic_router.schema import Message


@pytest.fixture
def openrouter_llm(mocker):
mocker.patch("openai.Client")
return OpenRouter(openrouter_api_key="test_api_key")


class TestOpenRouterLLM:
def test_openrouter_llm_init_with_api_key(self, openrouter_llm):
assert openrouter_llm.client is not None, "Client should be initialized"
assert (
openrouter_llm.name == "mistralai/mistral-7b-instruct"
), "Default name not set correctly"

def test_openrouter_llm_init_success(self, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
llm = OpenRouter()
assert llm.client is not None

def test_openrouter_llm_init_without_api_key(self, mocker):
mocker.patch("os.getenv", return_value=None)
with pytest.raises(ValueError) as _:
OpenRouter()

def test_openrouter_llm_call_uninitialized_client(self, openrouter_llm):
# Set the client to None to simulate an uninitialized client
openrouter_llm.client = None
with pytest.raises(ValueError) as e:
llm_input = [Message(role="user", content="test")]
openrouter_llm(llm_input)
assert "OpenRouter client is not initialized." in str(e.value)

def test_openrouter_llm_init_exception(self, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("openai.OpenAI", side_effect=Exception("Initialization error"))
with pytest.raises(ValueError) as e:
OpenRouter()
assert (
"OpenRouter API client failed to initialize. Error: Initialization error"
in str(e.value)
)

def test_openrouter_llm_call_success(self, openrouter_llm, mocker):
mock_completion = mocker.MagicMock()
mock_completion.choices[0].message.content = "test"

mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch.object(
openrouter_llm.client.chat.completions,
"create",
return_value=mock_completion,
)
llm_input = [Message(role="user", content="test")]
output = openrouter_llm(llm_input)
assert output == "test"
17 changes: 16 additions & 1 deletion tests/unit/test_route.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from unittest.mock import patch # , AsyncMock

# import pytest
import pytest
from semantic_router.llms import BaseLLM
from semantic_router.route import Route, is_valid

Expand Down Expand Up @@ -61,6 +61,21 @@ def __call__(self, prompt):


class TestRoute:
def test_value_error_in_route_call(self):
function_schema = {"name": "test_function", "type": "function"}

route = Route(
name="test_function",
utterances=["utterance1", "utterance2"],
function_schema=function_schema,
)

with pytest.raises(
ValueError,
match="LLM is required for dynamic routes. Please ensure the 'llm' is set.",
):
route("test_query")

def test_generate_dynamic_route(self):
mock_llm = MockLLM(name="test")
function_schema = {"name": "test_function", "type": "function"}
Expand Down
27 changes: 26 additions & 1 deletion tests/unit/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pytest

from pydantic import ValidationError
from semantic_router.schema import (
CohereEncoder,
Encoder,
EncoderType,
Message,
OpenAIEncoder,
)

Expand Down Expand Up @@ -38,3 +39,27 @@ def test_encoder_call_method(self, mocker):
encoder = Encoder(type="openai", name="test-engine")
result = encoder(["test"])
assert result == [0.1, 0.2, 0.3]


class TestMessageDataclass:
def test_message_creation(self):
message = Message(role="user", content="Hello!")
assert message.role == "user"
assert message.content == "Hello!"

with pytest.raises(ValidationError):
Message(user_role="invalid_role", message="Hello!")

def test_message_to_openai(self):
message = Message(role="user", content="Hello!")
openai_format = message.to_openai()
assert openai_format == {"role": "user", "content": "Hello!"}

message = Message(role="invalid_role", content="Hello!")
with pytest.raises(ValueError):
message.to_openai()

def test_message_to_cohere(self):
message = Message(role="user", content="Hello!")
cohere_format = message.to_cohere()
assert cohere_format == {"role": "user", "message": "Hello!"}

0 comments on commit 548bd40

Please sign in to comment.