Skip to content

Commit

Permalink
Merge pull request #142 from dwmorris11/mistralaiapi
Browse files Browse the repository at this point in the history
feat: Added support for MistralAI API. This includes a
  • Loading branch information
jamescalam authored Feb 12, 2024
2 parents 43f9808 + b8a899f commit 8d7579f
Show file tree
Hide file tree
Showing 16 changed files with 1,400 additions and 443 deletions.
1,398 changes: 980 additions & 418 deletions coverage.xml

Large diffs are not rendered by default.

83 changes: 79 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ python = "^3.9"
pydantic = "^2.5.3"
openai = "^1.10.0"
cohere = "^4.32"
mistralai= "^0.0.12"
numpy = "^1.25.2"
colorlog = "^6.8.0"
pyyaml = "^6.0.1"
Expand Down
2 changes: 2 additions & 0 deletions semantic_router/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from semantic_router.encoders.cohere import CohereEncoder
from semantic_router.encoders.fastembed import FastEmbedEncoder
from semantic_router.encoders.huggingface import HuggingFaceEncoder
from semantic_router.encoders.mistral import MistralEncoder
from semantic_router.encoders.openai import OpenAIEncoder
from semantic_router.encoders.tfidf import TfidfEncoder
from semantic_router.encoders.zure import AzureOpenAIEncoder
Expand All @@ -16,4 +17,5 @@
"TfidfEncoder",
"FastEmbedEncoder",
"HuggingFaceEncoder",
"MistralEncoder",
]
57 changes: 57 additions & 0 deletions semantic_router/encoders/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""This file contains the MistralEncoder class which is used to encode text using MistralAI"""
import os
from time import sleep
from typing import List, Optional

from mistralai.client import MistralClient
from mistralai.exceptions import MistralException
from mistralai.models.embeddings import EmbeddingResponse

from semantic_router.encoders import BaseEncoder


class MistralEncoder(BaseEncoder):
"""Class to encode text using MistralAI"""

client: Optional[MistralClient]
type: str = "mistral"

def __init__(
self,
name: Optional[str] = None,
mistralai_api_key: Optional[str] = None,
score_threshold: float = 0.82,
):
if name is None:
name = os.getenv("MISTRAL_MODEL_NAME", "mistral-embed")
super().__init__(name=name, score_threshold=score_threshold)
api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY")
if api_key is None:
raise ValueError("Mistral API key not provided")
try:
self.client = MistralClient(api_key=api_key)
except Exception as e:
raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e

def __call__(self, docs: List[str]) -> List[List[float]]:
if self.client is None:
raise ValueError("Mistral client not initialized")
embeds = None
error_message = ""

# Exponential backoff
for _ in range(3):
try:
embeds = self.client.embeddings(model=self.name, input=docs)
if embeds.data:
break
except MistralException as e:
sleep(2**_)
error_message = str(e)
except Exception as e:
raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e

if not embeds or not isinstance(embeds, EmbeddingResponse) or not embeds.data:
raise ValueError(f"No embeddings returned from MistralAI: {error_message}")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings
10 changes: 9 additions & 1 deletion semantic_router/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from semantic_router.llms.base import BaseLLM
from semantic_router.llms.cohere import CohereLLM
from semantic_router.llms.mistral import MistralAILLM
from semantic_router.llms.openai import OpenAILLM
from semantic_router.llms.openrouter import OpenRouterLLM
from semantic_router.llms.zure import AzureOpenAILLM

__all__ = ["BaseLLM", "OpenAILLM", "OpenRouterLLM", "CohereLLM", "AzureOpenAILLM"]
__all__ = [
"BaseLLM",
"OpenAILLM",
"OpenRouterLLM",
"CohereLLM",
"AzureOpenAILLM",
"MistralAILLM",
]
56 changes: 56 additions & 0 deletions semantic_router/llms/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
from typing import List, Optional

from mistralai.client import MistralClient

from semantic_router.llms import BaseLLM
from semantic_router.schema import Message
from semantic_router.utils.logger import logger


class MistralAILLM(BaseLLM):
client: Optional[MistralClient]
temperature: Optional[float]
max_tokens: Optional[int]

def __init__(
self,
name: Optional[str] = None,
mistralai_api_key: Optional[str] = None,
temperature: float = 0.01,
max_tokens: int = 200,
):
if name is None:
name = os.getenv("MISTRALAI_CHAT_MODEL_NAME", "mistral-tiny")
super().__init__(name=name)
api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY")
if api_key is None:
raise ValueError("MistralAI API key cannot be 'None'.")
try:
self.client = MistralClient(api_key=api_key)
except Exception as e:
raise ValueError(
f"MistralAI API client failed to initialize. Error: {e}"
) from e
self.temperature = temperature
self.max_tokens = max_tokens

def __call__(self, messages: List[Message]) -> str:
if self.client is None:
raise ValueError("MistralAI client is not initialized.")
try:
completion = self.client.chat(
model=self.name,
messages=[m.to_mistral() for m in messages],
temperature=self.temperature,
max_tokens=self.max_tokens,
)

output = completion.choices[0].message.content

if not output:
raise Exception("No output generated")
return output
except Exception as e:
logger.error(f"LLM error: {e}")
raise Exception(f"LLM error: {e}") from e
7 changes: 7 additions & 0 deletions semantic_router/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
BaseEncoder,
CohereEncoder,
FastEmbedEncoder,
MistralEncoder,
OpenAIEncoder,
)

Expand All @@ -17,6 +18,7 @@ class EncoderType(Enum):
FASTEMBED = "fastembed"
OPENAI = "openai"
COHERE = "cohere"
MISTRAL = "mistral"


class RouteChoice(BaseModel):
Expand All @@ -43,6 +45,8 @@ def __init__(self, type: str, name: Optional[str]):
self.model = OpenAIEncoder(name=name)
elif self.type == EncoderType.COHERE:
self.model = CohereEncoder(name=name)
elif self.type == EncoderType.MISTRAL:
self.model = MistralEncoder(name=name)
else:
raise ValueError

Expand All @@ -65,6 +69,9 @@ def to_cohere(self):
def to_llamacpp(self):
return {"role": self.role, "content": self.content}

def to_mistral(self):
return {"role": self.role, "content": self.content}

def __str__(self):
return f"{self.role}: {self.content}"

Expand Down
1 change: 1 addition & 0 deletions semantic_router/splitters/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List

from pydantic.v1 import BaseModel

from semantic_router.encoders import BaseEncoder


Expand Down
6 changes: 4 additions & 2 deletions semantic_router/splitters/consecutive_sim.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import List
from semantic_router.splitters.base import BaseSplitter
from semantic_router.encoders import BaseEncoder

import numpy as np

from semantic_router.encoders import BaseEncoder
from semantic_router.schema import DocumentSplit
from semantic_router.splitters.base import BaseSplitter


class ConsecutiveSimSplitter(BaseSplitter):
Expand Down
6 changes: 4 additions & 2 deletions semantic_router/splitters/cumulative_sim.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import List
from semantic_router.splitters.base import BaseSplitter

import numpy as np
from semantic_router.schema import DocumentSplit

from semantic_router.encoders import BaseEncoder
from semantic_router.schema import DocumentSplit
from semantic_router.splitters.base import BaseSplitter


class CumulativeSimSplitter(BaseSplitter):
Expand Down
11 changes: 5 additions & 6 deletions semantic_router/text.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from colorama import Fore
from colorama import Style
from typing import List, Literal, Tuple, Union

from colorama import Fore, Style
from pydantic.v1 import BaseModel, Field
from typing import Union, List, Literal, Tuple

from semantic_router.encoders import BaseEncoder
from semantic_router.schema import DocumentSplit, Message
from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter
from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter
from semantic_router.encoders import BaseEncoder
from semantic_router.schema import Message
from semantic_router.schema import DocumentSplit

# Define a type alias for the splitter to simplify the annotation
SplitterType = Union[ConsecutiveSimSplitter, CumulativeSimSplitter, None]
Expand Down
Loading

0 comments on commit 8d7579f

Please sign in to comment.