Skip to content

Commit bfbf10b

Browse files
authored
Merge pull request #33 from aurelio-labs/luca/fix-on-embeddings-check
Fix for embeddings
2 parents 45ce599 + 8011844 commit bfbf10b

File tree

5 files changed

+70
-74
lines changed

5 files changed

+70
-74
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ mac.env
1717
.coverage
1818
.coverage.*
1919
.pytest_cache
20+
test.py

docs/examples/function_calling.ipynb

Lines changed: 66 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,22 @@
99
},
1010
{
1111
"cell_type": "code",
12-
"execution_count": 213,
12+
"execution_count": null,
13+
"metadata": {},
14+
"outputs": [],
15+
"source": [
16+
"%reload_ext dotenv\n",
17+
"%dotenv"
18+
]
19+
},
20+
{
21+
"cell_type": "code",
22+
"execution_count": null,
1323
"metadata": {},
1424
"outputs": [],
1525
"source": [
1626
"# OpenAI\n",
27+
"import os\n",
1728
"import openai\n",
1829
"from semantic_router.utils.logger import logger\n",
1930
"\n",
@@ -39,7 +50,7 @@
3950
},
4051
{
4152
"cell_type": "code",
42-
"execution_count": 214,
53+
"execution_count": null,
4354
"metadata": {},
4455
"outputs": [],
4556
"source": [
@@ -48,7 +59,7 @@
4859
"import requests\n",
4960
"\n",
5061
"# Docs https://huggingface.co/docs/transformers/main_classes/text_generation\n",
51-
"HF_API_TOKEN = os.environ[\"HF_API_TOKEN\"]\n",
62+
"HF_API_TOKEN = os.getenv(\"HF_API_TOKEN\")\n",
5263
"\n",
5364
"\n",
5465
"def llm_mistral(prompt: str) -> str:\n",
@@ -180,7 +191,7 @@
180191
},
181192
{
182193
"cell_type": "code",
183-
"execution_count": 217,
194+
"execution_count": null,
184195
"metadata": {},
185196
"outputs": [],
186197
"source": [
@@ -242,6 +253,23 @@
242253
"Set up the routing layer"
243254
]
244255
},
256+
{
257+
"cell_type": "code",
258+
"execution_count": null,
259+
"metadata": {},
260+
"outputs": [],
261+
"source": [
262+
"from semantic_router.schema import Route\n",
263+
"from semantic_router.encoders import CohereEncoder, OpenAIEncoder\n",
264+
"from semantic_router.layer import RouteLayer\n",
265+
"from semantic_router.utils.logger import logger\n",
266+
"\n",
267+
"\n",
268+
"def create_router(routes: list[dict]) -> RouteLayer:\n",
269+
" logger.info(\"Creating route layer...\")\n",
270+
" encoder = OpenAIEncoder"
271+
]
272+
},
245273
{
246274
"cell_type": "code",
247275
"execution_count": null,
@@ -256,7 +284,7 @@
256284
"\n",
257285
"def create_router(routes: list[dict]) -> RouteLayer:\n",
258286
" logger.info(\"Creating route layer...\")\n",
259-
" encoder = CohereEncoder()\n",
287+
" encoder = OpenAIEncoder()\n",
260288
"\n",
261289
" route_list: list[Route] = []\n",
262290
" for route in routes:\n",
@@ -278,7 +306,7 @@
278306
},
279307
{
280308
"cell_type": "code",
281-
"execution_count": 219,
309+
"execution_count": null,
282310
"metadata": {},
283311
"outputs": [],
284312
"source": [
@@ -349,72 +377,38 @@
349377
},
350378
{
351379
"cell_type": "code",
352-
"execution_count": 220,
380+
"execution_count": null,
381+
"metadata": {},
382+
"outputs": [],
383+
"source": [
384+
"def get_time(location: str) -> str:\n",
385+
" \"\"\"Useful to get the time in a specific location\"\"\"\n",
386+
" print(f\"Calling `get_time` function with location: {location}\")\n",
387+
" return \"get_time\"\n",
388+
"\n",
389+
"\n",
390+
"def get_news(category: str, country: str) -> str:\n",
391+
" \"\"\"Useful to get the news in a specific country\"\"\"\n",
392+
" print(\n",
393+
" f\"Calling `get_news` function with category: {category} and country: {country}\"\n",
394+
" )\n",
395+
" return \"get_news\"\n",
396+
"\n",
397+
"\n",
398+
"# Registering functions to the router\n",
399+
"route_get_time = generate_route(get_time)\n",
400+
"route_get_news = generate_route(get_news)\n",
401+
"\n",
402+
"routes = [route_get_time, route_get_news]\n",
403+
"router = create_router(routes)\n",
404+
"\n",
405+
"# Tools\n",
406+
"tools = [get_time, get_news]"
407+
]
408+
},
409+
{
410+
"cell_type": "markdown",
353411
"metadata": {},
354-
"outputs": [
355-
{
356-
"name": "stderr",
357-
"output_type": "stream",
358-
"text": [
359-
"\u001b[32m2023-12-15 11:41:54 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n",
360-
"\u001b[32m2023-12-15 11:41:54 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
361-
"\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger AI message: \n",
362-
" {\n",
363-
" 'location': 'Stockholm'\n",
364-
" }\u001b[0m\n",
365-
"\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n"
366-
]
367-
},
368-
{
369-
"name": "stdout",
370-
"output_type": "stream",
371-
"text": [
372-
"parameters: {'location': 'Stockholm'}\n",
373-
"Calling `get_time` function with location: Stockholm\n"
374-
]
375-
},
376-
{
377-
"name": "stderr",
378-
"output_type": "stream",
379-
"text": [
380-
"\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n",
381-
"\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
382-
"\u001b[32m2023-12-15 11:41:56 INFO semantic_router.utils.logger AI message: \n",
383-
" {\n",
384-
" 'category': 'tech',\n",
385-
" 'country': 'Lithuania'\n",
386-
" }\u001b[0m\n",
387-
"\u001b[32m2023-12-15 11:41:56 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n"
388-
]
389-
},
390-
{
391-
"name": "stdout",
392-
"output_type": "stream",
393-
"text": [
394-
"parameters: {'category': 'tech', 'country': 'Lithuania'}\n",
395-
"Calling `get_news` function with category: tech and country: Lithuania\n"
396-
]
397-
},
398-
{
399-
"name": "stderr",
400-
"output_type": "stream",
401-
"text": [
402-
"\u001b[33m2023-12-15 11:41:57 WARNING semantic_router.utils.logger No function found\u001b[0m\n",
403-
"\u001b[32m2023-12-15 11:41:57 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
404-
"\u001b[32m2023-12-15 11:41:57 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n"
405-
]
406-
},
407-
{
408-
"data": {
409-
"text/plain": [
410-
"' How can I help you today?'"
411-
]
412-
},
413-
"execution_count": 220,
414-
"metadata": {},
415-
"output_type": "execute_result"
416-
}
417-
],
418412
"source": [
419413
"call(query=\"What is the time in Stockholm?\", functions=tools, router=router)\n",
420414
"call(query=\"What is the tech news in the Lithuania?\", functions=tools, router=router)\n",
@@ -438,7 +432,7 @@
438432
"name": "python",
439433
"nbconvert_exporter": "python",
440434
"pygments_lexer": "ipython3",
441-
"version": "3.11.3"
435+
"version": "3.11.5"
442436
}
443437
},
444438
"nbformat": 4,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "semantic-router"
3-
version = "0.0.9"
3+
version = "0.0.10"
44
description = "Super fast semantic router for AI decision making"
55
authors = [
66
"James Briggs <james@aurelio.ai>",

semantic_router/encoders/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __call__(self, docs: list[str]) -> list[list[float]]:
3636
try:
3737
logger.info(f"Encoding {len(docs)} documents...")
3838
embeds = self.client.embeddings.create(input=docs, model=self.name)
39-
if isinstance(embeds, dict) and "data" in embeds:
39+
if "data" in embeds:
4040
break
4141
except OpenAIError as e:
4242
sleep(2**j)

tests/unit/encoders/test_openai.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def test_openai_encoder_call_failure_non_openai_error(self, openai_encoder, mock
7171
)
7272
with pytest.raises(ValueError) as e:
7373
openai_encoder(["test document"])
74+
7475
assert "OpenAI API call failed. Error: Non-OpenAIError" in str(e.value)
7576

7677
def test_openai_encoder_call_successful_retry(self, openai_encoder, mocker):

0 commit comments

Comments
 (0)