1
+ from unittest .mock import patch
2
+
3
+ import pytest
1
4
from fastapi .testclient import TestClient
2
5
from sqlalchemy .orm import Session
3
6
4
7
from backend .config .tools import Tool , get_available_tools
8
+ from backend .database_models .database import DBSessionDep
9
+ from backend .schemas .tool import ToolCategory , ToolDefinition
5
10
from backend .schemas .user import User
6
11
from backend .tests .unit .factories import get_factory
12
+ from backend .tools .base import BaseTool
13
+
14
+ TOOL_DEFINITION_KEYS = [
15
+ "name" ,
16
+ "display_name" ,
17
+ "parameter_definitions" ,
18
+ "is_visible" ,
19
+ "is_available" ,
20
+ "should_return_token" ,
21
+ "category" ,
22
+ "description"
23
+ ]
7
24
25
+ @pytest .fixture
26
+ def mock_get_available_tools ():
27
+ with patch ("backend.routers.tool.get_available_tools" ) as mock :
28
+ yield mock
8
29
9
- def test_list_tools (session_client : TestClient , session : Session ) -> None :
30
+ def test_list_tools (session_client : TestClient ) -> None :
10
31
response = session_client .get ("/v1/tools" )
11
32
assert response .status_code == 200
12
33
available_tools = get_available_tools ()
13
34
for tool in response .json ():
14
- assert tool ["name" ] in available_tools .keys ()
15
- assert tool ["kwargs" ] is not None
16
- assert tool ["is_visible" ] is not None
17
- assert tool ["is_available" ] is not None
18
- assert tool ["category" ] is not None
19
- assert tool ["description" ] is not None
35
+ tool_definition = available_tools .get (tool ["name" ])
36
+ assert tool_definition is not None
37
+
38
+ for key in TOOL_DEFINITION_KEYS :
39
+ assert tool [key ] == getattr (tool_definition , key )
40
+
41
+ def test_list_authed_tool_should_return_token (session_client : TestClient , mock_get_available_tools ) -> None :
42
+ class MockGoogleDriveAuth ():
43
+ def is_auth_required (self , session : DBSessionDep , user_id : str ) -> bool :
44
+ return False
45
+
46
+ def get_auth_url (self , user_id : str ) -> str :
47
+ return ""
48
+
49
+ def get_token (self , session : DBSessionDep , user_id : str ) -> str :
50
+ return "mock"
51
+ class MockGoogleDrive (BaseTool ):
52
+ ID = "google_drive"
53
+ @classmethod
54
+ def get_tool_definition (cls ) -> ToolDefinition :
55
+ return ToolDefinition (
56
+ name = cls .ID ,
57
+ display_name = "Google Drive" ,
58
+ implementation = cls ,
59
+ parameter_definitions = {
60
+ "query" : {
61
+ "description" : "Query to search Google Drive documents with." ,
62
+ "type" : "str" ,
63
+ "required" : True ,
64
+ }
65
+ },
66
+ is_visible = True ,
67
+ is_available = True ,
68
+ auth_implementation = MockGoogleDriveAuth ,
69
+ should_return_token = True ,
70
+ error_message = cls .generate_error_message (),
71
+ category = ToolCategory .DataLoader ,
72
+ description = "Returns a list of relevant document snippets from the user's Google drive." ,
73
+ )
74
+
75
+ # Patch Google Drive tool
76
+ mock_get_available_tools .return_value = {Tool .Google_Drive .value .ID : MockGoogleDrive .get_tool_definition ()}
77
+
78
+ response = session_client .get ("/v1/tools" )
79
+ assert response .status_code == 200
80
+
81
+ for tool in response .json ():
82
+ print (tool )
83
+ if tool ["should_return_token" ]:
84
+ assert tool ["token" ] == "mock"
85
+
86
+ def test_list_authed_tool_should_not_return_token (session_client : TestClient ) -> None :
87
+ response = session_client .get ("/v1/tools" )
88
+
89
+ assert response .status_code == 200
20
90
91
+ for tool in response .json ():
92
+ if not tool ["should_return_token" ]:
93
+ assert tool ["token" ] == ""
21
94
22
- def test_list_tools_error_message_none_if_available (client : TestClient ) -> None :
23
- response = client .get ("/v1/tools" )
95
+ def test_list_tools_error_message_none_if_available (session_client : TestClient ) -> None :
96
+ response = session_client .get ("/v1/tools" )
24
97
assert response .status_code == 200
25
98
for tool in response .json ():
26
99
if tool ["is_available" ]:
27
100
assert tool ["error_message" ] is None
28
101
29
-
30
102
def test_list_tools_with_agent (
31
103
session_client : TestClient , session : Session , user : User
32
104
) -> None :
@@ -42,18 +114,13 @@ def test_list_tools_with_agent(
42
114
assert tool ["name" ] == Tool .Wiki_Retriever_LangChain .value .ID
43
115
44
116
# get tool that has the same name as the tool in the response
45
- tool_definition = get_available_tools ()[tool ["name" ]]
46
-
47
- assert tool ["kwargs" ] == tool_definition .kwargs
48
- assert tool ["is_visible" ] == tool_definition .is_visible
49
- assert tool ["is_available" ] == tool_definition .is_available
50
- assert tool ["error_message" ] == tool_definition .error_message
51
- assert tool ["category" ] == tool_definition .category
52
- assert tool ["description" ] == tool_definition .description
117
+ tool_definition = get_available_tools ().get (tool ["name" ])
53
118
119
+ for key in TOOL_DEFINITION_KEYS :
120
+ assert tool [key ] == getattr (tool_definition , key )
54
121
55
122
def test_list_tools_with_agent_that_doesnt_exist (
56
- session_client : TestClient , session : Session
123
+ session_client : TestClient
57
124
) -> None :
58
125
response = session_client .get ("/v1/tools" , params = {"agent_id" : "fake_id" })
59
126
assert response .status_code == 404
0 commit comments