Skip to content

Commit

Permalink
Add aider context retrieval + tool tracking rework
Browse files Browse the repository at this point in the history
  • Loading branch information
waleko committed Jun 22, 2024
1 parent 91e92e3 commit de21236
Show file tree
Hide file tree
Showing 18 changed files with 879 additions and 31 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from code_editing.agents.collect_edit.context_collectors.aider import AiderRetrieval
from code_editing.agents.collect_edit.context_collectors.as_is_retrieval import AsIsRetrieval
from code_editing.agents.collect_edit.context_collectors.auto_code_rover import ACRRetrieval
from code_editing.agents.collect_edit.context_collectors.llm_cycle_retrieval import LLMCycleRetrieval
Expand All @@ -12,4 +13,5 @@
"LLMFixedCtxRetrieval",
"ACRRetrieval",
"MyACRRetrieval",
"AiderRetrieval",
]
47 changes: 47 additions & 0 deletions code_editing/agents/collect_edit/context_collectors/aider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging
import os.path
from operator import itemgetter

from langchain_core.runnables import RunnableLambda

from code_editing.agents.collect_edit.editors.util import TagParser
from code_editing.agents.context_providers.aider import AiderRepoMap
from code_editing.agents.graph_factory import GraphFactory
from code_editing.agents.run import RunOverviewManager
from code_editing.agents.tools.common import read_file_full
from code_editing.agents.utils import PromptWrapper

logger = logging.getLogger(__name__)


class AiderRetrieval(GraphFactory):
name = "aider_retrieval"

def __init__(self, select_prompt: PromptWrapper, **kwargs):
super().__init__()
self.select_prompt = select_prompt

def build(self, run_overview_manager: RunOverviewManager, *args, **kwargs):
# noinspection PyTypeChecker
aider: AiderRepoMap = run_overview_manager.get_ctx_provider("aider")
repo_map = aider.get_repo_map()

def to_viewed_lines(state: dict):
files = set(state["matches"])
viewed_lines = {}
for file in files:
full_path = os.path.join(aider.repo_path, file)
if not os.path.exists(full_path):
logger.warning(f"File {full_path} does not exist")
continue
line_cnt = len(read_file_full(full_path).split("\n"))
viewed_lines[file] = list(range(1, line_cnt + 1))
return {"collected_context": viewed_lines}

return (
{"repo_map": lambda _: repo_map, "instruction": itemgetter("instruction")}
| self.select_prompt.as_runnable()
| self._llm
| TagParser(tag="file")
| RunnableLambda(to_viewed_lines, name="Convert to viewed lines")
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from code_editing.agents.context_providers.acr_search.search_manage import SearchManager
from code_editing.agents.context_providers.acr_search.search_utils import to_relative_path
from code_editing.agents.graph_factory import GraphFactory
from code_editing.agents.run import RunOverviewManager
from code_editing.agents.run import RunOverviewManager, ToolUseStatus

SYSTEM_PROMPT = """You are a software developer maintaining a large project.
You are working on an issue submitted to your project.
Expand Down Expand Up @@ -166,8 +166,6 @@ def __init__(self, *args, max_tries: int = 5, use_show_definition: bool = False,
# remove corresponding line
self.prompt = remove_unwanted_lines(self.prompt, "show_definition")
self.proxy_prompt = remove_unwanted_lines(self.proxy_prompt, "show_definition")
print(self.prompt)
print(self.proxy_prompt)

def proxy_run(self, text: str) -> Optional[dict]:
messages = [SystemMessage(self.proxy_prompt)]
Expand Down Expand Up @@ -261,12 +259,14 @@ def do_search(state):
func_name, func_args = parse_function_invocation(api_call)
function = getattr(search_manager, func_name)
try:
run_overview_manager.add_tool_use(func_name)
run_overview_manager.log_tool_use(func_name, ToolUseStatus.CALL)
res, summary, ok = function(*func_args)
if not ok:
run_overview_manager.add_tool_failure(func_name)
if ok:
run_overview_manager.log_tool_use(func_name, ToolUseStatus.OK)
else:
run_overview_manager.log_tool_use(func_name, ToolUseStatus.FAIL)
except Exception:
run_overview_manager.add_tool_error(func_name)
run_overview_manager.log_tool_use(func_name, ToolUseStatus.THROWN)
raise
tool_output += f"Result of {func_name}({', '.join(func_args)}):\n{res}\n"
except Exception as e:
Expand Down
19 changes: 19 additions & 0 deletions code_editing/agents/collect_edit/editors/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,22 @@ def _type(self) -> str:

key: Optional[str] = None
pattern: str = ""


class TagParser(BaseOutputParser):
def __init__(self, tag: str):
super().__init__()
self.tag = tag

def parse(self, text: str) -> dict:
matches = re.findall(f"<{self.tag}>(.*?)</{self.tag}>", text, re.DOTALL)
return {"matches": matches}

def get_format_instructions(self) -> str:
return f'The output should be a text with the tag "<{self.tag}>...</{self.tag}>".'

@property
def _type(self) -> str:
return "tag_parser"

tag: str = ""
3 changes: 3 additions & 0 deletions code_editing/agents/context_providers/aider/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from code_editing.agents.context_providers.aider.aider import AiderRepoMap

__all__ = ["AiderRepoMap"]
19 changes: 19 additions & 0 deletions code_editing/agents/context_providers/aider/aider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from code_editing.agents.context_providers.aider.repo_map import RepoMap, find_src_files
from code_editing.agents.context_providers.context_provider import ContextProvider


class AiderRepoMap(ContextProvider):
def __init__(self, repo_path: str, data_path: str):
self.repo_path = repo_path
self.data_path = data_path

self.rm = RepoMap(
map_tokens=1024,
root=repo_path,
token_count=lambda x: len(x.split()) // 4,
)

def get_repo_map(self) -> str:
fnames = find_src_files(self.repo_path)
repo_map = self.rm.get_repo_map([], fnames)
return repo_map
Loading

0 comments on commit de21236

Please sign in to comment.