1
- import os
2
-
3
1
import pytest
4
2
from fastapi .testclient import TestClient
5
3
from sqlalchemy .orm import Session
10
8
from backend .schemas .user import User
11
9
from backend .tests .unit .factories import get_factory
12
10
11
+ _IS_GOOGLE_CLOUD_API_KEY_SET = bool (Settings ().get ('google_cloud.api_key' ))
12
+
13
13
14
14
def test_search_conversations (
15
15
session_client : TestClient ,
16
16
session : Session ,
17
17
user : User ,
18
+ mock_available_model_deployments ,
18
19
) -> None :
19
20
conversation = get_factory ("Conversation" , session ).create (
20
21
title = "test title" , user_id = user .id
@@ -24,23 +25,18 @@ def test_search_conversations(
24
25
headers = {"User-Id" : user .id },
25
26
params = {"query" : "test" },
26
27
)
27
- print ("here" )
28
- print (response .json )
29
28
results = response .json ()
30
29
31
30
assert response .status_code == 200
32
31
assert len (results ) == 1
33
32
assert results [0 ]["id" ] == conversation .id
34
33
35
34
36
- @pytest .mark .skipif (
37
- os .environ .get ("COHERE_API_KEY" ) is None ,
38
- reason = "Cohere API key not set, skipping test" ,
39
- )
40
35
def test_search_conversations_with_reranking (
41
36
session_client : TestClient ,
42
37
session : Session ,
43
38
user : User ,
39
+ mock_available_model_deployments ,
44
40
) -> None :
45
41
_ = get_factory ("Conversation" , session ).create (
46
42
title = "Hello, how are you?" , text_messages = [], user_id = user .id
@@ -83,19 +79,16 @@ def test_search_conversations_no_conversations(
83
79
assert response .json () == []
84
80
85
81
86
- # MISC
87
-
88
-
89
- @pytest .mark .skip (reason = "Restore this test when we get access to run models on Huggingface" )
90
82
def test_generate_title (
91
83
session_client : TestClient ,
92
84
session : Session ,
93
85
user : User ,
86
+ mock_available_model_deployments ,
94
87
) -> None :
95
- conversation = get_factory ("Conversation" , session ).create (user_id = user .id )
88
+ conversation_initial = get_factory ("Conversation" , session ).create (user_id = user .id )
96
89
response = session_client .post (
97
- f"/v1/conversations/{ conversation .id } /generate-title" ,
98
- headers = {"User-Id" : conversation .user_id },
90
+ f"/v1/conversations/{ conversation_initial .id } /generate-title" ,
91
+ headers = {"User-Id" : conversation_initial .user_id },
99
92
)
100
93
response_json = response .json ()
101
94
@@ -105,7 +98,7 @@ def test_generate_title(
105
98
# Check if the conversation was updated
106
99
conversation = (
107
100
session .query (Conversation )
108
- .filter_by (id = conversation .id , user_id = conversation .user_id )
101
+ .filter_by (id = conversation_initial .id , user_id = conversation_initial .user_id )
109
102
.first ()
110
103
)
111
104
assert conversation is not None
@@ -165,10 +158,7 @@ def test_generate_title_error_invalid_model(
165
158
# SYNTHESIZE
166
159
167
160
168
- is_google_cloud_api_key_set = bool (Settings ().get ('google_cloud.api_key' ))
169
-
170
-
171
- @pytest .mark .skipif (not is_google_cloud_api_key_set , reason = "Google Cloud API key not set, skipping test" )
161
+ @pytest .mark .skipif (not _IS_GOOGLE_CLOUD_API_KEY_SET , reason = "Google Cloud API key not set, skipping test" )
172
162
def test_synthesize_english_message (
173
163
session_client : TestClient ,
174
164
session : Session ,
@@ -186,7 +176,7 @@ def test_synthesize_english_message(
186
176
assert response .headers ["Content-Type" ] == "audio/mp3"
187
177
188
178
189
- @pytest .mark .skipif (not is_google_cloud_api_key_set , reason = "Google Cloud API key not set, skipping test" )
179
+ @pytest .mark .skipif (not _IS_GOOGLE_CLOUD_API_KEY_SET , reason = "Google Cloud API key not set, skipping test" )
190
180
def test_synthesize_non_english_message (
191
181
session_client : TestClient ,
192
182
session : Session ,
0 commit comments