diff --git a/lumen/ai/llm.py b/lumen/ai/llm.py index 0056ecd5..c7d9ad3e 100644 --- a/lumen/ai/llm.py +++ b/lumen/ai/llm.py @@ -24,7 +24,7 @@ class Llm(param.Parameterized): use_logfire = param.Boolean(default=False) - interceptor_path = param.String(default=None) + interceptor = param.ClassSelector(default=None, class_=OpenAIInterceptor) # Allows defining a dictionary of default models. model_kwargs = param.Dict(default={}) @@ -35,7 +35,6 @@ class Llm(param.Parameterized): def __init__(self, **params): super().__init__(**params) - self._interceptor = None def _get_model_kwargs(self, model_key): if model_key in self.model_kwargs: @@ -229,17 +228,15 @@ def get_client(self, model_key: str, response_model: BaseModel | None = None, ** model_kwargs["organization"] = self.organization llm = openai.AsyncOpenAI(**model_kwargs) - if self.interceptor_path: - if self._interceptor is None: - self._interceptor = OpenAIInterceptor(db_path=self.interceptor_path) - self._interceptor.patch_client(llm, mode="store_inputs") + if self.interceptor: + self.interceptor.patch_client(llm, mode="store_inputs") if response_model: llm = patch(llm) - if self.interceptor_path: + if self.interceptor: # must be called after instructor - self._interceptor.patch_client_response(llm) + self.interceptor.patch_client_response(llm) client_callable = partial(llm.chat.completions.create, model=model) @@ -278,17 +275,15 @@ def get_client(self, model_key: str, response_model: BaseModel | None = None, ** model_kwargs["azure_endpoint"] = self.azure_endpoint llm = openai.AsyncAzureOpenAI(**model_kwargs) - if self.interceptor_path: - if self._interceptor is None: - self._interceptor = OpenAIInterceptor(db_path=self.interceptor_path) - self._interceptor.patch_client(llm, mode="store_inputs") + if self.interceptor: + self.interceptor.patch_client(llm, mode="store_inputs") if response_model: llm = patch(llm) - if self.interceptor_path: + if self.interceptor: # must be called after instructor - self._interceptor.patch_client_response(llm) + self.interceptor.patch_client_response(llm) client_callable = partial(llm.chat.completions.create, model=model) return client_callable @@ -337,16 +332,14 @@ def get_client(self, model_key: str, response_model: BaseModel | None = None, ** llm.chat.completions = SimpleNamespace(create=None) # make it like OpenAI for simplicity llm.chat.completions.create = llm.chat.stream_async if stream else llm.chat.complete_async - if self.interceptor_path: - if self._interceptor is None: - self._interceptor = OpenAIInterceptor(db_path=self.interceptor_path) - self._interceptor.patch_client(llm, mode="store_inputs") + if self.interceptor: + self.interceptor.patch_client(llm, mode="store_inputs") if response_model: llm = patch(llm) - if self.interceptor_path: - self._interceptor.patch_client_response(llm) + if self.interceptor: + self.interceptor.patch_client_response(llm) client_callable = partial(llm.chat.completions.create, model=model) return client_callable @@ -410,16 +403,14 @@ async def llm_chat_non_stream_async(*args, **kwargs): llm.chat.completions = SimpleNamespace(create=None) # make it like OpenAI for simplicity llm.chat.completions.create = llm.chat.stream_async if stream else llm.chat.complete_async - if self.interceptor_path: - if self._interceptor is None: - self._interceptor = OpenAIInterceptor(db_path=self.interceptor_path) - self._interceptor.patch_client(llm, mode="store_inputs") + if self.interceptor: + self.interceptor.patch_client(llm, mode="store_inputs") if response_model: llm = patch(llm) - if self.interceptor_path: - self._interceptor.patch_client_response(llm) + if self.interceptor: + self.interceptor.patch_client_response(llm) client_callable = partial(llm.chat.completions.create, model=model) return client_callable @@ -445,7 +436,7 @@ def _client_kwargs(self): return {"temperature": self.temperature, "max_tokens": 1024} def get_client(self, model_key: str, response_model: BaseModel | None = None, **kwargs): - if self.interceptor_path: + if self.interceptor: raise NotImplementedError("Interceptors are not supported for AnthropicAI.") from anthropic import AsyncAnthropic