Skip to content

Commit

Permalink
fix: cohere test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam committed Oct 10, 2024
1 parent d307fd2 commit ad6a1e8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
16 changes: 8 additions & 8 deletions tests/unit/encoders/test_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def cohere_encoder(mocker):

class TestCohereEncoder:
def test_initialization_with_api_key(self, cohere_encoder):
assert cohere_encoder.client is not None, "Client should be initialized"
assert cohere_encoder._client is not None, "Client should be initialized"
assert (
cohere_encoder.name == "embed-english-v3.0"
), "Default name not set correctly"
Expand All @@ -25,38 +25,38 @@ def test_initialization_without_api_key(self, mocker, monkeypatch):
def test_call_method(self, cohere_encoder, mocker):
mock_embed = mocker.MagicMock()
mock_embed.embeddings = [[0.1, 0.2, 0.3]]
cohere_encoder.client.embed.return_value = mock_embed
cohere_encoder._client.embed.return_value = mock_embed

result = cohere_encoder(["test"])
assert isinstance(result, list), "Result should be a list"
assert all(
isinstance(sublist, list) for sublist in result
), "Each item in result should be a list"
cohere_encoder.client.embed.assert_called_once()
cohere_encoder._client.embed.assert_called_once()

def test_returns_list_of_embeddings_for_valid_input(self, cohere_encoder, mocker):
mock_embed = mocker.MagicMock()
mock_embed.embeddings = [[0.1, 0.2, 0.3]]
cohere_encoder.client.embed.return_value = mock_embed
cohere_encoder._client.embed.return_value = mock_embed

result = cohere_encoder(["test"])
assert isinstance(result, list), "Result should be a list"
assert all(
isinstance(sublist, list) for sublist in result
), "Each item in result should be a list"
cohere_encoder.client.embed.assert_called_once()
cohere_encoder._client.embed.assert_called_once()

def test_handles_multiple_inputs_correctly(self, cohere_encoder, mocker):
mock_embed = mocker.MagicMock()
mock_embed.embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
cohere_encoder.client.embed.return_value = mock_embed
cohere_encoder._client.embed.return_value = mock_embed

result = cohere_encoder(["test1", "test2"])
assert isinstance(result, list), "Result should be a list"
assert all(
isinstance(sublist, list) for sublist in result
), "Each item in result should be a list"
cohere_encoder.client.embed.assert_called_once()
cohere_encoder._client.embed.assert_called_once()

def test_raises_value_error_if_api_key_is_none(self, mocker, monkeypatch):
monkeypatch.delenv("COHERE_API_KEY", raising=False)
Expand All @@ -79,7 +79,7 @@ def test_raises_value_error_if_cohere_client_is_not_initialized(self, mocker):

def test_call_method_raises_error_on_api_failure(self, cohere_encoder, mocker):
mocker.patch.object(
cohere_encoder.client, "embed", side_effect=Exception("API call failed")
cohere_encoder._client, "embed", side_effect=Exception("API call failed")
)
with pytest.raises(ValueError):
cohere_encoder(["test"])
8 changes: 4 additions & 4 deletions tests/unit/llms/test_llm_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def cohere_llm(mocker):

class TestCohereLLM:
def test_initialization_with_api_key(self, cohere_llm):
assert cohere_llm.client is not None, "Client should be initialized"
assert cohere_llm._client is not None, "Client should be initialized"
assert cohere_llm.name == "command", "Default name not set correctly"

def test_initialization_without_api_key(self, mocker, monkeypatch):
Expand All @@ -24,12 +24,12 @@ def test_initialization_without_api_key(self, mocker, monkeypatch):
def test_call_method(self, cohere_llm, mocker):
mock_llm = mocker.MagicMock()
mock_llm.text = "test"
cohere_llm.client.chat.return_value = mock_llm
cohere_llm._client.chat.return_value = mock_llm

llm_input = [Message(role="user", content="test")]
result = cohere_llm(llm_input)
assert isinstance(result, str), "Result should be a str"
cohere_llm.client.chat.assert_called_once()
cohere_llm._client.chat.assert_called_once()

def test_raises_value_error_if_cohere_client_fails_to_initialize(self, mocker):
mocker.patch(
Expand All @@ -46,7 +46,7 @@ def test_raises_value_error_if_cohere_client_is_not_initialized(self, mocker):

def test_call_method_raises_error_on_api_failure(self, cohere_llm, mocker):
mocker.patch.object(
cohere_llm.client, "__call__", side_effect=Exception("API call failed")
cohere_llm._client, "__call__", side_effect=Exception("API call failed")
)
with pytest.raises(ValueError):
cohere_llm("test")

0 comments on commit ad6a1e8

Please sign in to comment.