From f4b6f02ee005899294b6c0f0fe45331ccf0b4cf6 Mon Sep 17 00:00:00 2001 From: Michael Young <2496187+Miyou@users.noreply.github.com> Date: Wed, 5 Feb 2025 16:09:51 +0100 Subject: [PATCH] feat: much improved RAG, added LLM post-processing of results (#435) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: support new gptme-rag arg --format * fix: only run rag on last user message, not all * feat: make RagConfig class for typed rag config * feat(rag): enable post-processing of context with LLM * feat: use --print-relevance flag on rag_search * fix(rag): use print-relevance when not post-processing * feat(rag): add config params workspace_only and paths * fix: better loading of typed ragconfig * refactor(rag): move get rag context to function * fix: use `--score` instead of old `--print-relevance` --------- Co-authored-by: Erik Bjäreholt --- gptme/config.py | 53 ++++++++++++++++++-- gptme/tools/rag.py | 112 +++++++++++++++++++++++++++++++----------- gptme/util/context.py | 2 +- 3 files changed, 134 insertions(+), 33 deletions(-) diff --git a/gptme/config.py b/gptme/config.py index 069da941..ed4609d3 100644 --- a/gptme/config.py +++ b/gptme/config.py @@ -37,6 +37,48 @@ def dict(self) -> dict: } +default_post_process_prompt = """ +You are an intelligent knowledge retrieval assistant designed to analyze context chunks and extract relevant information based on user queries. Your primary goal is to provide accurate and helpful information while adhering to specific guidelines. + +You will be provided with a user query inside tags and a list of potentially relevant context chunks inside tags. + +When a user submits a query, follow these steps: + +1. Analyze the user's query carefully to identify key concepts and requirements. + +2. Search through the provided context chunks for relevant information. + +3. If you find relevant information: + a. Extract the most pertinent parts. + b. Summarize the relevant context inside tags. + c. Output the exact relevant context chunks, including the complete ... tags. + +4. If you cannot find any relevant information, respond with exactly: "No relevant context found". + +Important guidelines: +- Do not make assumptions beyond the available data. +- Maintain objectivity in source selection. +- When returning context chunks, include the entire content of the tag. Do not modify or truncate it in any way. +- Ensure that you're providing complete information from the chunks, not partial or summarized versions within the tags. +- When no relevant context is found, do not return anything other than exactly "No relevant context found". +- Do not output anything else than the and tags. + +Please provide your response, starting with the summary and followed by the relevant chunks (if any). +""" + + +@dataclass +class RagConfig: + enabled: bool = False + max_tokens: int | None = None + min_relevance: float | None = None + post_process: bool = True + post_process_model: str | None = None + post_process_prompt: str = default_post_process_prompt + workspace_only: bool = True + paths: list[str] = field(default_factory=list) + + @dataclass class ProjectConfig: """Project-level configuration, such as which files to include in the context by default.""" @@ -44,7 +86,7 @@ class ProjectConfig: base_prompt: str | None = None prompt: str | None = None files: list[str] = field(default_factory=list) - rag: dict = field(default_factory=dict) + rag: RagConfig = field(default_factory=RagConfig) ABOUT_ACTIVITYWATCH = """ActivityWatch is a free and open-source automated time-tracker that helps you track how you spend your time on your devices.""" @@ -146,8 +188,13 @@ def get_project_config(workspace: Path | None) -> ProjectConfig | None: ) # load project config with open(project_config_path) as f: - project_config = tomlkit.load(f) - return ProjectConfig(**project_config) # type: ignore + config_data = dict(tomlkit.load(f)) + + # Handle RAG config conversion before creating ProjectConfig + if "rag" in config_data: + config_data["rag"] = RagConfig(**config_data["rag"]) # type: ignore + + return ProjectConfig(**config_data) # type: ignore return None diff --git a/gptme/tools/rag.py b/gptme/tools/rag.py index 1e412e2f..9cd15690 100644 --- a/gptme/tools/rag.py +++ b/gptme/tools/rag.py @@ -15,6 +15,11 @@ [rag] enabled = true + post_process = false # Whether to post-process the context with an LLM to extract the most relevant information + post_process_model = "openai/gpt-4o-mini" # Which model to use for post-processing + post_process_prompt = "" # Optional prompt to use for post-processing (overrides default prompt) + workspace_only = true # Whether to only search in the workspace directory, or the whole RAG index + paths = [] # List of paths to include in the RAG index. Has no effect if workspace_only is true. .. rubric:: Features @@ -36,9 +41,10 @@ from functools import lru_cache from pathlib import Path -from ..config import get_project_config +from ..config import RagConfig, get_project_config from ..message import Message from ..util import get_project_dir +from ..llm import _chat_complete from .base import ToolSpec, ToolUse logger = logging.getLogger(__name__) @@ -106,7 +112,7 @@ def rag_search(query: str, return_full: bool = False) -> str: cmd = ["gptme-rag", "search", query] if return_full: # shows full context of the search results - cmd.append("--show-context") + cmd.extend(["--format", "full", "--score"]) result = _run_rag_cmd(cmd) return result.stdout.strip() @@ -129,7 +135,7 @@ def init() -> ToolSpec: # Check project configuration project_dir = get_project_dir() if project_dir and (config := get_project_config(project_dir)): - enabled = config.rag.get("enabled", False) + enabled = config.rag.enabled if not enabled: logger.debug("RAG not enabled in the project configuration") return replace(tool, available=False) @@ -140,41 +146,89 @@ def init() -> ToolSpec: return tool -def rag_enhance_messages(messages: list[Message]) -> list[Message]: +def get_rag_context( + query: str, + rag_config: RagConfig, + workspace: Path | None = None, +) -> Message: + """Get relevant context chunks from RAG for the user query.""" + + should_post_process = ( + rag_config.post_process and rag_config.post_process_model is not None + ) + + cmd = [ + "gptme-rag", + "search", + query, + ] + if workspace and rag_config.workspace_only: + cmd.append(workspace.as_posix()) + elif rag_config.paths: + cmd.extend(rag_config.paths) + if not should_post_process: + cmd.append("--score") + cmd.extend(["--format", "full"]) + + if rag_config.max_tokens: + cmd.extend(["--max-tokens", str(rag_config.max_tokens)]) + if rag_config.min_relevance: + cmd.extend(["--min-relevance", str(rag_config.min_relevance)]) + rag_result = _run_rag_cmd(cmd).stdout + + # Post-process the context with an LLM (if enabled) + if should_post_process: + post_process_msgs = [ + Message(role="system", content=rag_config.post_process_prompt), + Message(role="system", content=rag_result), + Message( + role="user", + content=f"\n{query}\n", + ), + ] + start = time.monotonic() + rag_result = _chat_complete( + messages=post_process_msgs, + model=rag_config.post_process_model, # type: ignore + tools=[], + ) + logger.info(f"Ran RAG post-process in {time.monotonic() - start:.2f}s") + + # Create the context message + msg = Message( + role="system", + content=f"Relevant context retrieved using `gptme-rag search`:\n\n{rag_result}", + hide=True, + ) + return msg + + +def rag_enhance_messages( + messages: list[Message], workspace: Path | None = None +) -> list[Message]: """Enhance messages with context from RAG.""" if not _has_gptme_rag(): return messages # Load config config = get_project_config(Path.cwd()) - rag_config = config.rag if config and config.rag else {} + rag_config = config.rag if config and config.rag else RagConfig() - if not rag_config.get("enabled", False): + if not rag_config.enabled: return messages - enhanced_messages = [] - for msg in messages: - if msg.role == "user": - try: - # Get context using gptme-rag CLI - cmd = ["gptme-rag", "search", msg.content, "--show-context"] - if max_tokens := rag_config.get("max_tokens"): - cmd.extend(["--max-tokens", str(max_tokens)]) - if min_relevance := rag_config.get("min_relevance"): - cmd.extend(["--min-relevance", str(min_relevance)]) - enhanced_messages.append( - Message( - role="system", - content=f"Relevant context retrieved using `gptme-rag search`:\n\n{_run_rag_cmd(cmd).stdout}", - hide=True, - ) - ) - except Exception as e: - logger.warning(f"Error getting context: {e}") - - enhanced_messages.append(msg) - - return enhanced_messages + last_msg = messages[-1] if messages else None + if last_msg and last_msg.role == "user": + try: + # Get context using gptme-rag CLI + msg = get_rag_context(last_msg.content, rag_config, workspace) + + # Append context message right before the last user message + messages.insert(-1, msg) + except Exception as e: + logger.warning(f"Error getting context: {e}") + + return messages tool = ToolSpec( diff --git a/gptme/util/context.py b/gptme/util/context.py index 614fc83b..8acdf9b2 100644 --- a/gptme/util/context.py +++ b/gptme/util/context.py @@ -351,7 +351,7 @@ def enrich_messages_with_context( msgs = copy(msgs) # First enhance messages with context, if gptme-rag is available - msgs = rag_enhance_messages(msgs) + msgs = rag_enhance_messages(msgs, workspace) msgs = [ append_file_content(msg, workspace, check_modified=use_fresh_context())