From de21236de0ed10dca27ffb143b03510318d33744 Mon Sep 17 00:00:00 2001 From: Alexander Kovrigin Date: Sat, 22 Jun 2024 23:21:32 +0200 Subject: [PATCH] Add aider context retrieval + tool tracking rework --- .../context_collectors/__init__.py | 2 + .../collect_edit/context_collectors/aider.py | 47 ++ .../context_collectors/auto_code_rover.py | 14 +- .../agents/collect_edit/editors/util.py | 19 + .../context_providers/aider/__init__.py | 3 + .../agents/context_providers/aider/aider.py | 19 + .../context_providers/aider/repo_map.py | 577 ++++++++++++++++++ code_editing/agents/run.py | 36 +- code_editing/agents/tools/base_tool.py | 9 +- .../collect_edit/context_collectors_config.py | 8 + .../context_providers/context_config.py | 6 + code_editing/configs/inference_config.py | 1 + code_editing/scripts/common.py | 5 +- code_editing/scripts/conf/aider.yaml | 12 + code_editing/scripts/conf/context/all.yaml | 1 + code_editing/utils/git_utils.py | 1 - poetry.lock | 144 ++++- pyproject.toml | 6 + 18 files changed, 879 insertions(+), 31 deletions(-) create mode 100644 code_editing/agents/collect_edit/context_collectors/aider.py create mode 100644 code_editing/agents/context_providers/aider/__init__.py create mode 100644 code_editing/agents/context_providers/aider/aider.py create mode 100644 code_editing/agents/context_providers/aider/repo_map.py create mode 100644 code_editing/scripts/conf/aider.yaml diff --git a/code_editing/agents/collect_edit/context_collectors/__init__.py b/code_editing/agents/collect_edit/context_collectors/__init__.py index cd2789c..e0d212e 100644 --- a/code_editing/agents/collect_edit/context_collectors/__init__.py +++ b/code_editing/agents/collect_edit/context_collectors/__init__.py @@ -1,3 +1,4 @@ +from code_editing.agents.collect_edit.context_collectors.aider import AiderRetrieval from code_editing.agents.collect_edit.context_collectors.as_is_retrieval import AsIsRetrieval from code_editing.agents.collect_edit.context_collectors.auto_code_rover import ACRRetrieval from code_editing.agents.collect_edit.context_collectors.llm_cycle_retrieval import LLMCycleRetrieval @@ -12,4 +13,5 @@ "LLMFixedCtxRetrieval", "ACRRetrieval", "MyACRRetrieval", + "AiderRetrieval", ] diff --git a/code_editing/agents/collect_edit/context_collectors/aider.py b/code_editing/agents/collect_edit/context_collectors/aider.py new file mode 100644 index 0000000..71728b6 --- /dev/null +++ b/code_editing/agents/collect_edit/context_collectors/aider.py @@ -0,0 +1,47 @@ +import logging +import os.path +from operator import itemgetter + +from langchain_core.runnables import RunnableLambda + +from code_editing.agents.collect_edit.editors.util import TagParser +from code_editing.agents.context_providers.aider import AiderRepoMap +from code_editing.agents.graph_factory import GraphFactory +from code_editing.agents.run import RunOverviewManager +from code_editing.agents.tools.common import read_file_full +from code_editing.agents.utils import PromptWrapper + +logger = logging.getLogger(__name__) + + +class AiderRetrieval(GraphFactory): + name = "aider_retrieval" + + def __init__(self, select_prompt: PromptWrapper, **kwargs): + super().__init__() + self.select_prompt = select_prompt + + def build(self, run_overview_manager: RunOverviewManager, *args, **kwargs): + # noinspection PyTypeChecker + aider: AiderRepoMap = run_overview_manager.get_ctx_provider("aider") + repo_map = aider.get_repo_map() + + def to_viewed_lines(state: dict): + files = set(state["matches"]) + viewed_lines = {} + for file in files: + full_path = os.path.join(aider.repo_path, file) + if not os.path.exists(full_path): + logger.warning(f"File {full_path} does not exist") + continue + line_cnt = len(read_file_full(full_path).split("\n")) + viewed_lines[file] = list(range(1, line_cnt + 1)) + return {"collected_context": viewed_lines} + + return ( + {"repo_map": lambda _: repo_map, "instruction": itemgetter("instruction")} + | self.select_prompt.as_runnable() + | self._llm + | TagParser(tag="file") + | RunnableLambda(to_viewed_lines, name="Convert to viewed lines") + ) diff --git a/code_editing/agents/collect_edit/context_collectors/auto_code_rover.py b/code_editing/agents/collect_edit/context_collectors/auto_code_rover.py index 1f4bba9..eff53bc 100644 --- a/code_editing/agents/collect_edit/context_collectors/auto_code_rover.py +++ b/code_editing/agents/collect_edit/context_collectors/auto_code_rover.py @@ -14,7 +14,7 @@ from code_editing.agents.context_providers.acr_search.search_manage import SearchManager from code_editing.agents.context_providers.acr_search.search_utils import to_relative_path from code_editing.agents.graph_factory import GraphFactory -from code_editing.agents.run import RunOverviewManager +from code_editing.agents.run import RunOverviewManager, ToolUseStatus SYSTEM_PROMPT = """You are a software developer maintaining a large project. You are working on an issue submitted to your project. @@ -166,8 +166,6 @@ def __init__(self, *args, max_tries: int = 5, use_show_definition: bool = False, # remove corresponding line self.prompt = remove_unwanted_lines(self.prompt, "show_definition") self.proxy_prompt = remove_unwanted_lines(self.proxy_prompt, "show_definition") - print(self.prompt) - print(self.proxy_prompt) def proxy_run(self, text: str) -> Optional[dict]: messages = [SystemMessage(self.proxy_prompt)] @@ -261,12 +259,14 @@ def do_search(state): func_name, func_args = parse_function_invocation(api_call) function = getattr(search_manager, func_name) try: - run_overview_manager.add_tool_use(func_name) + run_overview_manager.log_tool_use(func_name, ToolUseStatus.CALL) res, summary, ok = function(*func_args) - if not ok: - run_overview_manager.add_tool_failure(func_name) + if ok: + run_overview_manager.log_tool_use(func_name, ToolUseStatus.OK) + else: + run_overview_manager.log_tool_use(func_name, ToolUseStatus.FAIL) except Exception: - run_overview_manager.add_tool_error(func_name) + run_overview_manager.log_tool_use(func_name, ToolUseStatus.THROWN) raise tool_output += f"Result of {func_name}({', '.join(func_args)}):\n{res}\n" except Exception as e: diff --git a/code_editing/agents/collect_edit/editors/util.py b/code_editing/agents/collect_edit/editors/util.py index 6632440..9ea0f84 100644 --- a/code_editing/agents/collect_edit/editors/util.py +++ b/code_editing/agents/collect_edit/editors/util.py @@ -93,3 +93,22 @@ def _type(self) -> str: key: Optional[str] = None pattern: str = "" + + +class TagParser(BaseOutputParser): + def __init__(self, tag: str): + super().__init__() + self.tag = tag + + def parse(self, text: str) -> dict: + matches = re.findall(f"<{self.tag}>(.*?)", text, re.DOTALL) + return {"matches": matches} + + def get_format_instructions(self) -> str: + return f'The output should be a text with the tag "<{self.tag}>...".' + + @property + def _type(self) -> str: + return "tag_parser" + + tag: str = "" diff --git a/code_editing/agents/context_providers/aider/__init__.py b/code_editing/agents/context_providers/aider/__init__.py new file mode 100644 index 0000000..9d39a00 --- /dev/null +++ b/code_editing/agents/context_providers/aider/__init__.py @@ -0,0 +1,3 @@ +from code_editing.agents.context_providers.aider.aider import AiderRepoMap + +__all__ = ["AiderRepoMap"] diff --git a/code_editing/agents/context_providers/aider/aider.py b/code_editing/agents/context_providers/aider/aider.py new file mode 100644 index 0000000..5a9469a --- /dev/null +++ b/code_editing/agents/context_providers/aider/aider.py @@ -0,0 +1,19 @@ +from code_editing.agents.context_providers.aider.repo_map import RepoMap, find_src_files +from code_editing.agents.context_providers.context_provider import ContextProvider + + +class AiderRepoMap(ContextProvider): + def __init__(self, repo_path: str, data_path: str): + self.repo_path = repo_path + self.data_path = data_path + + self.rm = RepoMap( + map_tokens=1024, + root=repo_path, + token_count=lambda x: len(x.split()) // 4, + ) + + def get_repo_map(self) -> str: + fnames = find_src_files(self.repo_path) + repo_map = self.rm.get_repo_map([], fnames) + return repo_map diff --git a/code_editing/agents/context_providers/aider/repo_map.py b/code_editing/agents/context_providers/aider/repo_map.py new file mode 100644 index 0000000..81b8a48 --- /dev/null +++ b/code_editing/agents/context_providers/aider/repo_map.py @@ -0,0 +1,577 @@ +# Original source: https://github.com/paul-gauthier/aider/blob/0d9150c77b355a18d8cd1995c02cc2e65b965a84/aider/repomap.py +import colorsys +import os +import random +import warnings +from collections import Counter, defaultdict, namedtuple +from pathlib import Path + +import networkx as nx +from grep_ast import TreeContext, filename_to_lang +from pygments.lexers import guess_lexer_for_filename +from pygments.token import Token +from pygments.util import ClassNotFound + +# tree_sitter is throwing a FutureWarning +warnings.simplefilter("ignore", category=FutureWarning) +from tree_sitter_languages import get_language, get_parser # noqa: E402 + +Tag = namedtuple("Tag", "rel_fname fname line name kind".split()) + + +class InputOutput: + num_error_outputs = 0 + num_user_asks = 0 + + def __init__( + self, + pretty=True, + yes=False, + input_history_file=None, + chat_history_file=None, + input=None, + output=None, + user_input_color="blue", + tool_output_color=None, + tool_error_color="red", + encoding="utf-8", + dry_run=False, + llm_history_file=None, + ): + no_color = os.environ.get("NO_COLOR") + if no_color is not None and no_color != "": + pretty = False + + self.user_input_color = user_input_color if pretty else None + self.tool_output_color = tool_output_color if pretty else None + self.tool_error_color = tool_error_color if pretty else None + + self.input = input + self.output = output + + self.pretty = pretty + if self.output: + self.pretty = False + + self.yes = yes + + self.input_history_file = input_history_file + self.llm_history_file = llm_history_file + if chat_history_file is not None: + self.chat_history_file = Path(chat_history_file) + else: + self.chat_history_file = None + + self.encoding = encoding + self.dry_run = dry_run + + def read_text(self, filename): + + try: + with open(str(filename), "r", encoding=self.encoding) as f: + return f.read() + except FileNotFoundError: + self.tool_error(f"{filename}: file not found error") + return + except IsADirectoryError: + self.tool_error(f"{filename}: is a directory") + return + except UnicodeError as e: + self.tool_error(f"{filename}: {e}") + self.tool_error("Use --encoding to set the unicode encoding.") + return + + def write_text(self, filename, content): + if self.dry_run: + return + with open(str(filename), "w", encoding=self.encoding) as f: + f.write(content) + + def tool_error(self, message="", strip=True): + pass + + +class RepoMap: + CACHE_VERSION = 3 + + cache_missing = False + + warned_files = set() + + def __init__( + self, + map_tokens=1024, + root=None, + token_count=None, + repo_content_prefix=None, + verbose=False, + max_context_window=None, + ): + self.io = InputOutput() + self.verbose = verbose + + if not root: + root = os.getcwd() + self.root = root + + self.load_tags_cache() + + self.max_map_tokens = map_tokens + self.max_context_window = max_context_window + + self.token_count = token_count + self.repo_content_prefix = repo_content_prefix + + def get_repo_map(self, chat_files, other_files, mentioned_fnames=None, mentioned_idents=None): + if self.max_map_tokens <= 0: + return + if not other_files: + return + if not mentioned_fnames: + mentioned_fnames = set() + if not mentioned_idents: + mentioned_idents = set() + + max_map_tokens = self.max_map_tokens + + # With no files in the chat, give a bigger view of the entire repo + MUL = 16 + padding = 4096 + if max_map_tokens and self.max_context_window: + target = min(max_map_tokens * MUL, self.max_context_window - padding) + else: + target = 0 + if not chat_files and self.max_context_window and target > 0: + max_map_tokens = target + + try: + files_listing = self.get_ranked_tags_map( + chat_files, other_files, max_map_tokens, mentioned_fnames, mentioned_idents + ) + except RecursionError: + self.io.tool_error("Disabling repo map, git repo too large?") + self.max_map_tokens = 0 + return + + if not files_listing: + return + + num_tokens = self.token_count(files_listing) + if self.verbose: + self.io.tool_output(f"Repo-map: {num_tokens / 1024:.1f} k-tokens") + + if chat_files: + other = "other " + else: + other = "" + + if self.repo_content_prefix: + repo_content = self.repo_content_prefix.format(other=other) + else: + repo_content = "" + + repo_content += files_listing + + return repo_content + + def get_rel_fname(self, fname): + return os.path.relpath(fname, self.root) + + def split_path(self, path): + path = os.path.relpath(path, self.root) + return [path + ":"] + + def load_tags_cache(self): + self.cache_missing = True + self.TAGS_CACHE = {} + + def save_tags_cache(self): + pass + + def get_mtime(self, fname): + try: + return os.path.getmtime(fname) + except FileNotFoundError: + self.io.tool_error(f"File not found error: {fname}") + + def get_tags(self, fname, rel_fname): + # Check if the file is in the cache and if the modification time has not changed + file_mtime = self.get_mtime(fname) + if file_mtime is None: + return [] + + cache_key = fname + if cache_key in self.TAGS_CACHE and self.TAGS_CACHE[cache_key]["mtime"] == file_mtime: + return self.TAGS_CACHE[cache_key]["data"] + + # miss! + + data = list(self.get_tags_raw(fname, rel_fname)) + + # Update the cache + self.TAGS_CACHE[cache_key] = {"mtime": file_mtime, "data": data} + self.save_tags_cache() + return data + + def get_tags_raw(self, fname, rel_fname): + lang = filename_to_lang(fname) + if not lang: + return + + language = get_language(lang) + parser = get_parser(lang) + + # Load the tags queries + if lang == "python": + query_scm = """(class_definition + name: (identifier) @name.definition.class) @definition.class + + (function_definition + name: (identifier) @name.definition.function) @definition.function + + (call + function: [ + (identifier) @name.reference.call + (attribute + attribute: (identifier) @name.reference.call) + ]) @reference.call""" + else: + return + + code = self.io.read_text(fname) + if not code: + return + tree = parser.parse(bytes(code, "utf-8")) + + # Run the tags queries + query = language.query(query_scm) + captures = query.captures(tree.root_node) + + captures = list(captures) + + saw = set() + for node, tag in captures: + if tag.startswith("name.definition."): + kind = "def" + elif tag.startswith("name.reference."): + kind = "ref" + else: + continue + + saw.add(kind) + + result = Tag( + rel_fname=rel_fname, + fname=fname, + name=node.text.decode("utf-8"), + kind=kind, + line=node.start_point[0], + ) + + yield result + + if "ref" in saw: + return + if "def" not in saw: + return + + # We saw defs, without any refs + # Some tags files only provide defs (cpp, for example) + # Use pygments to backfill refs + + try: + lexer = guess_lexer_for_filename(fname, code) + except ClassNotFound: + return + + tokens = list(lexer.get_tokens(code)) + tokens = [token[1] for token in tokens if token[0] in Token.Name] + + for token in tokens: + yield Tag( + rel_fname=rel_fname, + fname=fname, + name=token, + kind="ref", + line=-1, + ) + + def get_ranked_tags(self, chat_fnames, other_fnames, mentioned_fnames, mentioned_idents): + defines = defaultdict(set) + references = defaultdict(list) + definitions = defaultdict(set) + + personalization = dict() + + fnames = set(chat_fnames).union(set(other_fnames)) + chat_rel_fnames = set() + + fnames = sorted(fnames) + + # Default personalization for unspecified files is 1/num_nodes + # https://networkx.org/documentation/stable/_modules/networkx/algorithms/link_analysis/pagerank_alg.html#pagerank + personalize = 10 / len(fnames) + + self.cache_missing = False + + for fname in fnames: + if not Path(fname).is_file(): + if fname not in self.warned_files: + if Path(fname).exists(): + self.io.tool_error(f"Repo-map can't include {fname}, it is not a normal file") + else: + self.io.tool_error(f"Repo-map can't include {fname}, it no longer exists") + + self.warned_files.add(fname) + continue + + # dump(fname) + rel_fname = self.get_rel_fname(fname) + + if fname in chat_fnames: + personalization[rel_fname] = personalize + chat_rel_fnames.add(rel_fname) + + if fname in mentioned_fnames: + personalization[rel_fname] = personalize + + tags = list(self.get_tags(fname, rel_fname)) + if tags is None: + continue + + for tag in tags: + if tag.kind == "def": + defines[tag.name].add(rel_fname) + key = (rel_fname, tag.name) + definitions[key].add(tag) + + if tag.kind == "ref": + references[tag.name].append(rel_fname) + + ## + # dump(defines) + # dump(references) + # dump(personalization) + + if not references: + references = dict((k, list(v)) for k, v in defines.items()) + + idents = set(defines.keys()).intersection(set(references.keys())) + + G = nx.MultiDiGraph() + + for ident in idents: + definers = defines[ident] + if ident in mentioned_idents: + mul = 10 + else: + mul = 1 + for referencer, num_refs in Counter(references[ident]).items(): + for definer in definers: + # if referencer == definer: + # continue + G.add_edge(referencer, definer, weight=mul * num_refs, ident=ident) + + if not references: + pass + + if personalization: + pers_args = dict(personalization=personalization, dangling=personalization) + else: + pers_args = dict() + + try: + ranked = nx.pagerank(G, weight="weight", **pers_args) + except ZeroDivisionError: + return [] + + # distribute the rank from each source node, across all of its out edges + ranked_definitions = defaultdict(float) + for src in G.nodes: + src_rank = ranked[src] + total_weight = sum(data["weight"] for _src, _dst, data in G.out_edges(src, data=True)) + # dump(src, src_rank, total_weight) + for _src, dst, data in G.out_edges(src, data=True): + data["rank"] = src_rank * data["weight"] / total_weight + ident = data["ident"] + ranked_definitions[(dst, ident)] += data["rank"] + + ranked_tags = [] + ranked_definitions = sorted(ranked_definitions.items(), reverse=True, key=lambda x: x[1]) + + # dump(ranked_definitions) + + for (fname, ident), rank in ranked_definitions: + # print(f"{rank:.03f} {fname} {ident}") + if fname in chat_rel_fnames: + continue + ranked_tags += list(definitions.get((fname, ident), [])) + + rel_other_fnames_without_tags = set(self.get_rel_fname(fname) for fname in other_fnames) + + fnames_already_included = set(rt[0] for rt in ranked_tags) + + top_rank = sorted([(rank, node) for (node, rank) in ranked.items()], reverse=True) + for rank, fname in top_rank: + if fname in rel_other_fnames_without_tags: + rel_other_fnames_without_tags.remove(fname) + if fname not in fnames_already_included: + ranked_tags.append((fname,)) + + for fname in rel_other_fnames_without_tags: + ranked_tags.append((fname,)) + + return ranked_tags + + def get_ranked_tags_map( + self, + chat_fnames, + other_fnames=None, + max_map_tokens=None, + mentioned_fnames=None, + mentioned_idents=None, + ): + if not other_fnames: + other_fnames = list() + if not max_map_tokens: + max_map_tokens = self.max_map_tokens + if not mentioned_fnames: + mentioned_fnames = set() + if not mentioned_idents: + mentioned_idents = set() + + ranked_tags = self.get_ranked_tags(chat_fnames, other_fnames, mentioned_fnames, mentioned_idents) + + num_tags = len(ranked_tags) + lower_bound = 0 + upper_bound = num_tags + best_tree = None + best_tree_tokens = 0 + + chat_rel_fnames = [self.get_rel_fname(fname) for fname in chat_fnames] + + # Guess a small starting number to help with giant repos + middle = min(max_map_tokens // 25, num_tags) + + self.tree_cache = dict() + + while lower_bound <= upper_bound: + tree = self.to_tree(ranked_tags[:middle], chat_rel_fnames) + num_tokens = self.token_count(tree) + + if num_tokens < max_map_tokens and num_tokens > best_tree_tokens: + best_tree = tree + best_tree_tokens = num_tokens + + if num_tokens < max_map_tokens: + lower_bound = middle + 1 + else: + upper_bound = middle - 1 + + middle = (lower_bound + upper_bound) // 2 + + return best_tree + + tree_cache = dict() + + def render_tree(self, abs_fname, rel_fname, lois): + key = (rel_fname, tuple(sorted(lois))) + + if key in self.tree_cache: + return self.tree_cache[key] + + code = self.io.read_text(abs_fname) or "" + if not code.endswith("\n"): + code += "\n" + + context = TreeContext( + rel_fname, + code, + color=False, + line_number=False, + child_context=False, + last_line=False, + margin=0, + mark_lois=False, + loi_pad=0, + # header_max=30, + show_top_of_file_parent_scope=False, + ) + + context.add_lines_of_interest(lois) + context.add_context() + res = context.format() + self.tree_cache[key] = res + return res + + def to_tree(self, tags, chat_rel_fnames): + if not tags: + return "" + + tags = [tag for tag in tags if tag[0] not in chat_rel_fnames] + tags = sorted(tags) + + cur_fname = None + cur_abs_fname = None + lois = None + output = "" + + # add a bogus tag at the end so we trip the this_fname != cur_fname... + dummy_tag = (None,) + for tag in tags + [dummy_tag]: + this_rel_fname = tag[0] + + # ... here ... to output the final real entry in the list + if this_rel_fname != cur_fname: + if lois is not None: + output += "\n" + output += cur_fname + ":\n" + output += self.render_tree(cur_abs_fname, cur_fname, lois) + lois = None + elif cur_fname: + output += "\n" + cur_fname + "\n" + if type(tag) is Tag: + lois = [] + cur_abs_fname = tag.fname + cur_fname = this_rel_fname + + if lois is not None: + lois.append(tag.line) + + # truncate long lines, in case we get minified js or something else crazy + output = "\n".join([line[:100] for line in output.splitlines()]) + "\n" + + return output + + +def find_src_files(directory): + if not os.path.isdir(directory): + return [directory] + + src_files = [] + for root, dirs, files in os.walk(directory): + for file in files: + if ".git" in root: # skip git files + continue + src_files.append(os.path.join(root, file)) + return src_files + + +def get_random_color(): + hue = random.random() + r, g, b = [int(x * 255) for x in colorsys.hsv_to_rgb(hue, 1, 0.75)] + res = f"#{r:02x}{g:02x}{b:02x}" + return res + + +def get_supported_languages_md(): + from grep_ast.parsers import PARSERS + + res = "" + data = sorted((lang, ex) for ex, lang in PARSERS.items()) + for lang, ext in data: + res += "" + res += f'{lang:20}\n' + res += f'{ext:20}\n' + res += "" + return res diff --git a/code_editing/agents/run.py b/code_editing/agents/run.py index 46e0bfd..7e3b8fe 100644 --- a/code_editing/agents/run.py +++ b/code_editing/agents/run.py @@ -1,16 +1,23 @@ import collections -from dataclasses import asdict, dataclass -from typing import Dict +from enum import Enum +from typing import Dict, TypedDict from code_editing.agents.context_providers.context_provider import ContextProvider from code_editing.utils import wandb_utils -@dataclass -class ToolInfo: - calls: int = 0 - errors: int = 0 - failures: int = 0 +class ToolInfo(TypedDict): + calls: int + errors: int + failures: int + + +# enum class: calls, errors, failures +class ToolUseStatus(Enum): + CALL = "calls" + OK = "success" + FAIL = "failures" + THROWN = "errors" class RunOverviewManager: @@ -24,22 +31,17 @@ def __init__( self.data_path = data_path self.context_providers = context_providers self.metadata = {"repo_path": repo_path, "data_path": data_path, "context_providers": context_providers.keys()} - self.tools_info = collections.defaultdict(ToolInfo) + self.tools_info = collections.defaultdict(dict) self.start_ms = wandb_utils.get_current_ms() - def add_tool_use(self, tool_name): - self.tools_info[tool_name].calls += 1 - - def add_tool_error(self, tool_name): - self.tools_info[tool_name].errors += 1 - - def add_tool_failure(self, tool_name): - self.tools_info[tool_name].failures += 1 + def log_tool_use(self, tool_name, status: ToolUseStatus): + self.tools_info.setdefault(tool_name, {}).setdefault(status, 0) + self.tools_info[tool_name][status] += 1 def get_run_summary(self): end_ms = wandb_utils.get_current_ms() return { - "tools": {k: asdict(v) for k, v in self.tools_info.items()}, + "tools": self.tools_info, "start_ms": self.start_ms, "end_ms": end_ms, "duration_ms": end_ms - self.start_ms, diff --git a/code_editing/agents/tools/base_tool.py b/code_editing/agents/tools/base_tool.py index 328d655..59a1458 100644 --- a/code_editing/agents/tools/base_tool.py +++ b/code_editing/agents/tools/base_tool.py @@ -4,7 +4,7 @@ from langchain_core.tools import BaseTool, ToolException -from code_editing.agents.run import RunOverviewManager +from code_editing.agents.run import RunOverviewManager, ToolUseStatus class CEBaseTool(BaseTool, ABC): @@ -43,17 +43,18 @@ def __init__(self, run_overview_manager: RunOverviewManager = None, dry_run: boo def _run(self, *args: Any, **kwargs: Any) -> Any: # Track tool usage - self.run_overview_manager.add_tool_use(self.name) + self.run_overview_manager.log_tool_use(self.name, ToolUseStatus.CALL) try: # Run the tool self._run_tool(*args, **kwargs) + self.run_overview_manager.log_tool_use(self.name, ToolUseStatus.OK) except ToolException: # Track tool failure - self.run_overview_manager.add_tool_failure(self.name) + self.run_overview_manager.log_tool_use(self.name, ToolUseStatus.FAIL) raise except Exception: # Track tool error - self.run_overview_manager.add_tool_error(self.name) + self.run_overview_manager.log_tool_use(self.name, ToolUseStatus.THROWN) raise @abstractmethod diff --git a/code_editing/configs/agents/collect_edit/context_collectors_config.py b/code_editing/configs/agents/collect_edit/context_collectors_config.py index 0cd8ebc..295c4e8 100644 --- a/code_editing/configs/agents/collect_edit/context_collectors_config.py +++ b/code_editing/configs/agents/collect_edit/context_collectors_config.py @@ -52,6 +52,12 @@ class MyACRRetrievalConfig(ContextCollectorsConfig): _target_: str = f"{CE_CLASSES_ROOT_PKG}.agents.collect_edit.context_collectors.MyACRRetrieval" +@dataclass +class AiderRetrievalConfig(ContextCollectorsConfig): + _target_: str = f"{CE_CLASSES_ROOT_PKG}.agents.collect_edit.context_collectors.AiderRetrieval" + select_prompt: UserPromptConfig = field(default_factory=default_user_prompt("jbr-code-editing/repomap-search")) + + cs = ConfigStore.instance() cs.store(name="as_is_retrieval", group="graph/context_collector", node=AsIsRetrievalConfig) cs.store(name="llm_retrieval", group="graph/context_collector", node=LLMRetrievalConfig) @@ -60,3 +66,5 @@ class MyACRRetrievalConfig(ContextCollectorsConfig): cs.store(name="acr", group="graph/context_collector", node=ACRRetrievalConfig) cs.store(name="my_acr", group="graph/context_collector", node=MyACRRetrievalConfig) + +cs.store(name="aider", group="graph/context_collector", node=AiderRetrievalConfig) diff --git a/code_editing/configs/agents/context_providers/context_config.py b/code_editing/configs/agents/context_providers/context_config.py index 5ed8b08..82485b0 100644 --- a/code_editing/configs/agents/context_providers/context_config.py +++ b/code_editing/configs/agents/context_providers/context_config.py @@ -37,8 +37,14 @@ class BM25RetrievalConfig(RetrievalConfig): _target_: str = f"{CE_CLASSES_ROOT_PKG}.agents.context_providers.retrieval.BM25Retrieval" +@dataclass +class AiderRepoMapConfig(ContextConfig): + _target_: str = f"{CE_CLASSES_ROOT_PKG}.agents.context_providers.aider.AiderRepoMap" + + cs = ConfigStore.instance() cs.store(name="context", node=ContextConfig) cs.store(name="acr_search", group="context", node=ACRSearchManagerConfig) cs.store(name="faiss", group="context", node=FaissRetrievalConfig) cs.store(name="bm25", group="context", node=BM25RetrievalConfig) +cs.store(name="aider", group="context", node=AiderRepoMapConfig) diff --git a/code_editing/configs/inference_config.py b/code_editing/configs/inference_config.py index e9ee0e2..ea4c305 100644 --- a/code_editing/configs/inference_config.py +++ b/code_editing/configs/inference_config.py @@ -13,3 +13,4 @@ class InferenceConfig: output_path: Optional[str] = None wandb: WandbConfig = field(default_factory=WandbConfig) num_tries: int = 5 + skip_empty_diffs: bool = True diff --git a/code_editing/scripts/common.py b/code_editing/scripts/common.py index 38b7363..9f96d77 100644 --- a/code_editing/scripts/common.py +++ b/code_editing/scripts/common.py @@ -76,7 +76,7 @@ def process_datapoint(i): row_info = f"{data.repo}@{data.base_hash[:8]}" res = task.result() y_pred = res["prediction"] + "\n" - if y_pred.strip() == "": + if y_pred.strip() == "" and inference_config.skip_empty_diffs: raise ValueError("Empty prediction") except Exception as e: if "Empty prediction" in str(e): @@ -102,6 +102,9 @@ def process_datapoint(i): for k, v in new_run_summary["tools"][tool].items(): run_summary.setdefault("tools", {}).setdefault(tool, {}).setdefault(k, 0) run_summary["tools"][tool][k] += v + new_run_summary.pop("tools") + # upd rest + run_summary.update(new_run_summary) # log wandb.log(run_summary) # Add the result to the dataframe diff --git a/code_editing/scripts/conf/aider.yaml b/code_editing/scripts/conf/aider.yaml new file mode 100644 index 0000000..a76a2fe --- /dev/null +++ b/code_editing/scripts/conf/aider.yaml @@ -0,0 +1,12 @@ +defaults: + - agent_common + - graph: collect_edit + - graph/context_collector: aider + - graph/editor: simple_editor + - _self_ + +graph: + only_collect: true + +inference: + skip_empty_diffs: false diff --git a/code_editing/scripts/conf/context/all.yaml b/code_editing/scripts/conf/context/all.yaml index d3bf192..1d9fafc 100644 --- a/code_editing/scripts/conf/context/all.yaml +++ b/code_editing/scripts/conf/context/all.yaml @@ -1,3 +1,4 @@ defaults: - acr_search@search_manager - retrieval_helper@retrieval_helper + - aider@aider diff --git a/code_editing/utils/git_utils.py b/code_editing/utils/git_utils.py index 879a387..7b07e20 100644 --- a/code_editing/utils/git_utils.py +++ b/code_editing/utils/git_utils.py @@ -212,7 +212,6 @@ def apply_patch_unsafe(repo, patch: str): # --recount: Fix line numbers in the patch repo.git.execute(["git", "apply", "--unidiff-zero", "--recount", "--ignore-whitespace", file_name]) except Exception as e: - print(e) os.remove(file_name) raise e os.remove(file_name) diff --git a/poetry.lock b/poetry.lock index a6ae90a..31b8de5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -924,6 +924,17 @@ files = [ graph = ["objgraph (>=1.7.2)"] profile = ["gprof2dot (>=2022.7.29)"] +[[package]] +name = "diskcache" +version = "5.6.3" +description = "Disk Cache -- Disk and file backed persistent cache." +optional = false +python-versions = ">=3" +files = [ + {file = "diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19"}, + {file = "diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc"}, +] + [[package]] name = "distro" version = "1.9.0" @@ -1394,6 +1405,21 @@ files = [ docs = ["Sphinx", "furo"] test = ["objgraph", "psutil"] +[[package]] +name = "grep-ast" +version = "0.3.2" +description = "A tool to grep through the AST of a source file" +optional = false +python-versions = "*" +files = [ + {file = "grep_ast-0.3.2-py3-none-any.whl", hash = "sha256:b7ceb84743983c3f4f5bca82f3374534cd9dbd759792d0dedf5648fedbb6f3fc"}, + {file = "grep_ast-0.3.2.tar.gz", hash = "sha256:d53bc7d25dfefafe77643fec189ab38e3cbd839d546c070a950ebedad82ee164"}, +] + +[package.dependencies] +pathspec = "*" +tree-sitter-languages = ">=1.8.0" + [[package]] name = "gymnasium" version = "0.29.1" @@ -5864,6 +5890,122 @@ torchhub = ["filelock", "huggingface-hub (>=0.23.0,<1.0)", "importlib-metadata", video = ["av (==9.2.0)", "decord (==0.6.0)"] vision = ["Pillow (>=10.0.1,<=15.0)"] +[[package]] +name = "tree-sitter" +version = "0.21.3" +description = "Python bindings for the Tree-Sitter parsing library" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tree-sitter-0.21.3.tar.gz", hash = "sha256:b5de3028921522365aa864d95b3c41926e0ba6a85ee5bd000e10dc49b0766988"}, + {file = "tree_sitter-0.21.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:351f302b6615230c9dac9829f0ba20a94362cd658206ca9a7b2d58d73373dfb0"}, + {file = "tree_sitter-0.21.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:766e79ae1e61271e7fdfecf35b6401ad9b47fc07a0965ad78e7f97fddfdf47a6"}, + {file = "tree_sitter-0.21.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c4d3d4d4b44857e87de55302af7f2d051c912c466ef20e8f18158e64df3542a"}, + {file = "tree_sitter-0.21.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84eedb06615461b9e2847be7c47b9c5f2195d7d66d31b33c0a227eff4e0a0199"}, + {file = "tree_sitter-0.21.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9d33ea425df8c3d6436926fe2991429d59c335431bf4e3c71e77c17eb508be5a"}, + {file = "tree_sitter-0.21.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fae1ee0ff6d85e2fd5cd8ceb9fe4af4012220ee1e4cbe813305a316caf7a6f63"}, + {file = "tree_sitter-0.21.3-cp310-cp310-win_amd64.whl", hash = "sha256:bb41be86a987391f9970571aebe005ccd10222f39c25efd15826583c761a37e5"}, + {file = "tree_sitter-0.21.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:54b22c3c2aab3e3639a4b255d9df8455da2921d050c4829b6a5663b057f10db5"}, + {file = "tree_sitter-0.21.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ab6e88c1e2d5e84ff0f9e5cd83f21b8e5074ad292a2cf19df3ba31d94fbcecd4"}, + {file = "tree_sitter-0.21.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc3fd34ed4cd5db445bc448361b5da46a2a781c648328dc5879d768f16a46771"}, + {file = "tree_sitter-0.21.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fabc7182f6083269ce3cfcad202fe01516aa80df64573b390af6cd853e8444a1"}, + {file = "tree_sitter-0.21.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4f874c3f7d2a2faf5c91982dc7d88ff2a8f183a21fe475c29bee3009773b0558"}, + {file = "tree_sitter-0.21.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ee61ee3b7a4eedf9d8f1635c68ba4a6fa8c46929601fc48a907c6cfef0cfbcb2"}, + {file = "tree_sitter-0.21.3-cp311-cp311-win_amd64.whl", hash = "sha256:0b7256c723642de1c05fbb776b27742204a2382e337af22f4d9e279d77df7aa2"}, + {file = "tree_sitter-0.21.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:669b3e5a52cb1e37d60c7b16cc2221c76520445bb4f12dd17fd7220217f5abf3"}, + {file = "tree_sitter-0.21.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2aa2a5099a9f667730ff26d57533cc893d766667f4d8a9877e76a9e74f48f0d3"}, + {file = "tree_sitter-0.21.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a3e06ae2a517cf6f1abb682974f76fa760298e6d5a3ecf2cf140c70f898adf0"}, + {file = "tree_sitter-0.21.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af992dfe08b4fefcfcdb40548d0d26d5d2e0a0f2d833487372f3728cd0772b48"}, + {file = "tree_sitter-0.21.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c7cbab1dd9765138505c4a55e2aa857575bac4f1f8a8b0457744a4fefa1288e6"}, + {file = "tree_sitter-0.21.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e1e66aeb457d1529370fcb0997ae5584c6879e0e662f1b11b2f295ea57e22f54"}, + {file = "tree_sitter-0.21.3-cp312-cp312-win_amd64.whl", hash = "sha256:013c750252dc3bd0e069d82e9658de35ed50eecf31c6586d0de7f942546824c5"}, + {file = "tree_sitter-0.21.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:4986a8cb4acebd168474ec2e5db440e59c7888819b3449a43ce8b17ed0331b07"}, + {file = "tree_sitter-0.21.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6e217fee2e7be7dbce4496caa3d1c466977d7e81277b677f954d3c90e3272ec2"}, + {file = "tree_sitter-0.21.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f32a88afff4f2bc0f20632b0a2aa35fa9ae7d518f083409eca253518e0950929"}, + {file = "tree_sitter-0.21.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3652ac9e47cdddf213c5d5d6854194469097e62f7181c0a9aa8435449a163a9"}, + {file = "tree_sitter-0.21.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:60b4df3298ff467bc01e2c0f6c2fb43aca088038202304bf8e41edd9fa348f45"}, + {file = "tree_sitter-0.21.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:00e4d0c99dff595398ef5e88a1b1ddd53adb13233fb677c1fd8e497fb2361629"}, + {file = "tree_sitter-0.21.3-cp38-cp38-win_amd64.whl", hash = "sha256:50c91353a26946e4dd6779837ecaf8aa123aafa2d3209f261ab5280daf0962f5"}, + {file = "tree_sitter-0.21.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b17b8648b296ccc21a88d72ca054b809ee82d4b14483e419474e7216240ea278"}, + {file = "tree_sitter-0.21.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f2f057fd01d3a95cbce6794c6e9f6db3d376cb3bb14e5b0528d77f0ec21d6478"}, + {file = "tree_sitter-0.21.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:839759de30230ffd60687edbb119b31521d5ac016749358e5285816798bb804a"}, + {file = "tree_sitter-0.21.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5df40aa29cb7e323898194246df7a03b9676955a0ac1f6bce06bc4903a70b5f7"}, + {file = "tree_sitter-0.21.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:1d9be27dde007b569fa78ff9af5fe40d2532c998add9997a9729e348bb78fa59"}, + {file = "tree_sitter-0.21.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c4ac87735e6f98fe085244c7c020f0177d13d4c117db72ba041faa980d25d69d"}, + {file = "tree_sitter-0.21.3-cp39-cp39-win_amd64.whl", hash = "sha256:fbbd137f7d9a5309fb4cb82e2c3250ba101b0dd08a8abdce815661e6cf2cbc19"}, +] + +[[package]] +name = "tree-sitter-languages" +version = "1.10.2" +description = "Binary Python wheels for all tree sitter languages." +optional = false +python-versions = "*" +files = [ + {file = "tree_sitter_languages-1.10.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5580348f0b20233b1d5431fa178ccd3d07423ca4a3275df02a44608fd72344b9"}, + {file = "tree_sitter_languages-1.10.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:103c7466644486b1e9e03850df46fc6aa12f13ca636c74f173270276220ac80b"}, + {file = "tree_sitter_languages-1.10.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d13db84511c6f1a7dc40383b66deafa74dabd8b877e3d65ab253f3719eccafd6"}, + {file = "tree_sitter_languages-1.10.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57adfa32be7e465b54aa72f915f6c78a2b66b227df4f656b5d4fbd1ca7a92b3f"}, + {file = "tree_sitter_languages-1.10.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c6385e033e460ceb8f33f3f940335f422ef2b763700a04f0089391a68b56153"}, + {file = "tree_sitter_languages-1.10.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:dfa3f38cc5381c5aba01dd7494f59b8a9050e82ff6e06e1233e3a0cbae297e3c"}, + {file = "tree_sitter_languages-1.10.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:9f195155acf47f8bc5de7cee46ecd07b2f5697f007ba89435b51ef4c0b953ea5"}, + {file = "tree_sitter_languages-1.10.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2de330e2ac6d7426ca025a3ec0f10d5640c3682c1d0c7702e812dcfb44b58120"}, + {file = "tree_sitter_languages-1.10.2-cp310-cp310-win32.whl", hash = "sha256:c9731cf745f135d9770eeba9bb4e2ff4dabc107b5ae9b8211e919f6b9100ea6d"}, + {file = "tree_sitter_languages-1.10.2-cp310-cp310-win_amd64.whl", hash = "sha256:6dd75851c41d0c3c4987a9b7692d90fa8848706c23115669d8224ffd6571e357"}, + {file = "tree_sitter_languages-1.10.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7eb7d7542b2091c875fe52719209631fca36f8c10fa66970d2c576ae6a1b8289"}, + {file = "tree_sitter_languages-1.10.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6b41bcb00974b1c8a1800c7f1bb476a1d15a0463e760ee24872f2d53b08ee424"}, + {file = "tree_sitter_languages-1.10.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f370cd7845c6c81df05680d5bd96db8a99d32b56f4728c5d05978911130a853"}, + {file = "tree_sitter_languages-1.10.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a1dc195c88ef4c72607e112a809a69190e096a2e5ebc6201548b3e05fdd169ad"}, + {file = "tree_sitter_languages-1.10.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ae34ac314a7170be24998a0f994c1ac80761d8d4bd126af27ee53a023d3b849"}, + {file = "tree_sitter_languages-1.10.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:01b5742d5f5bd675489486b582bd482215880b26dde042c067f8265a6e925d9c"}, + {file = "tree_sitter_languages-1.10.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:ab1cbc46244d34fd16f21edaa20231b2a57f09f092a06ee3d469f3117e6eb954"}, + {file = "tree_sitter_languages-1.10.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0b1149e7467a4e92b8a70e6005fe762f880f493cf811fc003554b29f04f5e7c8"}, + {file = "tree_sitter_languages-1.10.2-cp311-cp311-win32.whl", hash = "sha256:049276343962f4696390ee555acc2c1a65873270c66a6cbe5cb0bca83bcdf3c6"}, + {file = "tree_sitter_languages-1.10.2-cp311-cp311-win_amd64.whl", hash = "sha256:7f3fdd468a577f04db3b63454d939e26e360229b53c80361920aa1ebf2cd7491"}, + {file = "tree_sitter_languages-1.10.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c0f4c8b2734c45859edc7fcaaeaab97a074114111b5ba51ab4ec7ed52104763c"}, + {file = "tree_sitter_languages-1.10.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:eecd3c1244ac3425b7a82ba9125b4ddb45d953bbe61de114c0334fd89b7fe782"}, + {file = "tree_sitter_languages-1.10.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15db3c8510bc39a80147ee7421bf4782c15c09581c1dc2237ea89cefbd95b846"}, + {file = "tree_sitter_languages-1.10.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:92c6487a6feea683154d3e06e6db68c30e0ae749a7ce4ce90b9e4e46b78c85c7"}, + {file = "tree_sitter_languages-1.10.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2f1cd1d1bdd65332f9c2b67d49dcf148cf1ded752851d159ac3e5ee4f4d260"}, + {file = "tree_sitter_languages-1.10.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:976c8039165b8e12f17a01ddee9f4e23ec6e352b165ad29b44d2bf04e2fbe77e"}, + {file = "tree_sitter_languages-1.10.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:dafbbdf16bf668a580902e1620f4baa1913e79438abcce721a50647564c687b9"}, + {file = "tree_sitter_languages-1.10.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1aeabd3d60d6d276b73cd8f3739d595b1299d123cc079a317f1a5b3c5461e2ca"}, + {file = "tree_sitter_languages-1.10.2-cp312-cp312-win32.whl", hash = "sha256:fab8ee641914098e8933b87ea3d657bea4dd00723c1ee7038b847b12eeeef4f5"}, + {file = "tree_sitter_languages-1.10.2-cp312-cp312-win_amd64.whl", hash = "sha256:5e606430d736367e5787fa5a7a0c5a1ec9b85eded0b3596bbc0d83532a40810b"}, + {file = "tree_sitter_languages-1.10.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:838d5b48a7ed7a17658721952c77fda4570d2a069f933502653b17e15a9c39c9"}, + {file = "tree_sitter_languages-1.10.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:987b3c71b1d278c2889e018ee77b8ee05c384e2e3334dec798f8b611c4ab2d1e"}, + {file = "tree_sitter_languages-1.10.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:faa00abcb2c819027df58472da055d22fa7dfcb77c77413d8500c32ebe24d38b"}, + {file = "tree_sitter_languages-1.10.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e102fbbf02322d9201a86a814e79a9734ac80679fdb9682144479044f401a73"}, + {file = "tree_sitter_languages-1.10.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:8f0b87cf1a7b03174ba18dfd81582be82bfed26803aebfe222bd20e444aba003"}, + {file = "tree_sitter_languages-1.10.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c0f1b9af9cb67f0b942b020da9fdd000aad5e92f2383ae0ba7a330b318d31912"}, + {file = "tree_sitter_languages-1.10.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:5a4076c921f7a4d31e643843de7dfe040b65b63a238a5aa8d31d93aabe6572aa"}, + {file = "tree_sitter_languages-1.10.2-cp37-cp37m-win32.whl", hash = "sha256:fa6391a3a5d83d32db80815161237b67d70576f090ce5f38339206e917a6f8bd"}, + {file = "tree_sitter_languages-1.10.2-cp37-cp37m-win_amd64.whl", hash = "sha256:55649d3f254585a064121513627cf9788c1cfdadbc5f097f33d5ba750685a4c0"}, + {file = "tree_sitter_languages-1.10.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6f85d1edaa2d22d80d4ea5b6d12b95cf3644017b6c227d0d42854439e02e8893"}, + {file = "tree_sitter_languages-1.10.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d78feed4a764ef3141cb54bf00fe94d514d8b6e26e09423e23b4c616fcb7938c"}, + {file = "tree_sitter_languages-1.10.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da1aca27531f9dd5308637d76643372856f0f65d0d28677d1bcf4211e8ed1ad0"}, + {file = "tree_sitter_languages-1.10.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1031ea440dafb72237437d754eff8940153a3b051e3d18932ac25e75ce060a15"}, + {file = "tree_sitter_languages-1.10.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99d3249beaef2c9fe558ecc9a97853c260433a849dcc68266d9770d196c2e102"}, + {file = "tree_sitter_languages-1.10.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:59a4450f262a55148fb7e68681522f0c2a2f6b7d89666312a2b32708d8f416e1"}, + {file = "tree_sitter_languages-1.10.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:ce74eab0e430370d5e15a96b6c6205f93405c177a8b2e71e1526643b2fb9bab1"}, + {file = "tree_sitter_languages-1.10.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9b4dd2b6b3d24c85dffe33d6c343448869eaf4f41c19ddba662eb5d65d8808f4"}, + {file = "tree_sitter_languages-1.10.2-cp38-cp38-win32.whl", hash = "sha256:92d734fb968fe3927a7596d9f0459f81a8fa7b07e16569476b28e27d0d753348"}, + {file = "tree_sitter_languages-1.10.2-cp38-cp38-win_amd64.whl", hash = "sha256:46a13f7d38f2eeb75f7cf127d1201346093748c270d686131f0cbc50e42870a1"}, + {file = "tree_sitter_languages-1.10.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f8c6a936ae99fdd8857e91f86c11c2f5e507ff30631d141d98132bb7ab2c8638"}, + {file = "tree_sitter_languages-1.10.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c283a61423f49cdfa7b5a5dfbb39221e3bd126fca33479cd80749d4d7a6b7349"}, + {file = "tree_sitter_languages-1.10.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76e60be6bdcff923386a54a5edcb6ff33fc38ab0118636a762024fa2bc98de55"}, + {file = "tree_sitter_languages-1.10.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c00069f9575bd831eabcce2cdfab158dde1ed151e7e5614c2d985ff7d78a7de1"}, + {file = "tree_sitter_languages-1.10.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:475ff53203d8a43ccb19bb322fa2fb200d764001cc037793f1fadd714bb343da"}, + {file = "tree_sitter_languages-1.10.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:26fe7c9c412e4141dea87ea4b3592fd12e385465b5bdab106b0d5125754d4f60"}, + {file = "tree_sitter_languages-1.10.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:8fed27319957458340f24fe14daad467cd45021da034eef583519f83113a8c5e"}, + {file = "tree_sitter_languages-1.10.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3657a491a7f96cc75a3568ddd062d25f3be82b6a942c68801a7b226ff7130181"}, + {file = "tree_sitter_languages-1.10.2-cp39-cp39-win32.whl", hash = "sha256:33f7d584d01a7a3c893072f34cfc64ec031f3cfe57eebc32da2f8ac046e101a7"}, + {file = "tree_sitter_languages-1.10.2-cp39-cp39-win_amd64.whl", hash = "sha256:1b944af3ee729fa70fc8ae82224a9ff597cdb63addea084e0ea2fa2b0ec39bb7"}, +] + +[package.dependencies] +tree-sitter = "*" + [[package]] name = "trio" version = "0.25.1" @@ -6422,4 +6564,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "~3.10" -content-hash = "3f9d46c82f1a6fa4d833b276a4e5b89aad38cb46a951828cfe0930a7237a6b72" +content-hash = "8eebf58ce46671c4c65143b48f4d1c84ba6e7b848e3dd4bf3523d2202ad82399" diff --git a/pyproject.toml b/pyproject.toml index 9072681..6e0b633 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,12 @@ scikit-learn = "^1.5.0" scipy = "^1.13.1" +[tool.poetry.group.aider.dependencies] +grep-ast = "0.3.2" +tree-sitter = "0.21.3" +tree-sitter-languages = "1.10.2" +diskcache = "5.6.3" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api"