-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for the code-engine context provider
- Loading branch information
Showing
10 changed files
with
184 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 4 additions & 0 deletions
4
code_editing/agents/context_providers/code_engine/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from code_editing.agents.context_providers.code_engine.manager import CodeEngineManager | ||
from code_editing.agents.context_providers.code_engine.tools import ASTGetFileFunctionNames, ASTGetFileFunctionCode | ||
|
||
__all__ = ["ASTGetFileFunctionNames", "ASTGetFileFunctionCode", "CodeEngineManager"] |
68 changes: 68 additions & 0 deletions
68
code_editing/agents/context_providers/code_engine/api_tool.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from typing import Any, Type, Callable | ||
|
||
import code_engine_client | ||
from langchain_core.tools import ToolException | ||
from pydantic import BaseModel, ValidationError | ||
|
||
from code_editing.agents.context_providers.code_engine import CodeEngineManager | ||
from code_editing.agents.tools.base_tool import CEBaseTool | ||
|
||
|
||
def hack_pydantic_langchain(T: Type[BaseModel]) -> Type[BaseModel]: | ||
class FixedModel(T): | ||
def dict(self, *args, **kwargs): | ||
return super().dict(by_alias=False, *args, **kwargs) | ||
|
||
@classmethod | ||
def schema(cls, *args, **kwargs): | ||
return super().schema(by_alias=False, *args, **kwargs) | ||
|
||
return FixedModel | ||
|
||
|
||
def tool_from_api( | ||
func: Callable[[code_engine_client.ApiClient, Any], Any], | ||
input_type: Type[BaseModel], | ||
tool_name: str, | ||
tool_description: str, | ||
post_process: Callable[[Any], Any] = None, | ||
): | ||
new_input_type = hack_pydantic_langchain(input_type) | ||
|
||
class ToolFromApi(CEBaseTool): | ||
name = tool_name | ||
description = tool_description | ||
args_schema = new_input_type | ||
code_engine: CodeEngineManager = None | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.args_schema = new_input_type | ||
|
||
if self.dry_run: | ||
return | ||
|
||
if self.code_engine is None: | ||
raise ValueError("CodeEngineManager is required") | ||
|
||
def _run(self, **kwargs) -> Any: | ||
try: | ||
inp = input_type.parse_obj(kwargs) | ||
res = func(self.code_engine.api_client, inp) | ||
# Post process the result | ||
if post_process: | ||
res = post_process(res) | ||
# Display the result | ||
if res is None: | ||
return "Success" | ||
return str(res) | ||
except ValidationError as e: | ||
raise ToolException(f"Invalid input: {e}") | ||
except Exception as e: | ||
raise ToolException(f"Error in {tool_name}: {e}") | ||
|
||
@property | ||
def short_name(self) -> str: | ||
return 'api' | ||
|
||
return ToolFromApi |
47 changes: 47 additions & 0 deletions
47
code_editing/agents/context_providers/code_engine/manager.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import socket | ||
import subprocess | ||
import time | ||
|
||
import code_engine_client | ||
|
||
from code_editing.agents.context_providers.context_provider import ContextProvider | ||
|
||
|
||
def get_free_port(): | ||
sock = socket.socket() | ||
sock.bind(('', 0)) | ||
return sock.getsockname()[1] | ||
|
||
|
||
class CodeEngineManager(ContextProvider): | ||
def __init__(self, binary_path: str, repo_path: str, **kwargs): | ||
self.binary_path = binary_path | ||
self.repo_path = repo_path | ||
self.server_process = None | ||
self.api_client = None | ||
self.start_server() | ||
self.set_working_dir(repo_path) | ||
|
||
def start_server(self): | ||
port = get_free_port() | ||
command = [self.binary_path, f"-port={port}"] | ||
self.server_process = subprocess.Popen(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | ||
|
||
server_url = f"http://localhost:{port}" | ||
self.api_client = code_engine_client.ApiClient(configuration=code_engine_client.Configuration(host=server_url)) | ||
|
||
# Wait for server to start | ||
time.sleep(5) | ||
|
||
def set_working_dir(self, working_dir: str): | ||
fs_api = code_engine_client.FileSystemApiApi(self.api_client) | ||
set_working_dir_request = code_engine_client.FileSystemApiSetWorkingDirRequest(working_dir=working_dir) | ||
fs_api.file_system_set_working_dir_post(set_working_dir_request) | ||
|
||
def __del__(self): | ||
self.shutdown_server() | ||
|
||
def shutdown_server(self): | ||
if self.server_process: | ||
self.server_process.terminate() | ||
self.server_process = None |
16 changes: 16 additions & 0 deletions
16
code_editing/agents/context_providers/code_engine/tools.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import code_engine_client | ||
|
||
from code_editing.agents.context_providers.code_engine.api_tool import tool_from_api | ||
|
||
ASTGetFileFunctionNames = tool_from_api( | ||
lambda api, inp: code_engine_client.AstApiApi(api).ast_get_file_functions_names_post(inp), | ||
code_engine_client.AstApiGetFileFunctionsNamesRequest, | ||
"get-file-functions", | ||
"Get functions from a file", | ||
) | ||
ASTGetFileFunctionCode = tool_from_api( | ||
lambda api, inp: code_engine_client.AstApiApi(api).ast_get_file_function_code_post(inp), | ||
code_engine_client.AstApiGetFileFunctionCodeRequest, | ||
"get-file-functions-code", | ||
"Get code of functions from a file", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
defaults: | ||
- /tools@ce_ast_get_file_function_names: ce_ast_get_file_function_names | ||
- /tools@ce_ast_get_file_function_code: ce_ast_get_file_function_code |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters