From 56fc83bf3e9ab8bf15c0673fb0a3e503e58c5282 Mon Sep 17 00:00:00 2001 From: Nate Barbettini Date: Fri, 4 Oct 2024 16:09:08 -0700 Subject: [PATCH] Fix Github.CountStargazers and add tests (#92) ## Problem I found a bug with `Github.CountStargazers` where a stargazer count of `0` was interpreted as a null result. In other words, 0 wasn't passed back to the Engine correctly. Separately, the tool function was also not authorized correctly. ## Fix - Don't use a falsy comparison when evaluating `result` inside the `ToolOutputFactory` - Add unit tests for `ToolOutputFactory` to give us confidence in the business logic - Added `ToolContext` to pass in the authorization token correctly. Before ``` User (nate@arcade-ai.com): how many stars does the ArcadeAI/Docs repo have on github? Assistant (gpt-4o): I successfully checked the repository, but unfortunately, I cannot provide the number of stars for the ArcadeAI/Docs repository. Please try checking directly on GitHub for the most accurate information. Called tool 'Github_CountStargazers' Parameters:{"owner":"ArcadeAI","name":"Docs"} 'Github_CountStargazers' tool returned:Github.CountStargazers called successfully ``` After ``` User (nate@arcade-ai.com): how many stars does the ArcadeAI/Docs repo have on github? Assistant (gpt-4o): The ArcadeAI/Docs repository on GitHub has 0 stars. Called tool 'Github_CountStargazers' Parameters:{"owner":"ArcadeAI","name":"Docs"} 'Github_CountStargazers' tool returned:0 --- .vscode/launch.json | 12 ++- arcade/arcade/core/output.py | 3 +- arcade/tests/core/test_output.py | 75 +++++++++++++++++++ .../arcade_github/tools/repositories.py | 7 +- 4 files changed, 90 insertions(+), 7 deletions(-) create mode 100644 arcade/tests/core/test_output.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 0e597ca4..6a6c309e 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -6,17 +6,23 @@ "type": "python", "request": "launch", "module": "uvicorn", - "args": ["main:app", "--app-dir", "${workspaceFolder}/examples/fastapi/arcade_example_fastapi", "--port", "8002"], + "args": [ + "main:app", + "--app-dir", + "${workspaceFolder}/examples/fastapi/arcade_example_fastapi", + "--port", + "8002" + ], "jinja": true, "justMyCode": true, "cwd": "${workspaceFolder}/examples/fastapi/arcade_example_fastapi" }, { - "name": "Debug `arcade dev --no-auth`", + "name": "Debug `arcade actorup --no-auth`", "type": "python", "request": "launch", "program": "${workspaceFolder}/arcade/run_cli.py", - "args": ["dev", "--no-auth"], + "args": ["actorup", "--no-auth"], "console": "integratedTerminal", "jinja": true, "justMyCode": true, diff --git a/arcade/arcade/core/output.py b/arcade/arcade/core/output.py index f4c7d61c..42da8e69 100644 --- a/arcade/arcade/core/output.py +++ b/arcade/arcade/core/output.py @@ -15,8 +15,7 @@ def success( *, data: T | None = None, ) -> ToolCallOutput: - value = data.result if data and hasattr(data, "result") and data.result else "" - + value = getattr(data, "result", "") if data else "" return ToolCallOutput(value=value) def fail(self, *, message: str, developer_message: str | None = None) -> ToolCallOutput: diff --git a/arcade/tests/core/test_output.py b/arcade/tests/core/test_output.py new file mode 100644 index 00000000..fb62b413 --- /dev/null +++ b/arcade/tests/core/test_output.py @@ -0,0 +1,75 @@ +from typing import Any + +import pytest +from pydantic import BaseModel + +from arcade.core.output import ToolOutputFactory + + +@pytest.fixture +def output_factory(): + return ToolOutputFactory() + + +class SampleOutputModel(BaseModel): + result: Any + + +@pytest.mark.parametrize( + "data, expected_value", + [ + (None, ""), + ("success", "success"), + ("", ""), + (None, ""), + (123, 123), + (0, 0), + (123.45, 123.45), + (True, True), + (False, False), + ], +) +def test_success(output_factory, data, expected_value): + data_obj = SampleOutputModel(result=data) if data is not None else None + output = output_factory.success(data=data_obj) + assert output.value == expected_value + assert output.error is None + + +@pytest.mark.parametrize( + "message, developer_message", + [ + ("Error occurred", None), + ("Error occurred", "Detailed error message"), + ], +) +def test_fail(output_factory, message, developer_message): + output = output_factory.fail(message=message, developer_message=developer_message) + assert output.error is not None + assert output.error.message == message + assert output.error.developer_message == developer_message + assert output.error.can_retry is False + + +@pytest.mark.parametrize( + "message, developer_message, additional_prompt_content, retry_after_ms", + [ + ("Retry error", None, None, None), + ("Retry error", "Retrying", "Please try again with this additional data: foobar", 1000), + ], +) +def test_fail_retry( + output_factory, message, developer_message, additional_prompt_content, retry_after_ms +): + output = output_factory.fail_retry( + message=message, + developer_message=developer_message, + additional_prompt_content=additional_prompt_content, + retry_after_ms=retry_after_ms, + ) + assert output.error is not None + assert output.error.message == message + assert output.error.developer_message == developer_message + assert output.error.can_retry is True + assert output.error.additional_prompt_content == additional_prompt_content + assert output.error.retry_after_ms == retry_after_ms diff --git a/toolkits/github/arcade_github/tools/repositories.py b/toolkits/github/arcade_github/tools/repositories.py index 3d7b18a8..31352633 100644 --- a/toolkits/github/arcade_github/tools/repositories.py +++ b/toolkits/github/arcade_github/tools/repositories.py @@ -26,6 +26,7 @@ # Example arcade chat usage: "How many stargazers does the / repo have?" @tool(requires_auth=GitHub()) async def count_stargazers( + context: ToolContext, owner: Annotated[str, "The owner of the repository"], name: Annotated[str, "The name of the repository"], ) -> Annotated[int, "The number of stargazers (stars) for the specified repository"]: @@ -36,15 +37,17 @@ async def count_stargazers( ``` """ + headers = get_github_json_headers(context.authorization.token) + url = get_url("repo", owner=owner, repo=name) async with httpx.AsyncClient() as client: - response = await client.get(url) + response = await client.get(url, headers=headers) handle_github_response(response, url) data = response.json() stargazers_count = data.get("stargazers_count", 0) - return f"The repository {owner}/{name} has {stargazers_count} stargazers." + return stargazers_count # Implements https://docs.github.com/en/rest/repos/repos?apiVersion=2022-11-28#list-organization-repositories