Skip to content

Commit

Permalink
Merge pull request #448 from aurelio-labs/james/remove-cohere-default
Browse files Browse the repository at this point in the history
feat: move cohere to optional dep
  • Loading branch information
jamescalam authored Oct 10, 2024
2 parents 667eded + ad6a1e8 commit 0874205
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 43 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
project = "Semantic Router"
copyright = "2024, Aurelio AI"
author = "Aurelio AI"
release = "0.0.71"
release = "0.0.72"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
27 changes: 14 additions & 13 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "semantic-router"
version = "0.0.71"
version = "0.0.72"
description = "Super fast semantic router for AI decision making"
authors = ["Aurelio AI <hello@aurelio.ai>"]
readme = "README.md"
Expand All @@ -11,7 +11,7 @@ license = "MIT"
python = ">=3.9,<3.13"
pydantic = "^2.5.3"
openai = ">=1.10.0,<2.0.0"
cohere = ">=5.9.4,<6.00"
cohere = {version = ">=5.9.4,<6.00", optional = true}
mistralai= {version = ">=0.0.12,<0.1.0", optional = true}
numpy = "^1.25.2"
colorlog = "^6.8.0"
Expand Down Expand Up @@ -52,6 +52,7 @@ bedrock = ["boto3", "botocore"]
postgres = ["psycopg2"]
fastembed = ["fastembed"]
docs = ["sphinx", "sphinxawesome-theme"]
cohere = ["cohere"]

[tool.poetry.group.dev.dependencies]
ipykernel = "^6.25.0"
Expand Down
2 changes: 1 addition & 1 deletion semantic_router/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

__all__ = ["RouteLayer", "HybridRouteLayer", "Route", "LayerConfig"]

__version__ = "0.0.71"
__version__ = "0.0.72"
38 changes: 30 additions & 8 deletions semantic_router/encoders/cohere.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import os
from typing import List, Optional
from typing import Any, List, Optional

import cohere
from cohere.types.embed_response import EmbeddingsByTypeEmbedResponse
from pydantic.v1 import PrivateAttr

from semantic_router.encoders import BaseEncoder
from semantic_router.utils.defaults import EncoderDefault


class CohereEncoder(BaseEncoder):
client: Optional[cohere.Client] = None
_client: Any = PrivateAttr()
_embed_type: Any = PrivateAttr()
type: str = "cohere"
input_type: Optional[str] = "search_query"

Expand All @@ -28,25 +28,47 @@ def __init__(
input_type=input_type, # type: ignore
)
self.input_type = input_type
self._client = self._initialize_client(cohere_api_key)

def _initialize_client(self, cohere_api_key: Optional[str] = None):
"""Initializes the Cohere client.
:param cohere_api_key: The API key for the Cohere client, can also
be set via the COHERE_API_KEY environment variable.
:return: An instance of the Cohere client.
"""
try:
import cohere
from cohere.types.embed_response import EmbeddingsByTypeEmbedResponse

self._embed_type = EmbeddingsByTypeEmbedResponse
except ImportError:
raise ImportError(
"Please install Cohere to use CohereEncoder. "
"You can install it with: "
"`pip install 'semantic-router[cohere]'`"
)
cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY")
if cohere_api_key is None:
raise ValueError("Cohere API key cannot be 'None'.")
try:
self.client = cohere.Client(cohere_api_key)
client = cohere.Client(cohere_api_key)
except Exception as e:
raise ValueError(
f"Cohere API client failed to initialize. Error: {e}"
) from e
return client

def __call__(self, docs: List[str]) -> List[List[float]]:
if self.client is None:
if self._client is None:
raise ValueError("Cohere client is not initialized.")
try:
embeds = self.client.embed(
embeds = self._client.embed(
texts=docs, input_type=self.input_type, model=self.name
)
# Check for unsupported type.
if isinstance(embeds, EmbeddingsByTypeEmbedResponse):
if isinstance(embeds, self._embed_type):
raise NotImplementedError(
"Handling of EmbedByTypeResponseEmbeddings is not implemented."
)
Expand Down
24 changes: 18 additions & 6 deletions semantic_router/llms/cohere.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import os
from typing import List, Optional
from typing import Any, List, Optional

import cohere
from pydantic.v1 import PrivateAttr

from semantic_router.llms import BaseLLM
from semantic_router.schema import Message


class CohereLLM(BaseLLM):
client: Optional[cohere.Client] = None
_client: Any = PrivateAttr()

def __init__(
self,
Expand All @@ -18,21 +18,33 @@ def __init__(
if name is None:
name = os.getenv("COHERE_CHAT_MODEL_NAME", "command")
super().__init__(name=name)
self._client = self._initialize_client(cohere_api_key)

def _initialize_client(self, cohere_api_key: Optional[str] = None):
try:
import cohere
except ImportError:
raise ImportError(
"Please install Cohere to use CohereLLM. "
"You can install it with: "
"`pip install 'semantic-router[cohere]'`"
)
cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY")
if cohere_api_key is None:
raise ValueError("Cohere API key cannot be 'None'.")
try:
self.client = cohere.Client(cohere_api_key)
client = cohere.Client(cohere_api_key)
except Exception as e:
raise ValueError(
f"Cohere API client failed to initialize. Error: {e}"
) from e
return client

def __call__(self, messages: List[Message]) -> str:
if self.client is None:
if self._client is None:
raise ValueError("Cohere client is not initialized.")
try:
completion = self.client.chat(
completion = self._client.chat(
model=self.name,
chat_history=[m.to_cohere() for m in messages[:-1]],
message=messages[-1].content,
Expand Down
16 changes: 8 additions & 8 deletions tests/unit/encoders/test_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def cohere_encoder(mocker):

class TestCohereEncoder:
def test_initialization_with_api_key(self, cohere_encoder):
assert cohere_encoder.client is not None, "Client should be initialized"
assert cohere_encoder._client is not None, "Client should be initialized"
assert (
cohere_encoder.name == "embed-english-v3.0"
), "Default name not set correctly"
Expand All @@ -25,38 +25,38 @@ def test_initialization_without_api_key(self, mocker, monkeypatch):
def test_call_method(self, cohere_encoder, mocker):
mock_embed = mocker.MagicMock()
mock_embed.embeddings = [[0.1, 0.2, 0.3]]
cohere_encoder.client.embed.return_value = mock_embed
cohere_encoder._client.embed.return_value = mock_embed

result = cohere_encoder(["test"])
assert isinstance(result, list), "Result should be a list"
assert all(
isinstance(sublist, list) for sublist in result
), "Each item in result should be a list"
cohere_encoder.client.embed.assert_called_once()
cohere_encoder._client.embed.assert_called_once()

def test_returns_list_of_embeddings_for_valid_input(self, cohere_encoder, mocker):
mock_embed = mocker.MagicMock()
mock_embed.embeddings = [[0.1, 0.2, 0.3]]
cohere_encoder.client.embed.return_value = mock_embed
cohere_encoder._client.embed.return_value = mock_embed

result = cohere_encoder(["test"])
assert isinstance(result, list), "Result should be a list"
assert all(
isinstance(sublist, list) for sublist in result
), "Each item in result should be a list"
cohere_encoder.client.embed.assert_called_once()
cohere_encoder._client.embed.assert_called_once()

def test_handles_multiple_inputs_correctly(self, cohere_encoder, mocker):
mock_embed = mocker.MagicMock()
mock_embed.embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
cohere_encoder.client.embed.return_value = mock_embed
cohere_encoder._client.embed.return_value = mock_embed

result = cohere_encoder(["test1", "test2"])
assert isinstance(result, list), "Result should be a list"
assert all(
isinstance(sublist, list) for sublist in result
), "Each item in result should be a list"
cohere_encoder.client.embed.assert_called_once()
cohere_encoder._client.embed.assert_called_once()

def test_raises_value_error_if_api_key_is_none(self, mocker, monkeypatch):
monkeypatch.delenv("COHERE_API_KEY", raising=False)
Expand All @@ -79,7 +79,7 @@ def test_raises_value_error_if_cohere_client_is_not_initialized(self, mocker):

def test_call_method_raises_error_on_api_failure(self, cohere_encoder, mocker):
mocker.patch.object(
cohere_encoder.client, "embed", side_effect=Exception("API call failed")
cohere_encoder._client, "embed", side_effect=Exception("API call failed")
)
with pytest.raises(ValueError):
cohere_encoder(["test"])
8 changes: 4 additions & 4 deletions tests/unit/llms/test_llm_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def cohere_llm(mocker):

class TestCohereLLM:
def test_initialization_with_api_key(self, cohere_llm):
assert cohere_llm.client is not None, "Client should be initialized"
assert cohere_llm._client is not None, "Client should be initialized"
assert cohere_llm.name == "command", "Default name not set correctly"

def test_initialization_without_api_key(self, mocker, monkeypatch):
Expand All @@ -24,12 +24,12 @@ def test_initialization_without_api_key(self, mocker, monkeypatch):
def test_call_method(self, cohere_llm, mocker):
mock_llm = mocker.MagicMock()
mock_llm.text = "test"
cohere_llm.client.chat.return_value = mock_llm
cohere_llm._client.chat.return_value = mock_llm

llm_input = [Message(role="user", content="test")]
result = cohere_llm(llm_input)
assert isinstance(result, str), "Result should be a str"
cohere_llm.client.chat.assert_called_once()
cohere_llm._client.chat.assert_called_once()

def test_raises_value_error_if_cohere_client_fails_to_initialize(self, mocker):
mocker.patch(
Expand All @@ -46,7 +46,7 @@ def test_raises_value_error_if_cohere_client_is_not_initialized(self, mocker):

def test_call_method_raises_error_on_api_failure(self, cohere_llm, mocker):
mocker.patch.object(
cohere_llm.client, "__call__", side_effect=Exception("API call failed")
cohere_llm._client, "__call__", side_effect=Exception("API call failed")
)
with pytest.raises(ValueError):
cohere_llm("test")

0 comments on commit 0874205

Please sign in to comment.