Skip to content

Commit 1de7f97

Browse files
authored
backend: Filter Tool Auth client access token (#859)
* WIP * Add tests * Update lock file * Mock Google Drive for CI * testing * Resolve test for remote env
1 parent 5502d8d commit 1de7f97

File tree

7 files changed

+148
-22
lines changed

7 files changed

+148
-22
lines changed

poetry.lock

Lines changed: 40 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ pytest = "^7.1.2"
6767
pytest-env = "^1.1.3"
6868
pytest-cov = "^5.0.0"
6969
factory-boy = "^3.3.0"
70+
fakeredis = "^2.26.1"
7071
freezegun = "^1.5.1"
7172
pre-commit = "^2.20.0"
7273
ruff = "^0.6.0"

src/backend/routers/tool.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,11 @@ def list_tools(
5555
session, user_id
5656
)
5757
tool.auth_url = tool_auth_service.get_auth_url(user_id)
58-
tool.token = tool_auth_service.get_token(session, user_id)
58+
59+
# Return access token to client when required by frontend
60+
# e.g: to enable Google Drive picker in client
61+
if tool.should_return_token:
62+
tool.token = tool_auth_service.get_token(session, user_id)
5963
except Exception as e:
6064
logger.error(event=f"Error while fetching Tool Auth: {str(e)}")
6165
tool.is_auth_required = True

src/backend/schemas/tool.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class ToolDefinition(Tool):
2727
is_auth_required: bool = False # Per user
2828
auth_url: Optional[str] = "" # Per user
2929
token: Optional[str] = "" # Per user
30+
should_return_token: bool = False
3031

3132
implementation: Any = Field(exclude=True)
3233
auth_implementation: Any = Field(default=None, exclude=True)

src/backend/tests/unit/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
from typing import Any, Generator
33
from unittest.mock import patch
44

5+
import fakeredis
56
import pytest
67
from alembic.command import upgrade
78
from alembic.config import Config
89
from fastapi.testclient import TestClient
10+
from redis import Redis
911
from sqlalchemy import create_engine
1012
from sqlalchemy.orm import Session
1113

@@ -146,6 +148,18 @@ def override_get_session() -> Generator[Session, Any, None]:
146148
app.dependency_overrides = {}
147149

148150

151+
152+
@pytest.fixture(autouse=True)
153+
def mock_redis_client():
154+
"""
155+
A pytest fixture that globally replaces `Redis.from_url` with `fakeredis`.
156+
"""
157+
fake_redis = fakeredis.FakeStrictRedis(decode_responses=True)
158+
159+
# Patch Redis.from_url to always return the fake Redis instance
160+
with patch.object(Redis, 'from_url', return_value=fake_redis):
161+
yield fake_redis
162+
149163
@pytest.fixture
150164
def user(session: Session) -> User:
151165
return get_factory("User", session).create(id="1")

src/backend/tests/unit/routers/test_tool.py

Lines changed: 86 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,104 @@
1+
from unittest.mock import patch
2+
3+
import pytest
14
from fastapi.testclient import TestClient
25
from sqlalchemy.orm import Session
36

47
from backend.config.tools import Tool, get_available_tools
8+
from backend.database_models.database import DBSessionDep
9+
from backend.schemas.tool import ToolCategory, ToolDefinition
510
from backend.schemas.user import User
611
from backend.tests.unit.factories import get_factory
12+
from backend.tools.base import BaseTool
13+
14+
TOOL_DEFINITION_KEYS = [
15+
"name",
16+
"display_name",
17+
"parameter_definitions",
18+
"is_visible",
19+
"is_available",
20+
"should_return_token",
21+
"category",
22+
"description"
23+
]
724

25+
@pytest.fixture
26+
def mock_get_available_tools():
27+
with patch("backend.routers.tool.get_available_tools") as mock:
28+
yield mock
829

9-
def test_list_tools(session_client: TestClient, session: Session) -> None:
30+
def test_list_tools(session_client: TestClient) -> None:
1031
response = session_client.get("/v1/tools")
1132
assert response.status_code == 200
1233
available_tools = get_available_tools()
1334
for tool in response.json():
14-
assert tool["name"] in available_tools.keys()
15-
assert tool["kwargs"] is not None
16-
assert tool["is_visible"] is not None
17-
assert tool["is_available"] is not None
18-
assert tool["category"] is not None
19-
assert tool["description"] is not None
35+
tool_definition = available_tools.get(tool["name"])
36+
assert tool_definition is not None
37+
38+
for key in TOOL_DEFINITION_KEYS:
39+
assert tool[key] == getattr(tool_definition, key)
40+
41+
def test_list_authed_tool_should_return_token(session_client: TestClient, mock_get_available_tools) -> None:
42+
class MockGoogleDriveAuth():
43+
def is_auth_required(self, session: DBSessionDep, user_id: str) -> bool:
44+
return False
45+
46+
def get_auth_url(self, user_id: str) -> str:
47+
return ""
48+
49+
def get_token(self, session: DBSessionDep, user_id: str) -> str:
50+
return "mock"
51+
class MockGoogleDrive(BaseTool):
52+
ID = "google_drive"
53+
@classmethod
54+
def get_tool_definition(cls) -> ToolDefinition:
55+
return ToolDefinition(
56+
name=cls.ID,
57+
display_name="Google Drive",
58+
implementation=cls,
59+
parameter_definitions={
60+
"query": {
61+
"description": "Query to search Google Drive documents with.",
62+
"type": "str",
63+
"required": True,
64+
}
65+
},
66+
is_visible=True,
67+
is_available=True,
68+
auth_implementation=MockGoogleDriveAuth,
69+
should_return_token=True,
70+
error_message=cls.generate_error_message(),
71+
category=ToolCategory.DataLoader,
72+
description="Returns a list of relevant document snippets from the user's Google drive.",
73+
)
74+
75+
# Patch Google Drive tool
76+
mock_get_available_tools.return_value = {Tool.Google_Drive.value.ID: MockGoogleDrive.get_tool_definition()}
77+
78+
response = session_client.get("/v1/tools")
79+
assert response.status_code == 200
80+
81+
for tool in response.json():
82+
print(tool)
83+
if tool["should_return_token"]:
84+
assert tool["token"] == "mock"
85+
86+
def test_list_authed_tool_should_not_return_token(session_client: TestClient) -> None:
87+
response = session_client.get("/v1/tools")
88+
89+
assert response.status_code == 200
2090

91+
for tool in response.json():
92+
if not tool["should_return_token"]:
93+
assert tool["token"] == ""
2194

22-
def test_list_tools_error_message_none_if_available(client: TestClient) -> None:
23-
response = client.get("/v1/tools")
95+
def test_list_tools_error_message_none_if_available(session_client: TestClient) -> None:
96+
response = session_client.get("/v1/tools")
2497
assert response.status_code == 200
2598
for tool in response.json():
2699
if tool["is_available"]:
27100
assert tool["error_message"] is None
28101

29-
30102
def test_list_tools_with_agent(
31103
session_client: TestClient, session: Session, user: User
32104
) -> None:
@@ -42,18 +114,13 @@ def test_list_tools_with_agent(
42114
assert tool["name"] == Tool.Wiki_Retriever_LangChain.value.ID
43115

44116
# get tool that has the same name as the tool in the response
45-
tool_definition = get_available_tools()[tool["name"]]
46-
47-
assert tool["kwargs"] == tool_definition.kwargs
48-
assert tool["is_visible"] == tool_definition.is_visible
49-
assert tool["is_available"] == tool_definition.is_available
50-
assert tool["error_message"] == tool_definition.error_message
51-
assert tool["category"] == tool_definition.category
52-
assert tool["description"] == tool_definition.description
117+
tool_definition = get_available_tools().get(tool["name"])
53118

119+
for key in TOOL_DEFINITION_KEYS:
120+
assert tool[key] == getattr(tool_definition, key)
54121

55122
def test_list_tools_with_agent_that_doesnt_exist(
56-
session_client: TestClient, session: Session
123+
session_client: TestClient
57124
) -> None:
58125
response = session_client.get("/v1/tools", params={"agent_id": "fake_id"})
59126
assert response.status_code == 404

src/backend/tools/google_drive/tool.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def get_tool_definition(cls) -> ToolDefinition:
4949
is_visible=True,
5050
is_available=GoogleDrive.is_available(),
5151
auth_implementation=GoogleDriveAuth,
52+
should_return_token=True,
5253
error_message=cls.generate_error_message(),
5354
category=ToolCategory.DataLoader,
5455
description="Returns a list of relevant document snippets from the user's Google drive.",

0 commit comments

Comments
 (0)