Skip to content

Commit 4bc40e1

Browse files
committed
feat(backend): remove dependency on Cohere API key
1 parent 0a42a0d commit 4bc40e1

File tree

2 files changed

+41
-32
lines changed

2 files changed

+41
-32
lines changed

src/backend/tests/integration/routers/test_conversation.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import os
2-
31
import pytest
42
from fastapi.testclient import TestClient
53
from sqlalchemy.orm import Session
@@ -10,11 +8,14 @@
108
from backend.schemas.user import User
119
from backend.tests.unit.factories import get_factory
1210

11+
_IS_GOOGLE_CLOUD_API_KEY_SET = bool(Settings().get('google_cloud.api_key'))
12+
1313

1414
def test_search_conversations(
1515
session_client: TestClient,
1616
session: Session,
1717
user: User,
18+
mock_available_model_deployments,
1819
) -> None:
1920
conversation = get_factory("Conversation", session).create(
2021
title="test title", user_id=user.id
@@ -24,23 +25,18 @@ def test_search_conversations(
2425
headers={"User-Id": user.id},
2526
params={"query": "test"},
2627
)
27-
print("here")
28-
print(response.json)
2928
results = response.json()
3029

3130
assert response.status_code == 200
3231
assert len(results) == 1
3332
assert results[0]["id"] == conversation.id
3433

3534

36-
@pytest.mark.skipif(
37-
os.environ.get("COHERE_API_KEY") is None,
38-
reason="Cohere API key not set, skipping test",
39-
)
4035
def test_search_conversations_with_reranking(
4136
session_client: TestClient,
4237
session: Session,
4338
user: User,
39+
mock_available_model_deployments,
4440
) -> None:
4541
_ = get_factory("Conversation", session).create(
4642
title="Hello, how are you?", text_messages=[], user_id=user.id
@@ -83,19 +79,16 @@ def test_search_conversations_no_conversations(
8379
assert response.json() == []
8480

8581

86-
# MISC
87-
88-
89-
@pytest.mark.skip(reason="Restore this test when we get access to run models on Huggingface")
9082
def test_generate_title(
9183
session_client: TestClient,
9284
session: Session,
9385
user: User,
86+
mock_available_model_deployments,
9487
) -> None:
95-
conversation = get_factory("Conversation", session).create(user_id=user.id)
88+
conversation_initial = get_factory("Conversation", session).create(user_id=user.id)
9689
response = session_client.post(
97-
f"/v1/conversations/{conversation.id}/generate-title",
98-
headers={"User-Id": conversation.user_id},
90+
f"/v1/conversations/{conversation_initial.id}/generate-title",
91+
headers={"User-Id": conversation_initial.user_id},
9992
)
10093
response_json = response.json()
10194

@@ -105,7 +98,7 @@ def test_generate_title(
10598
# Check if the conversation was updated
10699
conversation = (
107100
session.query(Conversation)
108-
.filter_by(id=conversation.id, user_id=conversation.user_id)
101+
.filter_by(id=conversation_initial.id, user_id=conversation_initial.user_id)
109102
.first()
110103
)
111104
assert conversation is not None
@@ -165,10 +158,7 @@ def test_generate_title_error_invalid_model(
165158
# SYNTHESIZE
166159

167160

168-
is_google_cloud_api_key_set = bool(Settings().get('google_cloud.api_key'))
169-
170-
171-
@pytest.mark.skipif(not is_google_cloud_api_key_set, reason="Google Cloud API key not set, skipping test")
161+
@pytest.mark.skipif(not _IS_GOOGLE_CLOUD_API_KEY_SET, reason="Google Cloud API key not set, skipping test")
172162
def test_synthesize_english_message(
173163
session_client: TestClient,
174164
session: Session,
@@ -186,7 +176,7 @@ def test_synthesize_english_message(
186176
assert response.headers["Content-Type"] == "audio/mp3"
187177

188178

189-
@pytest.mark.skipif(not is_google_cloud_api_key_set, reason="Google Cloud API key not set, skipping test")
179+
@pytest.mark.skipif(not _IS_GOOGLE_CLOUD_API_KEY_SET, reason="Google Cloud API key not set, skipping test")
190180
def test_synthesize_non_english_message(
191181
session_client: TestClient,
192182
session: Session,

src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,37 @@
1-
from typing import Any, Dict, Generator, List
1+
import random
2+
from typing import Any, Generator
23

34
from cohere.types import StreamedChatResponse
45

56
from backend.chat.enums import StreamEvent
67
from backend.model_deployments.base import BaseDeployment
78
from backend.schemas.cohere_chat import CohereChatRequest
89
from backend.schemas.context import Context
10+
from backend.services.conversation import SEARCH_RELEVANCE_THRESHOLD
911

1012

1113
class MockCohereDeployment(BaseDeployment):
1214
"""Mocked Cohere Platform Deployment."""
1315

1416
DEFAULT_MODELS = ["command", "command-r"]
1517

18+
def __init__(self, **kwargs: Any):
19+
pass
20+
1621
@property
1722
def rerank_enabled(self) -> bool:
1823
return True
1924

2025
@classmethod
21-
def list_models(cls) -> List[str]:
26+
def list_models(cls) -> list[str]:
2227
return cls.DEFAULT_MODELS
2328

24-
@classmethod
25-
def is_available(cls) -> bool:
29+
@staticmethod
30+
def is_available() -> bool:
2631
return True
2732

28-
def invoke_chat(
29-
self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any
33+
async def invoke_chat(
34+
self, chat_request: CohereChatRequest, **kwargs: Any
3035
) -> Generator[StreamedChatResponse, None, None]:
3136
event = {
3237
"text": "Hi! Hello there! How's it going?",
@@ -51,7 +56,7 @@ def invoke_chat(
5156
}
5257
yield event
5358

54-
def invoke_chat_stream(
59+
async def invoke_chat_stream(
5560
self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any
5661
) -> Generator[StreamedChatResponse, None, None]:
5762
events = [
@@ -79,8 +84,22 @@ def invoke_chat_stream(
7984
for event in events:
8085
yield event
8186

82-
def invoke_rerank(
83-
self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any
87+
async def invoke_rerank(
88+
self, query: str, documents: list[str], ctx: Context, **kwargs: Any
8489
) -> Any:
85-
# TODO: Add
86-
pass
90+
results = []
91+
for idx, doc in enumerate(documents):
92+
if query in doc:
93+
results.append({
94+
"index": idx,
95+
"relevance_score": random.uniform(SEARCH_RELEVANCE_THRESHOLD, 1),
96+
})
97+
event = {
98+
"id": "eae2b023-bf49-4139-bf15-9825022762f4",
99+
"results": results,
100+
"meta": {
101+
"api_version":{"version":"1"},
102+
"billed_units":{"search_units":1}
103+
}
104+
}
105+
return event

0 commit comments

Comments
 (0)