-
Notifications
You must be signed in to change notification settings - Fork 238
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #142 from dwmorris11/mistralaiapi
feat: Added support for MistralAI API. This includes a
- Loading branch information
Showing
16 changed files
with
1,400 additions
and
443 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
"""This file contains the MistralEncoder class which is used to encode text using MistralAI""" | ||
import os | ||
from time import sleep | ||
from typing import List, Optional | ||
|
||
from mistralai.client import MistralClient | ||
from mistralai.exceptions import MistralException | ||
from mistralai.models.embeddings import EmbeddingResponse | ||
|
||
from semantic_router.encoders import BaseEncoder | ||
|
||
|
||
class MistralEncoder(BaseEncoder): | ||
"""Class to encode text using MistralAI""" | ||
|
||
client: Optional[MistralClient] | ||
type: str = "mistral" | ||
|
||
def __init__( | ||
self, | ||
name: Optional[str] = None, | ||
mistralai_api_key: Optional[str] = None, | ||
score_threshold: float = 0.82, | ||
): | ||
if name is None: | ||
name = os.getenv("MISTRAL_MODEL_NAME", "mistral-embed") | ||
super().__init__(name=name, score_threshold=score_threshold) | ||
api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY") | ||
if api_key is None: | ||
raise ValueError("Mistral API key not provided") | ||
try: | ||
self.client = MistralClient(api_key=api_key) | ||
except Exception as e: | ||
raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e | ||
|
||
def __call__(self, docs: List[str]) -> List[List[float]]: | ||
if self.client is None: | ||
raise ValueError("Mistral client not initialized") | ||
embeds = None | ||
error_message = "" | ||
|
||
# Exponential backoff | ||
for _ in range(3): | ||
try: | ||
embeds = self.client.embeddings(model=self.name, input=docs) | ||
if embeds.data: | ||
break | ||
except MistralException as e: | ||
sleep(2**_) | ||
error_message = str(e) | ||
except Exception as e: | ||
raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e | ||
|
||
if not embeds or not isinstance(embeds, EmbeddingResponse) or not embeds.data: | ||
raise ValueError(f"No embeddings returned from MistralAI: {error_message}") | ||
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] | ||
return embeddings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,15 @@ | ||
from semantic_router.llms.base import BaseLLM | ||
from semantic_router.llms.cohere import CohereLLM | ||
from semantic_router.llms.mistral import MistralAILLM | ||
from semantic_router.llms.openai import OpenAILLM | ||
from semantic_router.llms.openrouter import OpenRouterLLM | ||
from semantic_router.llms.zure import AzureOpenAILLM | ||
|
||
__all__ = ["BaseLLM", "OpenAILLM", "OpenRouterLLM", "CohereLLM", "AzureOpenAILLM"] | ||
__all__ = [ | ||
"BaseLLM", | ||
"OpenAILLM", | ||
"OpenRouterLLM", | ||
"CohereLLM", | ||
"AzureOpenAILLM", | ||
"MistralAILLM", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import os | ||
from typing import List, Optional | ||
|
||
from mistralai.client import MistralClient | ||
|
||
from semantic_router.llms import BaseLLM | ||
from semantic_router.schema import Message | ||
from semantic_router.utils.logger import logger | ||
|
||
|
||
class MistralAILLM(BaseLLM): | ||
client: Optional[MistralClient] | ||
temperature: Optional[float] | ||
max_tokens: Optional[int] | ||
|
||
def __init__( | ||
self, | ||
name: Optional[str] = None, | ||
mistralai_api_key: Optional[str] = None, | ||
temperature: float = 0.01, | ||
max_tokens: int = 200, | ||
): | ||
if name is None: | ||
name = os.getenv("MISTRALAI_CHAT_MODEL_NAME", "mistral-tiny") | ||
super().__init__(name=name) | ||
api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY") | ||
if api_key is None: | ||
raise ValueError("MistralAI API key cannot be 'None'.") | ||
try: | ||
self.client = MistralClient(api_key=api_key) | ||
except Exception as e: | ||
raise ValueError( | ||
f"MistralAI API client failed to initialize. Error: {e}" | ||
) from e | ||
self.temperature = temperature | ||
self.max_tokens = max_tokens | ||
|
||
def __call__(self, messages: List[Message]) -> str: | ||
if self.client is None: | ||
raise ValueError("MistralAI client is not initialized.") | ||
try: | ||
completion = self.client.chat( | ||
model=self.name, | ||
messages=[m.to_mistral() for m in messages], | ||
temperature=self.temperature, | ||
max_tokens=self.max_tokens, | ||
) | ||
|
||
output = completion.choices[0].message.content | ||
|
||
if not output: | ||
raise Exception("No output generated") | ||
return output | ||
except Exception as e: | ||
logger.error(f"LLM error: {e}") | ||
raise Exception(f"LLM error: {e}") from e |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from typing import List | ||
|
||
from pydantic.v1 import BaseModel | ||
|
||
from semantic_router.encoders import BaseEncoder | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.