Skip to content

Commit

Permalink
Add mistral support (#697)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 authored Sep 20, 2024
1 parent 584b2e6 commit e84976e
Showing 1 changed file with 130 additions and 5 deletions.
135 changes: 130 additions & 5 deletions lumen/ai/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import os

from functools import partial

import panel as pn
Expand All @@ -22,6 +24,8 @@ class Llm(param.Parameterized):
# Allows defining a dictionary of default models.
model_kwargs = param.Dict(default={})

_supports_model_stream = True

__abstract = True

def _get_model_kwargs(self, model_key):
Expand Down Expand Up @@ -77,6 +81,16 @@ async def stream(
model_key: str = "default",
**kwargs,
):
if response_model and not self._supports_model_stream:
yield await self.invoke(
messages,
system=system,
response_model=response_model,
model_key=model_key,
**kwargs,
)
return

string = ""
chunks = await self.invoke(
messages,
Expand Down Expand Up @@ -105,7 +119,7 @@ async def stream(
yield getattr(chunk, field) if field is not None else chunk

async def run_client(self, model_key, messages, **kwargs):
client = self.get_client(model_key, kwargs.get("response_model"))
client = self.get_client(model_key, **kwargs)
return await client(messages=messages, **kwargs)


Expand All @@ -132,7 +146,7 @@ class Llama(Llm):
def _client_kwargs(self):
return {"temperature": self.temperature}

def get_client(self, model_key: str, response_model: BaseModel | None = None):
def get_client(self, model_key: str, response_model: BaseModel | None = None, **kwargs):
if client_callable := pn.state.cache.get(model_key):
return client_callable
from huggingface_hub import hf_hub_download
Expand Down Expand Up @@ -163,7 +177,7 @@ def get_client(self, model_key: str, response_model: BaseModel | None = None):
return client_callable

async def run_client(self, model_key, messages, **kwargs):
client = self.get_client(model_key)
client = self.get_client(model_key, **kwargs)
return await client(messages=messages, **kwargs)


Expand All @@ -188,7 +202,7 @@ class OpenAI(Llm):
def _client_kwargs(self):
return {"temperature": self.temperature}

def get_client(self, model_key: str, response_model: BaseModel | None = None):
def get_client(self, model_key: str, response_model: BaseModel | None = None, **kwargs):
import openai

model_kwargs = self._get_model_kwargs(model_key)
Expand Down Expand Up @@ -228,7 +242,7 @@ class AzureOpenAI(Llm):
def _client_kwargs(self):
return {"temperature": self.temperature}

def get_client(self, model_key: str, response_model: BaseModel | None = None):
def get_client(self, model_key: str, response_model: BaseModel | None = None, **kwargs):
import openai

model_kwargs = self._get_model_kwargs(model_key)
Expand Down Expand Up @@ -258,3 +272,114 @@ class AILauncher(OpenAI):
"default": {"model": "gpt-3.5-turbo"},
"reasoning": {"model": "gpt-4-turbo-preview"},
})


class MistralAI(Llm):

api_key = param.String(default=os.getenv("MISTRAL_API_KEY"))

mode = param.Selector(default=Mode.MISTRAL_TOOLS, objects=[Mode.JSON_SCHEMA, Mode.MISTRAL_TOOLS])

temperature = param.Number(default=0.7, bounds=(0, 1), constant=True)

model_kwargs = param.Dict(default={
"default": {"model": "mistral-small-latest"},
"reasoning": {"model": "mistral-large-latest"},
})

_supports_model_stream = False # instructor doesn't work with Mistral's streaming

@property
def _client_kwargs(self):
return {"temperature": self.temperature}

def get_client(self, model_key: str, response_model: BaseModel | None = None, **kwargs):
from mistralai import Mistral

async def llm_chat_non_stream_async(*args, **kwargs):
response = await llm.chat.complete_async(*args, **kwargs)
return response.choices[0].message.content

model_kwargs = self._get_model_kwargs(model_key)
model = model_kwargs.pop("model")

llm = Mistral(api_key=self.api_key)
if response_model:
# can't use from_mistral due to new mistral API
# https://github.com/jxnl/instructor/issues/969
return patch(
create=partial(llm.chat.complete_async, model=model),
mode=self.mode,
)

stream = kwargs.get("stream", False)
if stream:
return partial(llm.chat.stream_async, model=model)
else:
return partial(llm_chat_non_stream_async, model=model)

@classmethod
def _get_delta(cls, chunk):
if chunk.data.choices:
return chunk.data.choices[0].delta.content or ""
return ""

async def invoke(
self,
messages: list | str,
system: str = "",
response_model: BaseModel | None = None,
allow_partial: bool = False,
model_key: str = "default",
**input_kwargs,
) -> BaseModel:
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]

if messages[0]["role"] == "assistant":
# Mistral cannot start with assistant
messages = messages[1:]

return await super().invoke(
messages,
system,
response_model,
allow_partial,
model_key,
**input_kwargs,
)



class AzureMistralAI(MistralAI):

api_key = param.String(default=os.getenv("AZURE_API_KEY"))

azure_endpoint = param.String(default=os.getenv("AZURE_ENDPOINT"))

model_kwargs = param.Dict(default={
"default": {"model": "azureai"},
})

def get_client(self, model_key: str, response_model: BaseModel | None = None, **kwargs):
from mistralai_azure import MistralAzure

async def llm_chat_non_stream_async(*args, **kwargs):
response = await llm.chat.complete_async(*args, **kwargs)
return response.choices[0].message.content

model_kwargs = self._get_model_kwargs(model_key)
model = model_kwargs.pop("model")

llm = MistralAzure(azure_api_key=self.api_key, azure_endpoint=self.azure_endpoint)
if response_model:
return patch(
create=partial(llm.chat.complete_async, model=model),
mode=self.mode,
)

stream = kwargs.get("stream", False)
if stream:
return partial(llm.chat.stream_async, model=model)
else:
return partial(llm_chat_non_stream_async, model=model)

0 comments on commit e84976e

Please sign in to comment.