From ad6a1e8074bd084c1a9a5cb64845f6f7f86b51cd Mon Sep 17 00:00:00 2001 From: James Briggs Date: Thu, 10 Oct 2024 23:26:09 +0200 Subject: [PATCH] fix: cohere test cases --- tests/unit/encoders/test_cohere.py | 16 ++++++++-------- tests/unit/llms/test_llm_cohere.py | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/unit/encoders/test_cohere.py b/tests/unit/encoders/test_cohere.py index 0f7607af..b4d81d24 100644 --- a/tests/unit/encoders/test_cohere.py +++ b/tests/unit/encoders/test_cohere.py @@ -11,7 +11,7 @@ def cohere_encoder(mocker): class TestCohereEncoder: def test_initialization_with_api_key(self, cohere_encoder): - assert cohere_encoder.client is not None, "Client should be initialized" + assert cohere_encoder._client is not None, "Client should be initialized" assert ( cohere_encoder.name == "embed-english-v3.0" ), "Default name not set correctly" @@ -25,38 +25,38 @@ def test_initialization_without_api_key(self, mocker, monkeypatch): def test_call_method(self, cohere_encoder, mocker): mock_embed = mocker.MagicMock() mock_embed.embeddings = [[0.1, 0.2, 0.3]] - cohere_encoder.client.embed.return_value = mock_embed + cohere_encoder._client.embed.return_value = mock_embed result = cohere_encoder(["test"]) assert isinstance(result, list), "Result should be a list" assert all( isinstance(sublist, list) for sublist in result ), "Each item in result should be a list" - cohere_encoder.client.embed.assert_called_once() + cohere_encoder._client.embed.assert_called_once() def test_returns_list_of_embeddings_for_valid_input(self, cohere_encoder, mocker): mock_embed = mocker.MagicMock() mock_embed.embeddings = [[0.1, 0.2, 0.3]] - cohere_encoder.client.embed.return_value = mock_embed + cohere_encoder._client.embed.return_value = mock_embed result = cohere_encoder(["test"]) assert isinstance(result, list), "Result should be a list" assert all( isinstance(sublist, list) for sublist in result ), "Each item in result should be a list" - cohere_encoder.client.embed.assert_called_once() + cohere_encoder._client.embed.assert_called_once() def test_handles_multiple_inputs_correctly(self, cohere_encoder, mocker): mock_embed = mocker.MagicMock() mock_embed.embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] - cohere_encoder.client.embed.return_value = mock_embed + cohere_encoder._client.embed.return_value = mock_embed result = cohere_encoder(["test1", "test2"]) assert isinstance(result, list), "Result should be a list" assert all( isinstance(sublist, list) for sublist in result ), "Each item in result should be a list" - cohere_encoder.client.embed.assert_called_once() + cohere_encoder._client.embed.assert_called_once() def test_raises_value_error_if_api_key_is_none(self, mocker, monkeypatch): monkeypatch.delenv("COHERE_API_KEY", raising=False) @@ -79,7 +79,7 @@ def test_raises_value_error_if_cohere_client_is_not_initialized(self, mocker): def test_call_method_raises_error_on_api_failure(self, cohere_encoder, mocker): mocker.patch.object( - cohere_encoder.client, "embed", side_effect=Exception("API call failed") + cohere_encoder._client, "embed", side_effect=Exception("API call failed") ) with pytest.raises(ValueError): cohere_encoder(["test"]) diff --git a/tests/unit/llms/test_llm_cohere.py b/tests/unit/llms/test_llm_cohere.py index aaf8a7e5..dc72931c 100644 --- a/tests/unit/llms/test_llm_cohere.py +++ b/tests/unit/llms/test_llm_cohere.py @@ -12,7 +12,7 @@ def cohere_llm(mocker): class TestCohereLLM: def test_initialization_with_api_key(self, cohere_llm): - assert cohere_llm.client is not None, "Client should be initialized" + assert cohere_llm._client is not None, "Client should be initialized" assert cohere_llm.name == "command", "Default name not set correctly" def test_initialization_without_api_key(self, mocker, monkeypatch): @@ -24,12 +24,12 @@ def test_initialization_without_api_key(self, mocker, monkeypatch): def test_call_method(self, cohere_llm, mocker): mock_llm = mocker.MagicMock() mock_llm.text = "test" - cohere_llm.client.chat.return_value = mock_llm + cohere_llm._client.chat.return_value = mock_llm llm_input = [Message(role="user", content="test")] result = cohere_llm(llm_input) assert isinstance(result, str), "Result should be a str" - cohere_llm.client.chat.assert_called_once() + cohere_llm._client.chat.assert_called_once() def test_raises_value_error_if_cohere_client_fails_to_initialize(self, mocker): mocker.patch( @@ -46,7 +46,7 @@ def test_raises_value_error_if_cohere_client_is_not_initialized(self, mocker): def test_call_method_raises_error_on_api_failure(self, cohere_llm, mocker): mocker.patch.object( - cohere_llm.client, "__call__", side_effect=Exception("API call failed") + cohere_llm._client, "__call__", side_effect=Exception("API call failed") ) with pytest.raises(ValueError): cohere_llm("test")