From 2ef8e9306d741c1ae1ed3e7f85f970dbd86d8ba7 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Fri, 15 Nov 2024 16:52:26 -0800 Subject: [PATCH 1/4] Fix Prompt Override --- py/core/main/api/v3/retrieval_router.py | 2 ++ py/core/providers/database/prompt.py | 25 +++++++++++-------------- py/core/providers/llm/litellm.py | 3 +++ 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/py/core/main/api/v3/retrieval_router.py b/py/core/main/api/v3/retrieval_router.py index 5fe4cac72..9311a32e0 100644 --- a/py/core/main/api/v3/retrieval_router.py +++ b/py/core/main/api/v3/retrieval_router.py @@ -424,6 +424,8 @@ async def rag_app( auth_user, vector_search_settings ) + print(f"Got a task prompt override: {task_prompt_override}") + response = await self.services["retrieval"].rag( query=query, vector_search_settings=vector_search_settings, diff --git a/py/core/providers/database/prompt.py b/py/core/providers/database/prompt.py index 526daacc8..875c9752d 100644 --- a/py/core/providers/database/prompt.py +++ b/py/core/providers/database/prompt.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from datetime import datetime, timedelta from pathlib import Path -from typing import Any, Generic, Optional, TypeVar, Union +from typing import Any, Generic, Optional, TypeVar import yaml @@ -141,20 +141,17 @@ async def get_cached_prompt( bypass_cache: bool = False, ) -> str: """Get a prompt with caching support""" - if prompt_override: - return prompt_override - - cache_key = self._cache_key(prompt_name, inputs) + template = prompt_override or await self._get_prompt_impl( + prompt_name, inputs + ) - if not bypass_cache: - cached = self._prompt_cache.get(cache_key) - if cached is not None: - logger.debug(f"Cache hit for prompt: {cache_key}") - return cached + if inputs: + try: + return template.format(**inputs) + except KeyError as e: + raise ValueError(f"Missing required input: {e}") - result = await self._get_prompt_impl(prompt_name, inputs) - self._prompt_cache.set(cache_key, result) - return result + return template async def get_prompt( # type: ignore self, @@ -249,7 +246,7 @@ def __init__( ) self.connection_manager = connection_manager self.project_name = project_name - self.prompts: dict[str, dict[str, Union[str, dict[str, str]]]] = {} + self.prompts: dict[str, dict[str, str | dict[str, str]]] = {} async def _load_prompts(self) -> None: """Load prompts from both database and YAML files.""" diff --git a/py/core/providers/llm/litellm.py b/py/core/providers/llm/litellm.py index 02747f902..f1f9679cf 100644 --- a/py/core/providers/llm/litellm.py +++ b/py/core/providers/llm/litellm.py @@ -54,6 +54,8 @@ async def _execute_task(self, task: dict[str, Any]): args["messages"] = messages args = {**args, **kwargs} + print(f"Getting completion for: {args}") + return await self.acompletion(**args) def _execute_task_sync(self, task: dict[str, Any]): @@ -66,6 +68,7 @@ def _execute_task_sync(self, task: dict[str, Any]): args = {**args, **kwargs} try: + print(f"Getting completion for: {args}") return self.completion(**args) except Exception as e: logger.error(f"Sync LiteLLM task execution failed: {str(e)}") From 16ec71e37643d5f9806782297577bc2edddb413c Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Fri, 15 Nov 2024 16:53:51 -0800 Subject: [PATCH 2/4] print --- py/core/main/api/v3/retrieval_router.py | 2 -- py/core/providers/llm/litellm.py | 3 --- 2 files changed, 5 deletions(-) diff --git a/py/core/main/api/v3/retrieval_router.py b/py/core/main/api/v3/retrieval_router.py index 9311a32e0..5fe4cac72 100644 --- a/py/core/main/api/v3/retrieval_router.py +++ b/py/core/main/api/v3/retrieval_router.py @@ -424,8 +424,6 @@ async def rag_app( auth_user, vector_search_settings ) - print(f"Got a task prompt override: {task_prompt_override}") - response = await self.services["retrieval"].rag( query=query, vector_search_settings=vector_search_settings, diff --git a/py/core/providers/llm/litellm.py b/py/core/providers/llm/litellm.py index f1f9679cf..02747f902 100644 --- a/py/core/providers/llm/litellm.py +++ b/py/core/providers/llm/litellm.py @@ -54,8 +54,6 @@ async def _execute_task(self, task: dict[str, Any]): args["messages"] = messages args = {**args, **kwargs} - print(f"Getting completion for: {args}") - return await self.acompletion(**args) def _execute_task_sync(self, task: dict[str, Any]): @@ -68,7 +66,6 @@ def _execute_task_sync(self, task: dict[str, Any]): args = {**args, **kwargs} try: - print(f"Getting completion for: {args}") return self.completion(**args) except Exception as e: logger.error(f"Sync LiteLLM task execution failed: {str(e)}") From 04ff6685f1e664175c999043749a97bafa2b4999 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Fri, 15 Nov 2024 16:56:57 -0800 Subject: [PATCH 3/4] Caching --- py/core/providers/database/prompt.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/py/core/providers/database/prompt.py b/py/core/providers/database/prompt.py index 875c9752d..980dd13a3 100644 --- a/py/core/providers/database/prompt.py +++ b/py/core/providers/database/prompt.py @@ -141,16 +141,32 @@ async def get_cached_prompt( bypass_cache: bool = False, ) -> str: """Get a prompt with caching support""" - template = prompt_override or await self._get_prompt_impl( - prompt_name, inputs - ) + if prompt_override: + template = prompt_override + else: + cache_key = self._cache_key(prompt_name, inputs) + if not bypass_cache: + cached = self._prompt_cache.get(cache_key) + if cached is not None: + logger.debug(f"Cache hit for prompt: {cache_key}") + return cached + + template = await self._get_prompt_impl(prompt_name, inputs) if inputs: try: - return template.format(**inputs) + result = template.format(**inputs) + if not prompt_override and not bypass_cache: + # Only cache if not using override and cache isn't bypassed + self._prompt_cache.set(cache_key, result) + return result except KeyError as e: raise ValueError(f"Missing required input: {e}") + if not prompt_override and not bypass_cache: + # Cache the template itself if no inputs + self._prompt_cache.set(prompt_name, template) + return template async def get_prompt( # type: ignore From acb0994ddbc1790cf0a01ba9993ae501fff0e9c6 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Fri, 15 Nov 2024 17:12:50 -0800 Subject: [PATCH 4/4] Fix --- py/core/providers/database/prompt.py | 44 ++++++++++++---------------- 1 file changed, 18 insertions(+), 26 deletions(-) diff --git a/py/core/providers/database/prompt.py b/py/core/providers/database/prompt.py index 980dd13a3..06138fd38 100644 --- a/py/core/providers/database/prompt.py +++ b/py/core/providers/database/prompt.py @@ -142,32 +142,24 @@ async def get_cached_prompt( ) -> str: """Get a prompt with caching support""" if prompt_override: - template = prompt_override - else: - cache_key = self._cache_key(prompt_name, inputs) - if not bypass_cache: - cached = self._prompt_cache.get(cache_key) - if cached is not None: - logger.debug(f"Cache hit for prompt: {cache_key}") - return cached - - template = await self._get_prompt_impl(prompt_name, inputs) - - if inputs: - try: - result = template.format(**inputs) - if not prompt_override and not bypass_cache: - # Only cache if not using override and cache isn't bypassed - self._prompt_cache.set(cache_key, result) - return result - except KeyError as e: - raise ValueError(f"Missing required input: {e}") - - if not prompt_override and not bypass_cache: - # Cache the template itself if no inputs - self._prompt_cache.set(prompt_name, template) - - return template + if inputs: + try: + return prompt_override.format(**inputs) + except KeyError: + return prompt_override + return prompt_override + + cache_key = self._cache_key(prompt_name, inputs) + + if not bypass_cache: + cached = self._prompt_cache.get(cache_key) + if cached is not None: + logger.debug(f"Cache hit for prompt: {cache_key}") + return cached + + result = await self._get_prompt_impl(prompt_name, inputs) + self._prompt_cache.set(cache_key, result) + return result async def get_prompt( # type: ignore self,