Skip to content

Commit

Permalink
test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
simjak committed Dec 18, 2023
1 parent dc12784 commit 16fc325
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 42 deletions.
5 changes: 3 additions & 2 deletions coverage.xml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<?xml version="1.0" ?>
<coverage version="7.3.3" timestamp="1702890603439" lines-valid="344" lines-covered="344" line-rate="1" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0">
<coverage version="7.3.3" timestamp="1702893702032" lines-valid="345" lines-covered="345" line-rate="1" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0">
<!-- Generated by coverage.py: https://coverage.readthedocs.io/en/7.3.3 -->
<!-- Based on https://raw.githubusercontent.com/cobertura/web/master/htdocs/xml/coverage-04.dtd -->
<sources>
Expand Down Expand Up @@ -382,8 +382,9 @@
<line number="47" hits="1"/>
<line number="49" hits="1"/>
<line number="50" hits="1"/>
<line number="52" hits="1"/>
<line number="51" hits="1"/>
<line number="53" hits="1"/>
<line number="54" hits="1"/>
</lines>
</class>
</classes>
Expand Down
83 changes: 48 additions & 35 deletions docs/examples/function_calling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jakit/customers/aurelio/semantic-router/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"# OpenAI\n",
"import openai\n",
Expand Down Expand Up @@ -39,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -91,7 +100,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -113,7 +122,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -130,7 +139,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -212,7 +221,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -228,7 +237,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -308,19 +317,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from semantic_router.schema import Route\n",
"from semantic_router.encoders import CohereEncoder\n",
"from semantic_router.encoders import CohereEncoder, OpenAIEncoder\n",
"from semantic_router.layer import RouteLayer\n",
"from semantic_router.utils.logger import logger\n",
"\n",
"\n",
"def create_router(routes: list[dict]) -> RouteLayer:\n",
" logger.info(\"Creating route layer...\")\n",
" encoder = CohereEncoder()\n",
" encoder = OpenAIEncoder()\n",
"\n",
" route_list: list[Route] = []\n",
" for route in routes:\n",
Expand All @@ -342,7 +351,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -383,16 +392,16 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2023-12-18 11:00:14 INFO semantic_router.utils.logger Generating config...\u001b[0m\n",
"\u001b[32m2023-12-18 11:00:14 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-18 11:00:17 INFO semantic_router.utils.logger AI message: \n",
"\u001b[32m2023-12-18 11:46:34 INFO semantic_router.utils.logger Generating config...\u001b[0m\n",
"\u001b[32m2023-12-18 11:46:34 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-18 11:46:38 INFO semantic_router.utils.logger AI message: \n",
" Example output:\n",
" {\n",
" \"name\": \"get_time\",\n",
Expand All @@ -404,10 +413,10 @@
" \"Can you tell me the time in Berlin?\"\n",
" ]\n",
" }\u001b[0m\n",
"\u001b[32m2023-12-18 11:00:17 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\u001b[0m\n",
"\u001b[32m2023-12-18 11:00:17 INFO semantic_router.utils.logger Generating config...\u001b[0m\n",
"\u001b[32m2023-12-18 11:00:17 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-18 11:00:21 INFO semantic_router.utils.logger AI message: \n",
"\u001b[32m2023-12-18 11:46:38 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\u001b[0m\n",
"\u001b[32m2023-12-18 11:46:38 INFO semantic_router.utils.logger Generating config...\u001b[0m\n",
"\u001b[32m2023-12-18 11:46:38 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger AI message: \n",
" Example output:\n",
" {\n",
" \"name\": \"get_news\",\n",
Expand All @@ -419,8 +428,9 @@
" \"What's the latest news from Germany?\"\n",
" ]\n",
" }\u001b[0m\n",
"\u001b[32m2023-12-18 11:00:21 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\u001b[0m\n",
"\u001b[32m2023-12-18 11:00:21 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n"
"\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\u001b[0m\n",
"\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n",
"\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger Encoding 10 documents...\u001b[0m\n"
]
},
{
Expand Down Expand Up @@ -460,20 +470,22 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2023-12-18 11:00:22 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n",
"\u001b[32m2023-12-18 11:00:22 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-18 11:00:23 INFO semantic_router.utils.logger AI message: \n",
"\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger Encoding 1 documents...\u001b[0m\n",
"\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n",
"\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-18 11:46:44 INFO semantic_router.utils.logger AI message: \n",
" {\n",
" \"location\": \"Stockholm\"\n",
" }\u001b[0m\n",
"\u001b[32m2023-12-18 11:00:23 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n"
"\u001b[32m2023-12-18 11:46:44 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n",
"\u001b[32m2023-12-18 11:46:44 INFO semantic_router.utils.logger Encoding 1 documents...\u001b[0m\n"
]
},
{
Expand All @@ -488,14 +500,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2023-12-18 11:00:23 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n",
"\u001b[32m2023-12-18 11:00:23 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-18 11:00:25 INFO semantic_router.utils.logger AI message: \n",
"\u001b[32m2023-12-18 11:46:44 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n",
"\u001b[32m2023-12-18 11:46:44 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-18 11:46:45 INFO semantic_router.utils.logger AI message: \n",
" {\n",
" \"category\": \"tech\",\n",
" \"country\": \"Lithuania\"\n",
" }\u001b[0m\n",
"\u001b[32m2023-12-18 11:00:25 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n"
"\u001b[32m2023-12-18 11:46:45 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n",
"\u001b[32m2023-12-18 11:46:45 INFO semantic_router.utils.logger Encoding 1 documents...\u001b[0m\n"
]
},
{
Expand All @@ -510,9 +523,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[33m2023-12-18 11:00:25 WARNING semantic_router.utils.logger No function found\u001b[0m\n",
"\u001b[32m2023-12-18 11:00:25 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-18 11:00:26 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n"
"\u001b[33m2023-12-18 11:46:46 WARNING semantic_router.utils.logger No function found\u001b[0m\n",
"\u001b[32m2023-12-18 11:46:46 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-18 11:46:46 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n"
]
},
{
Expand All @@ -521,7 +534,7 @@
"' How can I help you today?'"
]
},
"execution_count": 20,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand Down
5 changes: 3 additions & 2 deletions semantic_router/encoders/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ def __call__(self, docs: list[str]) -> list[list[float]]:
logger.error(f"OpenAI API call failed. Error: {error_message}")
raise ValueError(f"OpenAI API call failed. Error: {e}")

if not embeds or not isinstance(embeds, dict) or "data" not in embeds:
if embeds is None or embeds.data is None:
logger.error(f"No embeddings returned. Error: {error_message}")
raise ValueError(f"No embeddings returned. Error: {error_message}")

embeddings = [r["embedding"] for r in embeds["data"]]
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings
17 changes: 14 additions & 3 deletions tests/unit/encoders/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from openai import OpenAIError
from openai.types.embedding import Embedding

from semantic_router.encoders import OpenAIEncoder

Expand Down Expand Up @@ -40,11 +41,16 @@ def test_openai_encoder_init_exception(self, mocker):
)

def test_openai_encoder_call_success(self, openai_encoder, mocker):
mock_embeddings = mocker.Mock()
mock_embeddings.data = [
Embedding(embedding=[0.1, 0.2], index=0, object="embedding")
]

mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch.object(
openai_encoder.client.embeddings,
"create",
return_value={"data": [{"embedding": [0.1, 0.2]}]},
return_value=mock_embeddings,
)
embeddings = openai_encoder(["test document"])
assert embeddings == [[0.1, 0.2]]
Expand All @@ -59,7 +65,7 @@ def test_openai_encoder_call_with_retries(self, openai_encoder, mocker):
)
with pytest.raises(ValueError) as e:
openai_encoder(["test document"])
assert "No embeddings returned. Error: Test error" in str(e.value)
assert "No embeddings returned. Error" in str(e.value)

def test_openai_encoder_call_failure_non_openai_error(self, openai_encoder, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
Expand All @@ -75,9 +81,14 @@ def test_openai_encoder_call_failure_non_openai_error(self, openai_encoder, mock
assert "OpenAI API call failed. Error: Non-OpenAIError" in str(e.value)

def test_openai_encoder_call_successful_retry(self, openai_encoder, mocker):
mock_embeddings = mocker.Mock()
mock_embeddings.data = [
Embedding(embedding=[0.1, 0.2], index=0, object="embedding")
]

mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("time.sleep", return_value=None) # To speed up the test
responses = [OpenAIError("Test error"), {"data": [{"embedding": [0.1, 0.2]}]}]
responses = [OpenAIError("Test error"), mock_embeddings]
mocker.patch.object(
openai_encoder.client.embeddings, "create", side_effect=responses
)
Expand Down

0 comments on commit 16fc325

Please sign in to comment.