diff --git a/gptme/tools/rag.py b/gptme/tools/rag.py index 4c4a057fe..4008d506d 100644 --- a/gptme/tools/rag.py +++ b/gptme/tools/rag.py @@ -62,9 +62,8 @@ from pathlib import Path from ..config import get_project_config -from ..message import Message from ..util import get_project_dir -from .base import ConfirmFunc, ToolSpec, ToolUse +from .base import ToolSpec, ToolUse logger = logging.getLogger(__name__) @@ -80,59 +79,50 @@ instructions = """ Use RAG to index and search project documentation. - -Commands: -- index [paths...] - Index documents in specified paths -- search - Search indexed documents -- status - Show index status """ examples = f""" User: Index the current directory Assistant: Let me index the current directory with RAG. -{ToolUse("rag", ["index"], "").to_output()} +{ToolUse("ipython", [], "rag_index()").to_output()} System: Indexed 1 paths User: Search for documentation about functions Assistant: I'll search for function-related documentation. -{ToolUse("rag", ["search", "function", "documentation"], "").to_output()} +{ToolUse("ipython", [], 'rag_search("function documentation")').to_output()} System: ### docs/api.md Functions are documented using docstrings... User: Show index status Assistant: I'll check the current status of the RAG index. -{ToolUse("rag", ["status"], "").to_output()} +{ToolUse("ipython", [], "get_status()").to_output()} System: Index contains 42 documents """ -def execute_rag(code: str, args: list[str], confirm: ConfirmFunc) -> Message: - """Execute RAG commands.""" +def rag_index(*paths: str, glob: str | None = None) -> str: + """Index documents in specified paths.""" + assert indexer is not None, "RAG indexer not initialized" + paths = paths or (".",) + kwargs = {"glob_pattern": glob} if glob else {} + for path in paths: + indexer.index_directory(Path(path), **kwargs) + return f"Indexed {len(paths)} paths" + + +def rag_search(query: str) -> str: + """Search indexed documents.""" + assert indexer is not None, "RAG indexer not initialized" + docs, _ = indexer.search(query) + return "\n\n".join( + f"### {doc.metadata['source']}\n{doc.content[:200]}..." for doc in docs + ) + + +def rag_status() -> str: + """Show index status.""" assert indexer is not None, "RAG indexer not initialized" - command = args[0] if args else "help" - - if command == "help": - return Message("system", "Available commands: index, search, status") - elif command == "index": - paths = args[1:] or ["."] - for path in paths: - indexer.index_directory(Path(path)) - return Message("system", f"Indexed {len(paths)} paths") - elif command == "search": - query = " ".join(args[1:]) - docs, _ = indexer.search(query) - return Message( - "system", - "\n\n".join( - f"### {doc.metadata['source']}\n{doc.content[:200]}..." for doc in docs - ), - ) - elif command == "status": - return Message( - "system", f"Index contains {indexer.collection.count()} documents" - ) - else: - return Message("system", f"Unknown command: {command}") + return f"Index contains {indexer.collection.count()} documents" def init() -> ToolSpec: @@ -141,22 +131,19 @@ def init() -> ToolSpec: return tool project_dir = get_project_dir() - if not project_dir: - return tool + index_path = Path("~/.cache/gptme/rag").expanduser() + collection = "default" + if project_dir and (config := get_project_config(project_dir)): + index_path = Path(config.rag.get("index_path", index_path)).expanduser() + collection = config.rag.get("collection", project_dir.name) + + import gptme_rag # fmt: skip - config = get_project_config(project_dir) - if config: - # Initialize RAG with configuration - global indexer - import gptme_rag # fmt: skip - - indexer = gptme_rag.Indexer( - persist_directory=Path( - config.rag.get("index_path", "~/.cache/gptme/rag") - ).expanduser(), - # TODO: use a better default collection name? (e.g. project name) - collection_name=config.rag.get("collection", "gptme_docs"), - ) + global indexer + indexer = gptme_rag.Indexer( + persist_directory=index_path, + collection_name=collection, + ) return tool @@ -165,8 +152,7 @@ def init() -> ToolSpec: desc="RAG (Retrieval-Augmented Generation) for context-aware assistance", instructions=instructions, examples=examples, - block_types=["rag"], - execute=execute_rag, + functions=[rag_index, rag_search, rag_status], available=_HAS_RAG, init=init, ) diff --git a/gptme/util/cli.py b/gptme/util/cli.py index 3bdc52b35..eea61c046 100644 --- a/gptme/util/cli.py +++ b/gptme/util/cli.py @@ -103,14 +103,14 @@ def context(): @context.command("generate") -@click.argument("path", type=click.Path(exists=True)) -def context_generate(_path: str): - """Generate context from a directory.""" - pass - # from ..context import generate_context # fmt: skip - - # ctx = generate_context(path) - # print(ctx) +@click.argument("query") +def context_generate(query: str): + """Retrieve context for a given query.""" + from ..context import RAGContextProvider # fmt: skip + + provider = RAGContextProvider() + ctx = provider.get_context(query) + print(ctx) @main.group() diff --git a/pyproject.toml b/pyproject.toml index 562d9c80d..fe688887b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ include = ["gptme/server/static/**/*", "media/logo.png"] gptme = "gptme.cli:main" gptme-server = "gptme.server.cli:main" gptme-eval = "gptme.eval.main:main" +gptme-util = "gptme.util.cli:main" gptme-nc = "gptme.ncurses:main" [tool.poetry.dependencies] diff --git a/tests/test_tools_rag.py b/tests/test_tools_rag.py index 1972561d3..1d5091616 100644 --- a/tests/test_tools_rag.py +++ b/tests/test_tools_rag.py @@ -1,14 +1,13 @@ """Tests for the RAG tool.""" -from collections.abc import Generator from dataclasses import replace from unittest.mock import patch import pytest -from gptme import Message from gptme.tools.base import ToolSpec from gptme.tools.rag import _HAS_RAG from gptme.tools.rag import init as init_rag +from gptme.tools.rag import rag_index, rag_search, rag_status @pytest.fixture @@ -35,7 +34,6 @@ def test_rag_tool_init(): def test_rag_tool_init_without_gptme_rag(): """Test RAG tool initialization when gptme-rag is not available.""" - tool = init_rag() with ( patch("gptme.tools.rag._HAS_RAG", False), @@ -47,73 +45,72 @@ def test_rag_tool_init_without_gptme_rag(): assert tool.available is False -def _m2str(tool_execute: Generator[Message, None, None] | Message) -> str: - """Convert a execute() call to a string.""" - if isinstance(tool_execute, Generator): - return tool_execute.send(None).content - elif isinstance(tool_execute, Message): - return tool_execute.content - - -def noconfirm(*args, **kwargs): - return True - - @pytest.mark.skipif(not _HAS_RAG, reason="gptme-rag not installed") -def test_rag_index_command(temp_docs, tmp_path): - """Test the index command.""" +def test_rag_index_function(temp_docs, tmp_path): + """Test the index function.""" with patch("gptme.tools.rag.get_project_config") as mock_config: mock_config.return_value.rag = { "index_path": str(tmp_path), "collection": "test", } - tool = init_rag() - assert tool.execute - result = _m2str(tool.execute("", ["index", str(temp_docs)], noconfirm)) - assert "Indexed" in result + # Initialize RAG + init_rag() - # Check status after indexing - result = _m2str(tool.execute("", ["status"], noconfirm)) - assert "Index contains" in result - assert "2" in result # Should have indexed 2 documents + # Test indexing with specific path + result = rag_index(str(temp_docs)) + assert "Indexed 1 paths" in result + + # Test indexing with default path + # FIXME: this is really slow in the gptme directory, + # since it contains a lot of files (which are in gitignore, but not respected) + result = rag_index(glob="**/*.py") + assert "Indexed 1 paths" in result @pytest.mark.skipif(not _HAS_RAG, reason="gptme-rag not installed") -def test_rag_search_command(temp_docs): - """Test the search command.""" - tool = init_rag() - assert tool.execute - # Index first - _m2str(tool.execute("", ["index", str(temp_docs)], noconfirm)) +def test_rag_search_function(temp_docs, tmp_path): + """Test the search function.""" + with patch("gptme.tools.rag.get_project_config") as mock_config: + mock_config.return_value.rag = { + "index_path": str(tmp_path), + "collection": "test", + } + + # Initialize RAG and index documents + init_rag() + rag_index(str(temp_docs)) - # Search for Python - result = _m2str(tool.execute("", ["search", "Python"], noconfirm)) - assert "doc1.md" in result - assert "Python functions" in result + # Search for Python + result = rag_search("Python") + assert "doc1.md" in result + assert "Python functions" in result - # Search for testing - result = _m2str(tool.execute("", ["search", "testing"], noconfirm)) - assert "doc2.md" in result - assert "testing practices" in result + # Search for testing + result = rag_search("testing") + assert "doc2.md" in result + assert "testing practices" in result @pytest.mark.skipif(not _HAS_RAG, reason="gptme-rag not installed") -def test_rag_help_command(): - """Test the help command.""" - tool = init_rag() - assert tool.execute - result = _m2str(tool.execute("", ["help"], noconfirm)) - assert "Available commands" in result - assert "index" in result - assert "search" in result - assert "status" in result +def test_rag_status_function(temp_docs, tmp_path): + """Test the status function.""" + with patch("gptme.tools.rag.get_project_config") as mock_config: + mock_config.return_value.rag = { + "index_path": str(tmp_path), + "collection": "test", + } + # Initialize RAG + init_rag() -@pytest.mark.skipif(not _HAS_RAG, reason="gptme-rag not installed") -def test_rag_invalid_command(): - """Test invalid command handling.""" - tool = init_rag() - assert tool.execute - result = _m2str(tool.execute("", ["invalid"], noconfirm)) - assert "Unknown command" in result + # Check initial status + result = rag_status() + assert "Index contains" in result + assert "0" in result + + # Index documents and check status again + rag_index(str(temp_docs)) + result = rag_status() + assert "Index contains" in result + assert "2" in result # Should have indexed 2 documents diff --git a/tests/test_util_cli.py b/tests/test_util_cli.py index bff9df3d8..461562408 100644 --- a/tests/test_util_cli.py +++ b/tests/test_util_cli.py @@ -1,9 +1,10 @@ """Tests for the gptme-util CLI.""" +import time from pathlib import Path -from click.testing import CliRunner -import pytest +from click.testing import CliRunner +from gptme.logmanager import ConversationMeta from gptme.util.cli import main @@ -64,8 +65,6 @@ def test_chats_list(tmp_path, mocker): ) # Create ConversationMeta objects for our test conversations - from gptme.logmanager import ConversationMeta - import time conv1 = ConversationMeta( name="2024-01-01-chat-one", @@ -96,7 +95,6 @@ def test_chats_list(tmp_path, mocker): assert "Messages: 2" in result.output # Second chat has 2 messages -@pytest.mark.skip("Waiting for context module PR") def test_context_generate(tmp_path): """Test the context generate command.""" # Create a test file