Skip to content

Commit

Permalink
Add support for the code-engine context provider
Browse files Browse the repository at this point in the history
  • Loading branch information
waleko committed Jul 23, 2024
1 parent 13080a8 commit fc70677
Show file tree
Hide file tree
Showing 10 changed files with 184 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def build(self, run_overview_manager: RunOverviewManager, *args, **kwargs):

def get_llm_retrieval_tools(self, retrieval_helper):
# Find the code search tool
search_tools = [t for t in self._tools if "search" in t.name]
search_tools = self._tools

@tool
def add_to_context(file_name: str, start_line: int, end_line: int):
Expand Down
4 changes: 4 additions & 0 deletions code_editing/agents/context_providers/code_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from code_editing.agents.context_providers.code_engine.manager import CodeEngineManager
from code_editing.agents.context_providers.code_engine.tools import ASTGetFileFunctionNames, ASTGetFileFunctionCode

__all__ = ["ASTGetFileFunctionNames", "ASTGetFileFunctionCode", "CodeEngineManager"]
68 changes: 68 additions & 0 deletions code_editing/agents/context_providers/code_engine/api_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Any, Type, Callable

import code_engine_client
from langchain_core.tools import ToolException
from pydantic import BaseModel, ValidationError

from code_editing.agents.context_providers.code_engine import CodeEngineManager
from code_editing.agents.tools.base_tool import CEBaseTool


def hack_pydantic_langchain(T: Type[BaseModel]) -> Type[BaseModel]:
class FixedModel(T):
def dict(self, *args, **kwargs):
return super().dict(by_alias=False, *args, **kwargs)

@classmethod
def schema(cls, *args, **kwargs):
return super().schema(by_alias=False, *args, **kwargs)

return FixedModel


def tool_from_api(
func: Callable[[code_engine_client.ApiClient, Any], Any],
input_type: Type[BaseModel],
tool_name: str,
tool_description: str,
post_process: Callable[[Any], Any] = None,
):
new_input_type = hack_pydantic_langchain(input_type)

class ToolFromApi(CEBaseTool):
name = tool_name
description = tool_description
args_schema = new_input_type
code_engine: CodeEngineManager = None

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.args_schema = new_input_type

if self.dry_run:
return

if self.code_engine is None:
raise ValueError("CodeEngineManager is required")

def _run(self, **kwargs) -> Any:
try:
inp = input_type.parse_obj(kwargs)
res = func(self.code_engine.api_client, inp)
# Post process the result
if post_process:
res = post_process(res)
# Display the result
if res is None:
return "Success"
return str(res)
except ValidationError as e:
raise ToolException(f"Invalid input: {e}")
except Exception as e:
raise ToolException(f"Error in {tool_name}: {e}")

@property
def short_name(self) -> str:
return 'api'

return ToolFromApi
47 changes: 47 additions & 0 deletions code_editing/agents/context_providers/code_engine/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import socket
import subprocess
import time

import code_engine_client

from code_editing.agents.context_providers.context_provider import ContextProvider


def get_free_port():
sock = socket.socket()
sock.bind(('', 0))
return sock.getsockname()[1]


class CodeEngineManager(ContextProvider):
def __init__(self, binary_path: str, repo_path: str, **kwargs):
self.binary_path = binary_path
self.repo_path = repo_path
self.server_process = None
self.api_client = None
self.start_server()
self.set_working_dir(repo_path)

def start_server(self):
port = get_free_port()
command = [self.binary_path, f"-port={port}"]
self.server_process = subprocess.Popen(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

server_url = f"http://localhost:{port}"
self.api_client = code_engine_client.ApiClient(configuration=code_engine_client.Configuration(host=server_url))

# Wait for server to start
time.sleep(5)

def set_working_dir(self, working_dir: str):
fs_api = code_engine_client.FileSystemApiApi(self.api_client)
set_working_dir_request = code_engine_client.FileSystemApiSetWorkingDirRequest(working_dir=working_dir)
fs_api.file_system_set_working_dir_post(set_working_dir_request)

def __del__(self):
self.shutdown_server()

def shutdown_server(self):
if self.server_process:
self.server_process.terminate()
self.server_process = None
16 changes: 16 additions & 0 deletions code_editing/agents/context_providers/code_engine/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import code_engine_client

from code_editing.agents.context_providers.code_engine.api_tool import tool_from_api

ASTGetFileFunctionNames = tool_from_api(
lambda api, inp: code_engine_client.AstApiApi(api).ast_get_file_functions_names_post(inp),
code_engine_client.AstApiGetFileFunctionsNamesRequest,
"get-file-functions",
"Get functions from a file",
)
ASTGetFileFunctionCode = tool_from_api(
lambda api, inp: code_engine_client.AstApiApi(api).ast_get_file_function_code_post(inp),
code_engine_client.AstApiGetFileFunctionCodeRequest,
"get-file-functions-code",
"Get code of functions from a file",
)
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ class BM25RetrievalConfig(RetrievalConfig):
_target_: str = f"{CE_CLASSES_ROOT_PKG}.agents.context_providers.retrieval.BM25Retrieval"


@dataclass
class CodeEngineConfig(ContextConfig):
_target_: str = f"{CE_CLASSES_ROOT_PKG}.agents.context_providers.code_engine.CodeEngineManager"
binary_path: str = MISSING


@dataclass
class AiderRepoMapConfig(ContextConfig):
_target_: str = f"{CE_CLASSES_ROOT_PKG}.agents.context_providers.aider.AiderRepoMap"
Expand All @@ -47,4 +53,5 @@ class AiderRepoMapConfig(ContextConfig):
cs.store(name="acr_search", group="context", node=ACRSearchManagerConfig)
cs.store(name="faiss", group="context", node=FaissRetrievalConfig)
cs.store(name="bm25", group="context", node=BM25RetrievalConfig)
cs.store(name="code_engine", group="context", node=CodeEngineConfig)
cs.store(name="aider", group="context", node=AiderRepoMapConfig)
14 changes: 14 additions & 0 deletions code_editing/configs/agents/tools_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ class ACRSearchCodeInFileConfig(ToolConfig):
_target_: str = f"{CE_CLASSES_ROOT_PKG}.agents.tools.ACRSearchCodeInFile"


@dataclass
class ASTGetFileFunctionNamesConfig(ToolConfig):
_target_: str = f"{CE_CLASSES_ROOT_PKG}.agents.context_providers.code_engine.ASTGetFileFunctionNames"


@dataclass
class ASTGetFileFunctionCodeConfig(ToolConfig):
_target_: str = f"{CE_CLASSES_ROOT_PKG}.agents.context_providers.code_engine.ASTGetFileFunctionCode"


@dataclass
class ACRShowDefinitionConfig(ToolConfig):
_target_: str = f"{CE_CLASSES_ROOT_PKG}.agents.tools.ACRShowDefinition"
Expand All @@ -87,3 +97,7 @@ class AiderRepoMapConfig(ToolConfig):
cs.store(name="acr_show_definition", group="tools", node=ACRShowDefinitionConfig)

cs.store(name="repo_map", group="tools", node=AiderRepoMapConfig)

# All Code Engine tool options
cs.store(name="ce_ast_get_file_function_names", group="tools", node=ASTGetFileFunctionNamesConfig)
cs.store(name="ce_ast_get_file_function_code", group="tools", node=ASTGetFileFunctionCodeConfig)
3 changes: 3 additions & 0 deletions code_editing/scripts/conf/tools/code_engine.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
defaults:
- /tools@ce_ast_get_file_function_names: ce_ast_get_file_function_names
- /tools@ce_ast_get_file_function_code: ce_ast_get_file_function_code
21 changes: 21 additions & 0 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ scikit-learn = "^1.5.0"
scipy = "^1.13.1"


[tool.poetry.group.code-engine.dependencies]
code-engine-client = {git = "https://github.com/waleko/code-engine-client.git"}

[tool.poetry.group.aider.dependencies]
grep-ast = "0.3.2"
tree-sitter = "0.21.3"
Expand Down

0 comments on commit fc70677

Please sign in to comment.