|
9 | 9 | },
|
10 | 10 | {
|
11 | 11 | "cell_type": "code",
|
12 |
| - "execution_count": 213, |
| 12 | + "execution_count": null, |
| 13 | + "metadata": {}, |
| 14 | + "outputs": [], |
| 15 | + "source": [ |
| 16 | + "%reload_ext dotenv\n", |
| 17 | + "%dotenv" |
| 18 | + ] |
| 19 | + }, |
| 20 | + { |
| 21 | + "cell_type": "code", |
| 22 | + "execution_count": null, |
13 | 23 | "metadata": {},
|
14 | 24 | "outputs": [],
|
15 | 25 | "source": [
|
16 | 26 | "# OpenAI\n",
|
| 27 | + "import os\n", |
17 | 28 | "import openai\n",
|
18 | 29 | "from semantic_router.utils.logger import logger\n",
|
19 | 30 | "\n",
|
|
39 | 50 | },
|
40 | 51 | {
|
41 | 52 | "cell_type": "code",
|
42 |
| - "execution_count": 214, |
| 53 | + "execution_count": null, |
43 | 54 | "metadata": {},
|
44 | 55 | "outputs": [],
|
45 | 56 | "source": [
|
|
48 | 59 | "import requests\n",
|
49 | 60 | "\n",
|
50 | 61 | "# Docs https://huggingface.co/docs/transformers/main_classes/text_generation\n",
|
51 |
| - "HF_API_TOKEN = os.environ[\"HF_API_TOKEN\"]\n", |
| 62 | + "HF_API_TOKEN = os.getenv(\"HF_API_TOKEN\")\n", |
52 | 63 | "\n",
|
53 | 64 | "\n",
|
54 | 65 | "def llm_mistral(prompt: str) -> str:\n",
|
|
180 | 191 | },
|
181 | 192 | {
|
182 | 193 | "cell_type": "code",
|
183 |
| - "execution_count": 217, |
| 194 | + "execution_count": null, |
184 | 195 | "metadata": {},
|
185 | 196 | "outputs": [],
|
186 | 197 | "source": [
|
|
242 | 253 | "Set up the routing layer"
|
243 | 254 | ]
|
244 | 255 | },
|
| 256 | + { |
| 257 | + "cell_type": "code", |
| 258 | + "execution_count": null, |
| 259 | + "metadata": {}, |
| 260 | + "outputs": [], |
| 261 | + "source": [ |
| 262 | + "from semantic_router.schema import Route\n", |
| 263 | + "from semantic_router.encoders import CohereEncoder, OpenAIEncoder\n", |
| 264 | + "from semantic_router.layer import RouteLayer\n", |
| 265 | + "from semantic_router.utils.logger import logger\n", |
| 266 | + "\n", |
| 267 | + "\n", |
| 268 | + "def create_router(routes: list[dict]) -> RouteLayer:\n", |
| 269 | + " logger.info(\"Creating route layer...\")\n", |
| 270 | + " encoder = OpenAIEncoder" |
| 271 | + ] |
| 272 | + }, |
245 | 273 | {
|
246 | 274 | "cell_type": "code",
|
247 | 275 | "execution_count": null,
|
|
256 | 284 | "\n",
|
257 | 285 | "def create_router(routes: list[dict]) -> RouteLayer:\n",
|
258 | 286 | " logger.info(\"Creating route layer...\")\n",
|
259 |
| - " encoder = CohereEncoder()\n", |
| 287 | + " encoder = OpenAIEncoder()\n", |
260 | 288 | "\n",
|
261 | 289 | " route_list: list[Route] = []\n",
|
262 | 290 | " for route in routes:\n",
|
|
278 | 306 | },
|
279 | 307 | {
|
280 | 308 | "cell_type": "code",
|
281 |
| - "execution_count": 219, |
| 309 | + "execution_count": null, |
282 | 310 | "metadata": {},
|
283 | 311 | "outputs": [],
|
284 | 312 | "source": [
|
|
349 | 377 | },
|
350 | 378 | {
|
351 | 379 | "cell_type": "code",
|
352 |
| - "execution_count": 220, |
| 380 | + "execution_count": null, |
| 381 | + "metadata": {}, |
| 382 | + "outputs": [], |
| 383 | + "source": [ |
| 384 | + "def get_time(location: str) -> str:\n", |
| 385 | + " \"\"\"Useful to get the time in a specific location\"\"\"\n", |
| 386 | + " print(f\"Calling `get_time` function with location: {location}\")\n", |
| 387 | + " return \"get_time\"\n", |
| 388 | + "\n", |
| 389 | + "\n", |
| 390 | + "def get_news(category: str, country: str) -> str:\n", |
| 391 | + " \"\"\"Useful to get the news in a specific country\"\"\"\n", |
| 392 | + " print(\n", |
| 393 | + " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", |
| 394 | + " )\n", |
| 395 | + " return \"get_news\"\n", |
| 396 | + "\n", |
| 397 | + "\n", |
| 398 | + "# Registering functions to the router\n", |
| 399 | + "route_get_time = generate_route(get_time)\n", |
| 400 | + "route_get_news = generate_route(get_news)\n", |
| 401 | + "\n", |
| 402 | + "routes = [route_get_time, route_get_news]\n", |
| 403 | + "router = create_router(routes)\n", |
| 404 | + "\n", |
| 405 | + "# Tools\n", |
| 406 | + "tools = [get_time, get_news]" |
| 407 | + ] |
| 408 | + }, |
| 409 | + { |
| 410 | + "cell_type": "markdown", |
353 | 411 | "metadata": {},
|
354 |
| - "outputs": [ |
355 |
| - { |
356 |
| - "name": "stderr", |
357 |
| - "output_type": "stream", |
358 |
| - "text": [ |
359 |
| - "\u001b[32m2023-12-15 11:41:54 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", |
360 |
| - "\u001b[32m2023-12-15 11:41:54 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", |
361 |
| - "\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger AI message: \n", |
362 |
| - " {\n", |
363 |
| - " 'location': 'Stockholm'\n", |
364 |
| - " }\u001b[0m\n", |
365 |
| - "\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n" |
366 |
| - ] |
367 |
| - }, |
368 |
| - { |
369 |
| - "name": "stdout", |
370 |
| - "output_type": "stream", |
371 |
| - "text": [ |
372 |
| - "parameters: {'location': 'Stockholm'}\n", |
373 |
| - "Calling `get_time` function with location: Stockholm\n" |
374 |
| - ] |
375 |
| - }, |
376 |
| - { |
377 |
| - "name": "stderr", |
378 |
| - "output_type": "stream", |
379 |
| - "text": [ |
380 |
| - "\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", |
381 |
| - "\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", |
382 |
| - "\u001b[32m2023-12-15 11:41:56 INFO semantic_router.utils.logger AI message: \n", |
383 |
| - " {\n", |
384 |
| - " 'category': 'tech',\n", |
385 |
| - " 'country': 'Lithuania'\n", |
386 |
| - " }\u001b[0m\n", |
387 |
| - "\u001b[32m2023-12-15 11:41:56 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n" |
388 |
| - ] |
389 |
| - }, |
390 |
| - { |
391 |
| - "name": "stdout", |
392 |
| - "output_type": "stream", |
393 |
| - "text": [ |
394 |
| - "parameters: {'category': 'tech', 'country': 'Lithuania'}\n", |
395 |
| - "Calling `get_news` function with category: tech and country: Lithuania\n" |
396 |
| - ] |
397 |
| - }, |
398 |
| - { |
399 |
| - "name": "stderr", |
400 |
| - "output_type": "stream", |
401 |
| - "text": [ |
402 |
| - "\u001b[33m2023-12-15 11:41:57 WARNING semantic_router.utils.logger No function found\u001b[0m\n", |
403 |
| - "\u001b[32m2023-12-15 11:41:57 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", |
404 |
| - "\u001b[32m2023-12-15 11:41:57 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n" |
405 |
| - ] |
406 |
| - }, |
407 |
| - { |
408 |
| - "data": { |
409 |
| - "text/plain": [ |
410 |
| - "' How can I help you today?'" |
411 |
| - ] |
412 |
| - }, |
413 |
| - "execution_count": 220, |
414 |
| - "metadata": {}, |
415 |
| - "output_type": "execute_result" |
416 |
| - } |
417 |
| - ], |
418 | 412 | "source": [
|
419 | 413 | "call(query=\"What is the time in Stockholm?\", functions=tools, router=router)\n",
|
420 | 414 | "call(query=\"What is the tech news in the Lithuania?\", functions=tools, router=router)\n",
|
|
438 | 432 | "name": "python",
|
439 | 433 | "nbconvert_exporter": "python",
|
440 | 434 | "pygments_lexer": "ipython3",
|
441 |
| - "version": "3.11.3" |
| 435 | + "version": "3.11.5" |
442 | 436 | }
|
443 | 437 | },
|
444 | 438 | "nbformat": 4,
|
|
0 commit comments