Skip to content

Commit

Permalink
use interceptor class
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 committed Oct 16, 2024
1 parent e4c43b0 commit 4d47a0c
Showing 1 changed file with 18 additions and 27 deletions.
45 changes: 18 additions & 27 deletions lumen/ai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Llm(param.Parameterized):

use_logfire = param.Boolean(default=False)

interceptor_path = param.String(default=None)
interceptor = param.ClassSelector(default=None, class_=OpenAIInterceptor)

# Allows defining a dictionary of default models.
model_kwargs = param.Dict(default={})
Expand All @@ -35,7 +35,6 @@ class Llm(param.Parameterized):

def __init__(self, **params):
super().__init__(**params)
self._interceptor = None

def _get_model_kwargs(self, model_key):
if model_key in self.model_kwargs:
Expand Down Expand Up @@ -229,17 +228,15 @@ def get_client(self, model_key: str, response_model: BaseModel | None = None, **
model_kwargs["organization"] = self.organization
llm = openai.AsyncOpenAI(**model_kwargs)

if self.interceptor_path:
if self._interceptor is None:
self._interceptor = OpenAIInterceptor(db_path=self.interceptor_path)
self._interceptor.patch_client(llm, mode="store_inputs")
if self.interceptor:
self.interceptor.patch_client(llm, mode="store_inputs")

if response_model:
llm = patch(llm)

if self.interceptor_path:
if self.interceptor:
# must be called after instructor
self._interceptor.patch_client_response(llm)
self.interceptor.patch_client_response(llm)

client_callable = partial(llm.chat.completions.create, model=model)

Expand Down Expand Up @@ -278,17 +275,15 @@ def get_client(self, model_key: str, response_model: BaseModel | None = None, **
model_kwargs["azure_endpoint"] = self.azure_endpoint
llm = openai.AsyncAzureOpenAI(**model_kwargs)

if self.interceptor_path:
if self._interceptor is None:
self._interceptor = OpenAIInterceptor(db_path=self.interceptor_path)
self._interceptor.patch_client(llm, mode="store_inputs")
if self.interceptor:
self.interceptor.patch_client(llm, mode="store_inputs")

if response_model:
llm = patch(llm)

if self.interceptor_path:
if self.interceptor:
# must be called after instructor
self._interceptor.patch_client_response(llm)
self.interceptor.patch_client_response(llm)

client_callable = partial(llm.chat.completions.create, model=model)
return client_callable
Expand Down Expand Up @@ -337,16 +332,14 @@ def get_client(self, model_key: str, response_model: BaseModel | None = None, **
llm.chat.completions = SimpleNamespace(create=None) # make it like OpenAI for simplicity
llm.chat.completions.create = llm.chat.stream_async if stream else llm.chat.complete_async

if self.interceptor_path:
if self._interceptor is None:
self._interceptor = OpenAIInterceptor(db_path=self.interceptor_path)
self._interceptor.patch_client(llm, mode="store_inputs")
if self.interceptor:
self.interceptor.patch_client(llm, mode="store_inputs")

if response_model:
llm = patch(llm)

if self.interceptor_path:
self._interceptor.patch_client_response(llm)
if self.interceptor:
self.interceptor.patch_client_response(llm)

client_callable = partial(llm.chat.completions.create, model=model)
return client_callable
Expand Down Expand Up @@ -410,16 +403,14 @@ async def llm_chat_non_stream_async(*args, **kwargs):
llm.chat.completions = SimpleNamespace(create=None) # make it like OpenAI for simplicity
llm.chat.completions.create = llm.chat.stream_async if stream else llm.chat.complete_async

if self.interceptor_path:
if self._interceptor is None:
self._interceptor = OpenAIInterceptor(db_path=self.interceptor_path)
self._interceptor.patch_client(llm, mode="store_inputs")
if self.interceptor:
self.interceptor.patch_client(llm, mode="store_inputs")

if response_model:
llm = patch(llm)

if self.interceptor_path:
self._interceptor.patch_client_response(llm)
if self.interceptor:
self.interceptor.patch_client_response(llm)

client_callable = partial(llm.chat.completions.create, model=model)
return client_callable
Expand All @@ -445,7 +436,7 @@ def _client_kwargs(self):
return {"temperature": self.temperature, "max_tokens": 1024}

def get_client(self, model_key: str, response_model: BaseModel | None = None, **kwargs):
if self.interceptor_path:
if self.interceptor:
raise NotImplementedError("Interceptors are not supported for AnthropicAI.")

from anthropic import AsyncAnthropic
Expand Down

0 comments on commit 4d47a0c

Please sign in to comment.