Skip to content

Commit

Permalink
Fix Github.CountStargazers and add tests (#92)
Browse files Browse the repository at this point in the history
## Problem
I found a bug with `Github.CountStargazers` where a stargazer count of
`0` was interpreted as a null result. In other words, 0 wasn't passed
back to the Engine correctly.
Separately, the tool function was also not authorized correctly.

## Fix
- Don't use a falsy comparison when evaluating `result` inside the
`ToolOutputFactory`
- Add unit tests for `ToolOutputFactory` to give us confidence in the
business logic
- Added `ToolContext` to pass in the authorization token correctly.

Before
```
User (nate@arcade-ai.com): 
how many stars does the ArcadeAI/Docs repo have on github?

Assistant (gpt-4o): 
I successfully checked the repository, but unfortunately, I cannot provide the number of stars for the ArcadeAI/Docs repository. Please try checking directly on GitHub for the most accurate information.                                                                                        
Called tool 'Github_CountStargazers'
Parameters:{"owner":"ArcadeAI","name":"Docs"}
'Github_CountStargazers' tool returned:Github.CountStargazers called successfully
```

After
```
User (nate@arcade-ai.com): 
how many stars does the ArcadeAI/Docs repo have on github?

Assistant (gpt-4o): 
The ArcadeAI/Docs repository on GitHub has 0 stars.                                                                                                                                                                 
Called tool 'Github_CountStargazers'
Parameters:{"owner":"ArcadeAI","name":"Docs"}
'Github_CountStargazers' tool returned:0
  • Loading branch information
nbarbettini authored Oct 4, 2024
1 parent 8444039 commit 56fc83b
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 7 deletions.
12 changes: 9 additions & 3 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,23 @@
"type": "python",
"request": "launch",
"module": "uvicorn",
"args": ["main:app", "--app-dir", "${workspaceFolder}/examples/fastapi/arcade_example_fastapi", "--port", "8002"],
"args": [
"main:app",
"--app-dir",
"${workspaceFolder}/examples/fastapi/arcade_example_fastapi",
"--port",
"8002"
],
"jinja": true,
"justMyCode": true,
"cwd": "${workspaceFolder}/examples/fastapi/arcade_example_fastapi"
},
{
"name": "Debug `arcade dev --no-auth`",
"name": "Debug `arcade actorup --no-auth`",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/arcade/run_cli.py",
"args": ["dev", "--no-auth"],
"args": ["actorup", "--no-auth"],
"console": "integratedTerminal",
"jinja": true,
"justMyCode": true,
Expand Down
3 changes: 1 addition & 2 deletions arcade/arcade/core/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ def success(
*,
data: T | None = None,
) -> ToolCallOutput:
value = data.result if data and hasattr(data, "result") and data.result else ""

value = getattr(data, "result", "") if data else ""
return ToolCallOutput(value=value)

def fail(self, *, message: str, developer_message: str | None = None) -> ToolCallOutput:
Expand Down
75 changes: 75 additions & 0 deletions arcade/tests/core/test_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Any

import pytest
from pydantic import BaseModel

from arcade.core.output import ToolOutputFactory


@pytest.fixture
def output_factory():
return ToolOutputFactory()


class SampleOutputModel(BaseModel):
result: Any


@pytest.mark.parametrize(
"data, expected_value",
[
(None, ""),
("success", "success"),
("", ""),
(None, ""),
(123, 123),
(0, 0),
(123.45, 123.45),
(True, True),
(False, False),
],
)
def test_success(output_factory, data, expected_value):
data_obj = SampleOutputModel(result=data) if data is not None else None
output = output_factory.success(data=data_obj)
assert output.value == expected_value
assert output.error is None


@pytest.mark.parametrize(
"message, developer_message",
[
("Error occurred", None),
("Error occurred", "Detailed error message"),
],
)
def test_fail(output_factory, message, developer_message):
output = output_factory.fail(message=message, developer_message=developer_message)
assert output.error is not None
assert output.error.message == message
assert output.error.developer_message == developer_message
assert output.error.can_retry is False


@pytest.mark.parametrize(
"message, developer_message, additional_prompt_content, retry_after_ms",
[
("Retry error", None, None, None),
("Retry error", "Retrying", "Please try again with this additional data: foobar", 1000),
],
)
def test_fail_retry(
output_factory, message, developer_message, additional_prompt_content, retry_after_ms
):
output = output_factory.fail_retry(
message=message,
developer_message=developer_message,
additional_prompt_content=additional_prompt_content,
retry_after_ms=retry_after_ms,
)
assert output.error is not None
assert output.error.message == message
assert output.error.developer_message == developer_message
assert output.error.can_retry is True
assert output.error.additional_prompt_content == additional_prompt_content
assert output.error.retry_after_ms == retry_after_ms
7 changes: 5 additions & 2 deletions toolkits/github/arcade_github/tools/repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
# Example arcade chat usage: "How many stargazers does the <OWNER>/<REPO> repo have?"
@tool(requires_auth=GitHub())
async def count_stargazers(
context: ToolContext,
owner: Annotated[str, "The owner of the repository"],
name: Annotated[str, "The name of the repository"],
) -> Annotated[int, "The number of stargazers (stars) for the specified repository"]:
Expand All @@ -36,15 +37,17 @@ async def count_stargazers(
```
"""

headers = get_github_json_headers(context.authorization.token)

url = get_url("repo", owner=owner, repo=name)
async with httpx.AsyncClient() as client:
response = await client.get(url)
response = await client.get(url, headers=headers)

handle_github_response(response, url)

data = response.json()
stargazers_count = data.get("stargazers_count", 0)
return f"The repository {owner}/{name} has {stargazers_count} stargazers."
return stargazers_count


# Implements https://docs.github.com/en/rest/repos/repos?apiVersion=2022-11-28#list-organization-repositories
Expand Down

0 comments on commit 56fc83b

Please sign in to comment.