-
Notifications
You must be signed in to change notification settings - Fork 239
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5141474
commit 548bd40
Showing
7 changed files
with
226 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import pytest | ||
|
||
from semantic_router.llms import BaseLLM | ||
|
||
|
||
class TestBaseLLM: | ||
@pytest.fixture | ||
def base_llm(self): | ||
return BaseLLM(name="TestLLM") | ||
|
||
def test_base_llm_initialization(self, base_llm): | ||
assert base_llm.name == "TestLLM", "Initialization of name failed" | ||
|
||
def test_base_llm_call_method_not_implemented(self, base_llm): | ||
with pytest.raises(NotImplementedError): | ||
base_llm("test") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import pytest | ||
|
||
from semantic_router.llms import Cohere | ||
from semantic_router.schema import Message | ||
|
||
|
||
@pytest.fixture | ||
def cohere_llm(mocker): | ||
mocker.patch("cohere.Client") | ||
return Cohere(cohere_api_key="test_api_key") | ||
|
||
|
||
class TestCohereLLM: | ||
def test_initialization_with_api_key(self, cohere_llm): | ||
assert cohere_llm.client is not None, "Client should be initialized" | ||
assert cohere_llm.name == "command", "Default name not set correctly" | ||
|
||
def test_initialization_without_api_key(self, mocker, monkeypatch): | ||
monkeypatch.delenv("COHERE_API_KEY", raising=False) | ||
mocker.patch("cohere.Client") | ||
with pytest.raises(ValueError): | ||
Cohere() | ||
|
||
def test_call_method(self, cohere_llm, mocker): | ||
mock_llm = mocker.MagicMock() | ||
mock_llm.text = "test" | ||
cohere_llm.client.chat.return_value = mock_llm | ||
|
||
llm_input = [Message(role="user", content="test")] | ||
result = cohere_llm(llm_input) | ||
assert isinstance(result, str), "Result should be a str" | ||
cohere_llm.client.chat.assert_called_once() | ||
|
||
def test_raises_value_error_if_cohere_client_fails_to_initialize(self, mocker): | ||
mocker.patch( | ||
"cohere.Client", side_effect=Exception("Failed to initialize client") | ||
) | ||
with pytest.raises(ValueError): | ||
Cohere(cohere_api_key="test_api_key") | ||
|
||
def test_raises_value_error_if_cohere_client_is_not_initialized(self, mocker): | ||
mocker.patch("cohere.Client", return_value=None) | ||
llm = Cohere(cohere_api_key="test_api_key") | ||
with pytest.raises(ValueError): | ||
llm("test") | ||
|
||
def test_call_method_raises_error_on_api_failure(self, cohere_llm, mocker): | ||
mocker.patch.object( | ||
cohere_llm.client, "__call__", side_effect=Exception("API call failed") | ||
) | ||
with pytest.raises(ValueError): | ||
cohere_llm("test") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import pytest | ||
from semantic_router.llms import OpenAI | ||
from semantic_router.schema import Message | ||
|
||
|
||
@pytest.fixture | ||
def openai_llm(mocker): | ||
mocker.patch("openai.Client") | ||
return OpenAI(openai_api_key="test_api_key") | ||
|
||
|
||
class TestOpenAILLM: | ||
def test_openai_llm_init_with_api_key(self, openai_llm): | ||
assert openai_llm.client is not None, "Client should be initialized" | ||
assert openai_llm.name == "gpt-3.5-turbo", "Default name not set correctly" | ||
|
||
def test_openai_llm_init_success(self, mocker): | ||
mocker.patch("os.getenv", return_value="fake-api-key") | ||
llm = OpenAI() | ||
assert llm.client is not None | ||
|
||
def test_openai_llm_init_without_api_key(self, mocker): | ||
mocker.patch("os.getenv", return_value=None) | ||
with pytest.raises(ValueError) as _: | ||
OpenAI() | ||
|
||
def test_openai_llm_call_uninitialized_client(self, openai_llm): | ||
# Set the client to None to simulate an uninitialized client | ||
openai_llm.client = None | ||
with pytest.raises(ValueError) as e: | ||
llm_input = [Message(role="user", content="test")] | ||
openai_llm(llm_input) | ||
assert "OpenAI client is not initialized." in str(e.value) | ||
|
||
def test_openai_llm_init_exception(self, mocker): | ||
mocker.patch("os.getenv", return_value="fake-api-key") | ||
mocker.patch("openai.OpenAI", side_effect=Exception("Initialization error")) | ||
with pytest.raises(ValueError) as e: | ||
OpenAI() | ||
assert ( | ||
"OpenAI API client failed to initialize. Error: Initialization error" | ||
in str(e.value) | ||
) | ||
|
||
def test_openai_llm_call_success(self, openai_llm, mocker): | ||
mock_completion = mocker.MagicMock() | ||
mock_completion.choices[0].message.content = "test" | ||
|
||
mocker.patch("os.getenv", return_value="fake-api-key") | ||
mocker.patch.object( | ||
openai_llm.client.chat.completions, "create", return_value=mock_completion | ||
) | ||
llm_input = [Message(role="user", content="test")] | ||
output = openai_llm(llm_input) | ||
assert output == "test" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import pytest | ||
from semantic_router.llms import OpenRouter | ||
from semantic_router.schema import Message | ||
|
||
|
||
@pytest.fixture | ||
def openrouter_llm(mocker): | ||
mocker.patch("openai.Client") | ||
return OpenRouter(openrouter_api_key="test_api_key") | ||
|
||
|
||
class TestOpenRouterLLM: | ||
def test_openrouter_llm_init_with_api_key(self, openrouter_llm): | ||
assert openrouter_llm.client is not None, "Client should be initialized" | ||
assert ( | ||
openrouter_llm.name == "mistralai/mistral-7b-instruct" | ||
), "Default name not set correctly" | ||
|
||
def test_openrouter_llm_init_success(self, mocker): | ||
mocker.patch("os.getenv", return_value="fake-api-key") | ||
llm = OpenRouter() | ||
assert llm.client is not None | ||
|
||
def test_openrouter_llm_init_without_api_key(self, mocker): | ||
mocker.patch("os.getenv", return_value=None) | ||
with pytest.raises(ValueError) as _: | ||
OpenRouter() | ||
|
||
def test_openrouter_llm_call_uninitialized_client(self, openrouter_llm): | ||
# Set the client to None to simulate an uninitialized client | ||
openrouter_llm.client = None | ||
with pytest.raises(ValueError) as e: | ||
llm_input = [Message(role="user", content="test")] | ||
openrouter_llm(llm_input) | ||
assert "OpenRouter client is not initialized." in str(e.value) | ||
|
||
def test_openrouter_llm_init_exception(self, mocker): | ||
mocker.patch("os.getenv", return_value="fake-api-key") | ||
mocker.patch("openai.OpenAI", side_effect=Exception("Initialization error")) | ||
with pytest.raises(ValueError) as e: | ||
OpenRouter() | ||
assert ( | ||
"OpenRouter API client failed to initialize. Error: Initialization error" | ||
in str(e.value) | ||
) | ||
|
||
def test_openrouter_llm_call_success(self, openrouter_llm, mocker): | ||
mock_completion = mocker.MagicMock() | ||
mock_completion.choices[0].message.content = "test" | ||
|
||
mocker.patch("os.getenv", return_value="fake-api-key") | ||
mocker.patch.object( | ||
openrouter_llm.client.chat.completions, | ||
"create", | ||
return_value=mock_completion, | ||
) | ||
llm_input = [Message(role="user", content="test")] | ||
output = openrouter_llm(llm_input) | ||
assert output == "test" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters