diff --git a/fern/pages/cookbooks/sql-agent-cohere-langchain.mdx b/fern/pages/cookbooks/sql-agent-cohere-langchain.mdx
index d1d13e42f..7d54879ea 100644
--- a/fern/pages/cookbooks/sql-agent-cohere-langchain.mdx
+++ b/fern/pages/cookbooks/sql-agent-cohere-langchain.mdx
@@ -6,13 +6,7 @@ description: "This page contains a tutorial on how to build a SQL agent with Coh
image: "../../assets/images/f1cc130-cohere_meta_image.jpg"
keywords: "Cohere, automatic SQL generation, code generation, AI agents"
---
-
-import { AuthorsContainer } from "../../components/authors-container";
-import { CookbookHeader } from "../../components/cookbook-header";
-
-
-
-*This notebook was created in collaboration with [i-5O](https://i-5o.ai/)*
+# SQL Agent Demo with Cohere and LangChain
This tutorial demonstrates how to create a SQL agent using Cohere and LangChain. The agent can translate natural language queries coming from users into SQL, and execute them against a database. This powerful combination allows for intuitive interaction with databases without requiring direct SQL knowledge.
@@ -22,6 +16,7 @@ Key topics covered:
3. Configuring the LangChain SQL Toolkit
4. Creating a custom prompt template with few-shot examples
5. Building and running the SQL agent
+6. Adding memory to the agent to keep track of historical messages
By the end of this tutorial, you'll have a functional SQL agent that can answer questions about your data using natural language.
@@ -41,56 +36,52 @@ The database contains two tables:
- [Create the agent](#toc1_5_)
- [Run the agent](#toc1_6_)
-
## [Import the required libraries](#toc0_)
First, let's import the necessary libraries for creating a SQL agent using Cohere and LangChain. These libraries enable natural language interaction with databases and provide tools for building AI-powered agents.
-```python PYTHON
+
+```python
import os
-os.environ["COHERE_API_KEY"] = ""
+os.environ["COHERE_API_KEY"] = ""
```
-```python PYTHON
-! pip install faiss-gpu -qq
+```python
+! pip install faiss-cpu -qq
```
-```python PYTHON
-! pip install langchain-core langchain-cohere langchain-community
+```python
+! pip install langchain-core langchain-cohere langchain-community -qq
```
-```python PYTHON
-from langchain.agents import AgentExecutor
+
+```python
from langchain_cohere import create_sql_agent
-from langchain_core.prompts import ChatPromptTemplate
from langchain_cohere.chat_models import ChatCohere
-from langchain_community.utilities.sql_database import SQLDatabase
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_cohere import CohereEmbeddings
-from datetime import datetime, timedelta
-import os
-import json
+from datetime import datetime
```
-
## [Load the database](#toc0_)
Next, we load the database for our manufacturing data.
-We create an in-memory SQLite database using SQL scripts for the `product_tracking` and `status` tables. You can get the [SQL tables here](https://github.com/cohere-ai/cohere-developer-experience/tree/main/notebooks/agents/i-5O-sql-agent).
+### Download the sql files from the link below to create the database.
+
+We create an in-memory SQLite database using SQL scripts for the `product_tracking` and `status` tables. You can get the [SQL tables here](https://github.com/cohere-ai/notebooks/tree/main/notebooks/agents/i-5O-sql-agent).
We then create a SQLDatabase instance, which will be used by our LangChain tools and agents to interact with the data.
-```python PYTHON
+```python
import sqlite3
-import os
from langchain_community.utilities.sql_database import SQLDatabase
from sqlalchemy import create_engine
@@ -123,34 +114,24 @@ db = SQLDatabase(engine)
```
-```python PYTHON
+```python
# Test the connection
db.run("SELECT * FROM status LIMIT 5;")
```
-```
-"[('2024-05-09 19:28:00', 'Canada/Toronto', '2024-05-09', '19', '28', 'stn3', 'downtime'), ('2024-04-21 06:57:00', 'Canada/Toronto', '2024-04-21', '6', '57', 'stn3', 'productive'), ('2024-04-11 23:52:00', 'Canada/Toronto', '2024-04-11', '23', '52', 'stn4', 'productive'), ('2024-04-03 21:52:00', 'Canada/Toronto', '2024-04-03', '21', '52', 'stn2', 'downtime'), ('2024-04-30 05:01:00', 'Canada/Toronto', '2024-04-30', '5', '1', 'stn4', 'productive')]"
-```
-
-
-```python PYTHON
+```python
# Test the connection
db.run("SELECT * FROM product_tracking LIMIT 5;")
```
-```
-"[('2024-05-27 17:22:00', '2024-05-27 17:57:00', 'Canada/Toronto', '2024-05-27', '17', 'stn2', 'wip', '187', '35'), ('2024-04-26 15:56:00', '2024-04-26 17:56:00', 'Canada/Toronto', '2024-04-26', '15', 'stn4', 'wip', '299', '120'), ('2024-04-12 04:36:00', '2024-04-12 05:12:00', 'Canada/Toronto', '2024-04-12', '4', 'stn3', 'wip', '60', '36'), ('2024-04-19 15:15:00', '2024-04-19 15:22:00', 'Canada/Toronto', '2024-04-19', '15', 'stn4', 'wait', '227', '7'), ('2024-04-24 19:10:00', '2024-04-24 21:07:00', 'Canada/Toronto', '2024-04-24', '19', 'stn4', 'wait', '169', '117')]"
-```
-
-
## [Setup the LangChain SQL Toolkit](#toc0_)
-Next, we initialize the LangChain SQL Toolkit and sets up the language model to use Cohere's model. This prepares the necessary components for querying the SQL database using natural language.
+Next, we initialize the LangChain SQL Toolkit and initialize the language model to use Cohere's LLM. This prepares the necessary components for querying the SQL database using natural language.
-```python PYTHON
+```python
## Define model to use
import os
@@ -168,19 +149,84 @@ tools = toolkit.get_tools()
print('**List of pre-defined Langchain Tools**')
print([tool.name for tool in tools])
```
+
+## [Create a prompt template](#toc0_)
+
+Next, we create a prompt template. In this section, we will introduce a simple system message, and then also show how we can improve the prompt by introducing few shot prompting examples in the later sections.
+The system message is used to communicate instructions or provide context to the model at the beginning of a conversation.
+
+In this case, we provide the model with context on what SQL dialect it should use, how many samples to query among other instructions.
+
+
+```python
+from langchain_core.prompts import (
+ PromptTemplate,
+ ChatPromptTemplate,
+ SystemMessagePromptTemplate,
+ MessagesPlaceholder
+)
+
+system_message = """You are an agent designed to interact with a SQL database.
+You are an expert at answering questions about manufacturing data.
+Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
+Always start with checking the schema of the available tables.
+Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
+You can order the results by a relevant column to return the most interesting examples in the database.
+Never query for all the columns from a specific table, only ask for the relevant columns given the question.
+You have access to tools for interacting with the database.
+Only use the given tools. Only use the information returned by the tools to construct your final answer.
+You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
+
+DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
+
+The current date is {date}.
+
+For questions regarding productive time, downtime, productive or productivity, use minutes as units.
+
+For questions regarding productive time, downtime, productive or productivity use the status table.
+
+For questions regarding processing time and average processing time, use minutes as units.
+
+For questions regarding bottlenecks, processing time and average processing time use the product_tracking table.
+
+If the question does not seem related to the database, just return "I don't know" as the answer."""
+
+system_prompt = PromptTemplate.from_template(system_message)
+
+```
+
+
+```python
+full_prompt = ChatPromptTemplate.from_messages(
+ [
+ SystemMessagePromptTemplate(prompt=system_prompt),
+ MessagesPlaceholder(variable_name='chat_history', optional=True),
+ ("human", "{input}"),
+ MessagesPlaceholder("agent_scratchpad"),
+ ]
+)
```
-**List of pre-defined Langchain Tools**
-['sql_db_query', 'sql_db_schema', 'sql_db_list_tables', 'sql_db_query_checker']
+
+
+```python
+prompt_val = full_prompt.invoke({
+ "input": "What was the productive time for all stations today?",
+ "top_k": 5,
+ "dialect": "SQLite",
+ "date":datetime.now(),
+ "agent_scratchpad": [],
+ })
+print(prompt_val.to_string())
```
-## [Create a prompt template with few shot examples](#toc0_)
+## [Create a few-shot prompt template](#toc0_)
-Next, we create a prompt template with few-shot examples. Few-shot examples are used to provide the model with context and improve its performance on specific tasks. In this case, we'll prepare examples of natural language queries and their corresponding SQL queries to help the model generate accurate SQL statements for our database.
+In the above step, we've created a simple system prompt. Now, let us see how we can create a better few shot prompt template in this section. Few-shot examples are used to provide the model with context and improve its performance on specific tasks. In this case, we'll prepare examples of natural language queries and their corresponding SQL queries to help the model generate accurate SQL statements for our database.
In this example, we use `SemanticSimilarityExampleSelector` to select the top k examples that are most similar to an input query out of all the examples available.
-```python PYTHON
+```python
examples = [
{
"input": "What was the average processing time for all stations on April 3rd 2024?",
@@ -214,7 +260,7 @@ examples = [
```
-```python PYTHON
+```python
example_selector = SemanticSimilarityExampleSelector.from_examples(
examples,
CohereEmbeddings(cohere_api_key=os.getenv("COHERE_API_KEY"),
@@ -226,7 +272,7 @@ example_selector = SemanticSimilarityExampleSelector.from_examples(
```
-```python PYTHON
+```python
from langchain_core.prompts import (
ChatPromptTemplate,
FewShotPromptTemplate,
@@ -260,7 +306,8 @@ For questions regarding bottlenecks, processing time and average processing time
If the question does not seem related to the database, just return "I don't know" as the answer.
-Here are some examples of user inputs and their corresponding SQL queries:"""
+Here are some examples of user inputs and their corresponding SQL queries:
+"""
few_shot_prompt = FewShotPromptTemplate(
example_selector=example_selector,
@@ -274,9 +321,11 @@ few_shot_prompt = FewShotPromptTemplate(
```
-```python PYTHON
+```python
full_prompt = ChatPromptTemplate.from_messages(
[
+ # In the previous section, this was system_prompt instead without the few shot examples.
+ # We can use either prompting style as required
SystemMessagePromptTemplate(prompt=few_shot_prompt),
("human", "{input}"),
MessagesPlaceholder("agent_scratchpad"),
@@ -285,7 +334,7 @@ full_prompt = ChatPromptTemplate.from_messages(
```
-```python PYTHON
+```python
# Example formatted prompt
prompt_val = full_prompt.invoke(
{
@@ -298,50 +347,6 @@ prompt_val = full_prompt.invoke(
)
print(prompt_val.to_string())
```
-```
-System: You are an agent designed to interact with a SQL database.
-You are an expert at answering questions about manufacturing data.
-Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
-Always start with checking the schema of the available tables.
-Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
-You can order the results by a relevant column to return the most interesting examples in the database.
-Never query for all the columns from a specific table, only ask for the relevant columns given the question.
-You have access to tools for interacting with the database.
-Only use the given tools. Only use the information returned by the tools to construct your final answer.
-You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
-
-DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
-
-The current date is 2024-10-10 09:40:22.052973.
-
-For questions regarding productive time, downtime, productive or productivity, use minutes as units.
-
-For questions regarding productive time, downtime, productive or productivity use the status table.
-
-For questions regarding processing time and average processing time, use minutes as units.
-
-For questions regarding bottlenecks, processing time and average processing time use the product_tracking table.
-
-If the question does not seem related to the database, just return "I don't know" as the answer.
-
-Here are some examples of user inputs and their corresponding SQL queries:
-
-User input: What were the productive time and downtime numbers for all stations on April 3rd 2024?
-SQL query: SELECT station_name, station_status, COUNT(*) as total_time FROM status WHERE date = '2024-04-03' GROUP BY station_name, station_status;
-
-User input: What was the average processing time for all stations on April 3rd 2024?
-SQL query: SELECT station_name, AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND zone = 'wip' GROUP BY station_name ORDER BY station_name;
-
-User input: What was the average processing time for all stations on April 3rd 2024 between 4pm and 6pm?
-SQL query: SELECT station_name, AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND CAST(hour AS INTEGER) BETWEEN 16 AND 18 AND zone = 'wip' GROUP BY station_name ORDER BY station_name;
-
-User input: What was the bottleneck station on April 3rd 2024?
-SQL query: SELECT station_name, AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND zone = 'wip' GROUP BY station_name ORDER BY avg_processing_time DESC LIMIT 1;
-
-User input: What was the average processing time for stn4 on April 3rd 2024?
-SQL query: SELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND station_name = 'stn4' AND zone = 'wip';
-Human: What was the productive time for all stations today?
-```
## [Create the agent](#toc0_)
@@ -350,7 +355,7 @@ Next, we create an instance of the SQL agent using the LangChain framework, spec
This agent will be capable of interpreting natural language queries, converting them into SQL queries, and executing them against our database. The agent uses the LLM we defined earlier, along with the SQL toolkit and the custom prompt we created.
-```python PYTHON
+```python
agent = create_sql_agent(
llm=llm,
toolkit=toolkit,
@@ -364,228 +369,97 @@ agent = create_sql_agent(
Now, we can run the agent and test it with a few different queries.
-```python PYTHON
+```python
# %%time
output=agent.invoke({
- "input": "What was the total downtime for each station during the month of May 2024?",
+ "input": "Which stations had some downtime in the month of May 2024?",
"date": datetime.now()
})
print(output['output'])
-# Answer: stn2: 1 minute(s), stn3: 2 minute(s)stn5: 4 minute(s)
+# Answer: stn2, stn3 and stn5 had some downtime in the month of May 2024.
```
-```
-[1m> Entering new Cohere SQL Agent Executor chain...[0m
-[32;1m[1;3m
-Invoking: `sql_db_list_tables` with `{}`
-responded: I will first check the schema of the available tables. Then, I will write and execute a query to find the total downtime for each station during the month of May 2024.
-
-[0m[38;5;200m[1;3mproduct_tracking, status[0m[32;1m[1;3m
-Invoking: `sql_db_schema` with `{'table_names': 'product_tracking, status'}`
-responded: I have found that the relevant tables are 'product_tracking' and 'status'. I will now check the schema of these tables.
-
-[0m[33;1m[1;3m
-CREATE TABLE product_tracking (
- timestamp_start TEXT,
- timestamp_end TEXT,
- timezone TEXT,
- date TEXT,
- hour TEXT,
- station_name TEXT,
- zone TEXT,
- product_id TEXT,
- duration TEXT
-)
-/*
-3 rows from product_tracking table:
-timestamp_start timestamp_end timezone date hour station_name zone product_id duration
-2024-05-27 17:22:00 2024-05-27 17:57:00 Canada/Toronto 2024-05-27 17 stn2 wip 187 35
-2024-04-26 15:56:00 2024-04-26 17:56:00 Canada/Toronto 2024-04-26 15 stn4 wip 299 120
-2024-04-12 04:36:00 2024-04-12 05:12:00 Canada/Toronto 2024-04-12 4 stn3 wip 60 36
-*/
-
-
-CREATE TABLE status (
- timestamp_event TEXT,
- timezone TEXT,
- date TEXT,
- hour TEXT,
- minute TEXT,
- station_name TEXT,
- station_status TEXT
-)
+```python
+output=agent.invoke({
+ "input": "What is the average processing duration at stn5 in the wip zone?",
+ "date": datetime.now()
+})
+print(output['output'])
-/*
-3 rows from status table:
-timestamp_event timezone date hour minute station_name station_status
-2024-05-09 19:28:00 Canada/Toronto 2024-05-09 19 28 stn3 downtime
-2024-04-21 06:57:00 Canada/Toronto 2024-04-21 6 57 stn3 productive
-2024-04-11 23:52:00 Canada/Toronto 2024-04-11 23 52 stn4 productive
-*/[0m[32;1m[1;3m
-Invoking: `sql_db_query_checker` with `{'query': "SELECT station_name, COUNT(*) as total_downtime FROM status WHERE date LIKE '2024-05-%' AND station_status = 'downtime' GROUP BY station_name;"}`
-responded: I have found that the 'status' table contains the 'station_name' and 'station_status' columns, which are relevant to the user's request. I will now write and execute a query to find the total downtime for each station during the month of May 2024.
-
-[0m[36;1m[1;3m```sql
-SELECT station_name, COUNT(*) as total_downtime FROM status WHERE date LIKE '2024-05-%' AND station_status = 'downtime' GROUP BY station_name;
-```[0m[32;1m[1;3m
-Invoking: `sql_db_query` with `{'query': "SELECT station_name, COUNT(*) as total_downtime FROM status WHERE date LIKE '2024-05-%' AND station_status = 'downtime' GROUP BY station_name;"}`
-responded: I have checked the query and it is correct. I will now execute it.
-
-[0m[36;1m[1;3m[('stn2', 1), ('stn3', 2), ('stn5', 4)][0m[32;1m[1;3mThe total downtime for each station during the month of May 2024 was:
-
-| station_name | total_downtime |
-|---|---|
-| stn2 | 1 |
-| stn3 | 2 |
-| stn5 | 4 |[0m
-
-[1m> Finished chain.[0m
-The total downtime for each station during the month of May 2024 was:
-
-| station_name | total_downtime |
-|---|---|
-| stn2 | 1 |
-| stn3 | 2 |
-| stn5 | 4 |
+# Answer: 39.17 minutes
```
-```python PYTHON
+```python
output=agent.invoke({
- "input": "What is the average processing duration at stn5 in the wip zone?",
+ "input": "Which station had the highest total duration in the wait zone?",
"date": datetime.now()
})
print(output['output'])
-# Answer: 39.17 minutes
+# Answer: stn4 - 251 minutes
+
```
+
+## Memory in the sql agent
+We may want the agent to hold memory of our previous messages so that we're able to coherently engage with the agent to answer our queries. In this section, let's take a look at how we can add memory to the agent so that we're able to achieve this outcome!
+
+
+```python
+from langchain_core.runnables.history import RunnableWithMessageHistory
+from langchain_core.chat_history import BaseChatMessageHistory
+from langchain_core.messages import BaseMessage
+from pydantic import BaseModel, Field
+from typing import List
```
-[1m> Entering new Cohere SQL Agent Executor chain...[0m
-[32;1m[1;3m
-Invoking: `sql_db_list_tables` with `{}`
-responded: I will first check the schema of the available tables. Then, I will write and execute a query to find the average processing duration at stn5 in the wip zone.
-
-[0m[38;5;200m[1;3mproduct_tracking, status[0m[32;1m[1;3m
-Invoking: `sql_db_schema` with `{'table_names': 'product_tracking'}`
-responded: The available tables are product_tracking and status.
-
-I will now write and execute a query to find the average processing duration at stn5 in the wip zone.
-
-[0m[33;1m[1;3m
-CREATE TABLE product_tracking (
- timestamp_start TEXT,
- timestamp_end TEXT,
- timezone TEXT,
- date TEXT,
- hour TEXT,
- station_name TEXT,
- zone TEXT,
- product_id TEXT,
- duration TEXT
-)
-/*
-3 rows from product_tracking table:
-timestamp_start timestamp_end timezone date hour station_name zone product_id duration
-2024-05-27 17:22:00 2024-05-27 17:57:00 Canada/Toronto 2024-05-27 17 stn2 wip 187 35
-2024-04-26 15:56:00 2024-04-26 17:56:00 Canada/Toronto 2024-04-26 15 stn4 wip 299 120
-2024-04-12 04:36:00 2024-04-12 05:12:00 Canada/Toronto 2024-04-12 4 stn3 wip 60 36
-*/[0m[32;1m[1;3m
-Invoking: `sql_db_query_checker` with `{'query': "SELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_duration FROM product_tracking WHERE station_name = 'stn5' AND zone = 'wip';"}`
-responded: The available tables are product_tracking and status.
+In the code snippets below, we create a class to store the chat history in memory. This can be customised to store the messages from a database or any other suitable data store.
-I will now write and execute a query to find the average processing duration at stn5 in the wip zone.
-[0m[36;1m[1;3mSELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_duration FROM product_tracking WHERE station_name = 'stn5' AND zone = 'wip';[0m[32;1m[1;3m
-Invoking: `sql_db_query` with `{'query': "SELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_duration FROM product_tracking WHERE station_name = 'stn5' AND zone = 'wip';"}`
-responded: The available tables are product_tracking and status.
+```python
+class InMemoryHistory(BaseChatMessageHistory, BaseModel):
+ """In memory implementation of chat message history."""
-I will now write and execute a query to find the average processing duration at stn5 in the wip zone.
+ messages: List[BaseMessage] = Field(default_factory=list)
-[0m[36;1m[1;3m[(39.166666666666664,)][0m[32;1m[1;3mThe average processing duration at stn5 in the wip zone is 39.17 minutes.[0m
+ def add_messages(self, messages: List[BaseMessage]) -> None:
+ """Add a list of messages to the store"""
+ self.messages.extend(messages)
+
+ def clear(self) -> None:
+ self.messages = []
-[1m> Finished chain.[0m
-The average processing duration at stn5 in the wip zone is 39.17 minutes.
```
+In the below code snippet, we make use of the [RunnableWithMessageHistory](https://python.langchain.com/api_reference/core/runnables/langchain_core.runnables.history.RunnableWithMessageHistory.html) abstraction to wrap around the agent we've created above to provide the message history to the agent that we can now utilize by chatting with the `agent_with_chat_history` as shown below.
-```python PYTHON
-output=agent.invoke({
- "input": "Which station had the highest total duration in the wait zone?",
- "date": datetime.now()
-})
-print(output['output'])
-# Answer: stn4 - 251 minutes
+```python
+store = {}
+def get_by_session_id(session_id: str):
+ if session_id not in store:
+ store[session_id] = InMemoryHistory()
+ return store[session_id]
-```
+agent_with_chat_history = RunnableWithMessageHistory(agent, get_by_session_id, history_messages_key="chat_history")
+output = agent_with_chat_history.invoke({"input": "What station had the longest duration on 27th May 2024?", "date": datetime.now()}, config={"configurable": {"session_id": "foo"}})
+print(output["output"])
+
+# Answer: sstn2, with duration of 35 mins.
```
-[1m> Entering new Cohere SQL Agent Executor chain...[0m
-[32;1m[1;3m
-Invoking: `sql_db_list_tables` with `{}`
-responded: I will first check the schema of the available tables. Then I will write and execute a query to find the station with the highest total duration in the wait zone.
-
-[0m[38;5;200m[1;3mproduct_tracking, status[0m[32;1m[1;3m
-Invoking: `sql_db_schema` with `{'table_names': 'product_tracking, status'}`
-responded: I have found that the relevant tables are product_tracking and status. I will now write and execute a query to find the station with the highest total duration in the wait zone.
-
-[0m[33;1m[1;3m
-CREATE TABLE product_tracking (
- timestamp_start TEXT,
- timestamp_end TEXT,
- timezone TEXT,
- date TEXT,
- hour TEXT,
- station_name TEXT,
- zone TEXT,
- product_id TEXT,
- duration TEXT
-)
-/*
-3 rows from product_tracking table:
-timestamp_start timestamp_end timezone date hour station_name zone product_id duration
-2024-05-27 17:22:00 2024-05-27 17:57:00 Canada/Toronto 2024-05-27 17 stn2 wip 187 35
-2024-04-26 15:56:00 2024-04-26 17:56:00 Canada/Toronto 2024-04-26 15 stn4 wip 299 120
-2024-04-12 04:36:00 2024-04-12 05:12:00 Canada/Toronto 2024-04-12 4 stn3 wip 60 36
-*/
-
-
-CREATE TABLE status (
- timestamp_event TEXT,
- timezone TEXT,
- date TEXT,
- hour TEXT,
- minute TEXT,
- station_name TEXT,
- station_status TEXT
-)
-/*
-3 rows from status table:
-timestamp_event timezone date hour minute station_name station_status
-2024-05-09 19:28:00 Canada/Toronto 2024-05-09 19 28 stn3 downtime
-2024-04-21 06:57:00 Canada/Toronto 2024-04-21 6 57 stn3 productive
-2024-04-11 23:52:00 Canada/Toronto 2024-04-11 23 52 stn4 productive
-*/[0m[32;1m[1;3m
-Invoking: `sql_db_query_checker` with `{'query': "SELECT station_name, SUM(CAST(duration AS INTEGER)) AS total_duration FROM product_tracking WHERE zone = 'wait' GROUP BY station_name ORDER BY total_duration DESC LIMIT 1;"}`
-responded: I have found that the relevant columns are zone and duration in the product_tracking table. I will now write and execute a query to find the station with the highest total duration in the wait zone.
-
-[0m[36;1m[1;3m```sql
-SELECT station_name, SUM(CAST(duration AS INTEGER)) AS total_duration FROM product_tracking WHERE zone = 'wait' GROUP BY station_name ORDER BY total_duration DESC LIMIT 1;
-```[0m[32;1m[1;3m
-Invoking: `sql_db_query` with `{'query': "SELECT station_name, SUM(CAST(duration AS INTEGER)) AS total_duration FROM product_tracking WHERE zone = 'wait' GROUP BY station_name ORDER BY total_duration DESC LIMIT 1;"}`
-responded: I have checked the query and it is correct. I will now execute it.
-
-[0m[36;1m[1;3m[('stn4', 251)][0m[32;1m[1;3mThe station with the highest total duration in the wait zone is stn4, with a total duration of 251 minutes.[0m
-
-[1m> Finished chain.[0m
-The station with the highest total duration in the wait zone is stn4, with a total duration of 251 minutes.
+```python
+output = agent_with_chat_history.invoke({"input": "Can you tell me when this station had downtime on 2024-04-03?", "date": datetime.now()}, config={"configurable": {"session_id": "foo"}})
+print(output["output"])
+
+# Answer: 21:52:00
```
-## Conclusion
-This tutorial demonstrated how to create a SQL agent using Cohere and LangChain. The agent can translate natural language queries coming from users into SQL, and execute them against a database. This powerful combination allows for intuitive interaction with databases without requiring direct SQL knowledge.
+We can see from the above code snippets that the agent is automatically able to infer and query with respect to 'stn2' in the above question without us having to specify it explicitly. This allows us to have more coherent conversations with the agent.
+## Conclusion
+This tutorial demonstrated how to create a SQL agent using Cohere and LangChain. The agent can translate natural language queries coming from users into SQL, and execute them against a database. This powerful combination allows for intuitive interaction with databases without requiring direct SQL knowledge.
\ No newline at end of file
diff --git a/notebooks/agents/i-5O-sql-agent/sql_agent_demo.ipynb b/notebooks/agents/i-5O-sql-agent/sql_agent_demo.ipynb
index ae4d11bb6..80c6a3f4f 100644
--- a/notebooks/agents/i-5O-sql-agent/sql_agent_demo.ipynb
+++ b/notebooks/agents/i-5O-sql-agent/sql_agent_demo.ipynb
@@ -16,6 +16,7 @@
"3. Configuring the LangChain SQL Toolkit\n",
"4. Creating a custom prompt template with few-shot examples\n",
"5. Building and running the SQL agent\n",
+ "6. Adding memory to the agent to keep track of historical messages\n",
"\n",
"By the end of this tutorial, you'll have a functional SQL agent that can answer questions about your data using natural language.\n",
"\n",
@@ -71,7 +72,7 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {
"id": "zubD6E1Ty5X5"
},
@@ -79,391 +80,46 @@
"source": [
"import os\n",
"\n",
- "os.environ[\"COHERE_API_KEY\"] = \"\""
+ "os.environ[\"COHERE_API_KEY\"] = \"\""
]
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": null,
"metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "66IzK-H-zH8u",
- "outputId": "d9d1bb14-4a46-44d2-b1a6-f734c105f28f"
+ "id": "66IzK-H-zH8u"
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m85.5/85.5 MB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25h"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "! pip install faiss-gpu -qq"
+ "! pip install faiss-cpu -qq"
]
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "r0IG_1Ecyaqv",
- "outputId": "51375ba8-bec3-4362-8399-b528e67f36e2"
+ "id": "r0IG_1Ecyaqv"
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Collecting langchain-core\n",
- " Downloading langchain_core-0.3.10-py3-none-any.whl.metadata (6.3 kB)\n",
- "Collecting langchain-cohere\n",
- " Downloading langchain_cohere-0.3.0-py3-none-any.whl.metadata (6.7 kB)\n",
- "Collecting langchain-community\n",
- " Downloading langchain_community-0.3.2-py3-none-any.whl.metadata (2.8 kB)\n",
- "Requirement already satisfied: PyYAML>=5.3 in /usr/local/lib/python3.10/dist-packages (from langchain-core) (6.0.2)\n",
- "Collecting jsonpatch<2.0,>=1.33 (from langchain-core)\n",
- " Downloading jsonpatch-1.33-py2.py3-none-any.whl.metadata (3.0 kB)\n",
- "Collecting langsmith<0.2.0,>=0.1.125 (from langchain-core)\n",
- " Downloading langsmith-0.1.133-py3-none-any.whl.metadata (13 kB)\n",
- "Requirement already satisfied: packaging<25,>=23.2 in /usr/local/lib/python3.10/dist-packages (from langchain-core) (24.1)\n",
- "Requirement already satisfied: pydantic<3.0.0,>=2.5.2 in /usr/local/lib/python3.10/dist-packages (from langchain-core) (2.9.2)\n",
- "Collecting tenacity!=8.4.0,<9.0.0,>=8.1.0 (from langchain-core)\n",
- " Downloading tenacity-8.5.0-py3-none-any.whl.metadata (1.2 kB)\n",
- "Requirement already satisfied: typing-extensions>=4.7 in /usr/local/lib/python3.10/dist-packages (from langchain-core) (4.12.2)\n",
- "Collecting cohere<6.0,>=5.5.6 (from langchain-cohere)\n",
- " Downloading cohere-5.11.0-py3-none-any.whl.metadata (3.4 kB)\n",
- "Collecting langchain-experimental>=0.3.0 (from langchain-cohere)\n",
- " Downloading langchain_experimental-0.3.2-py3-none-any.whl.metadata (1.7 kB)\n",
- "Requirement already satisfied: pandas>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from langchain-cohere) (2.2.2)\n",
- "Requirement already satisfied: tabulate<0.10.0,>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from langchain-cohere) (0.9.0)\n",
- "Requirement already satisfied: SQLAlchemy<3,>=1.4 in /usr/local/lib/python3.10/dist-packages (from langchain-community) (2.0.35)\n",
- "Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /usr/local/lib/python3.10/dist-packages (from langchain-community) (3.10.8)\n",
- "Collecting dataclasses-json<0.7,>=0.5.7 (from langchain-community)\n",
- " Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)\n",
- "Collecting langchain<0.4.0,>=0.3.3 (from langchain-community)\n",
- " Downloading langchain-0.3.3-py3-none-any.whl.metadata (7.1 kB)\n",
- "Requirement already satisfied: numpy<2,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain-community) (1.26.4)\n",
- "Collecting pydantic-settings<3.0.0,>=2.4.0 (from langchain-community)\n",
- " Downloading pydantic_settings-2.5.2-py3-none-any.whl.metadata (3.5 kB)\n",
- "Requirement already satisfied: requests<3,>=2 in /usr/local/lib/python3.10/dist-packages (from langchain-community) (2.32.3)\n",
- "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain-community) (2.4.3)\n",
- "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain-community) (1.3.1)\n",
- "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain-community) (24.2.0)\n",
- "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain-community) (1.4.1)\n",
- "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain-community) (6.1.0)\n",
- "Requirement already satisfied: yarl<2.0,>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain-community) (1.13.1)\n",
- "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain-community) (4.0.3)\n",
- "Collecting boto3<2.0.0,>=1.34.0 (from cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading boto3-1.35.37-py3-none-any.whl.metadata (6.7 kB)\n",
- "Collecting fastavro<2.0.0,>=1.9.4 (from cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading fastavro-1.9.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.5 kB)\n",
- "Collecting httpx>=0.21.2 (from cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading httpx-0.27.2-py3-none-any.whl.metadata (7.1 kB)\n",
- "Collecting httpx-sse==0.4.0 (from cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading httpx_sse-0.4.0-py3-none-any.whl.metadata (9.0 kB)\n",
- "Collecting parameterized<0.10.0,>=0.9.0 (from cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading parameterized-0.9.0-py2.py3-none-any.whl.metadata (18 kB)\n",
- "Requirement already satisfied: pydantic-core<3.0.0,>=2.18.2 in /usr/local/lib/python3.10/dist-packages (from cohere<6.0,>=5.5.6->langchain-cohere) (2.23.4)\n",
- "Collecting sagemaker<3.0.0,>=2.232.1 (from cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading sagemaker-2.232.2-py3-none-any.whl.metadata (16 kB)\n",
- "Requirement already satisfied: tokenizers<1,>=0.15 in /usr/local/lib/python3.10/dist-packages (from cohere<6.0,>=5.5.6->langchain-cohere) (0.19.1)\n",
- "Collecting types-requests<3.0.0,>=2.0.0 (from cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading types_requests-2.32.0.20240914-py3-none-any.whl.metadata (1.9 kB)\n",
- "Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.7,>=0.5.7->langchain-community)\n",
- " Downloading marshmallow-3.22.0-py3-none-any.whl.metadata (7.2 kB)\n",
- "Collecting typing-inspect<1,>=0.4.0 (from dataclasses-json<0.7,>=0.5.7->langchain-community)\n",
- " Downloading typing_inspect-0.9.0-py3-none-any.whl.metadata (1.5 kB)\n",
- "Collecting jsonpointer>=1.9 (from jsonpatch<2.0,>=1.33->langchain-core)\n",
- " Downloading jsonpointer-3.0.0-py2.py3-none-any.whl.metadata (2.3 kB)\n",
- "Collecting langchain-text-splitters<0.4.0,>=0.3.0 (from langchain<0.4.0,>=0.3.3->langchain-community)\n",
- " Downloading langchain_text_splitters-0.3.0-py3-none-any.whl.metadata (2.3 kB)\n",
- "Collecting orjson<4.0.0,>=3.9.14 (from langsmith<0.2.0,>=0.1.125->langchain-core)\n",
- " Downloading orjson-3.10.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (50 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m50.4/50.4 kB\u001b[0m \u001b[31m3.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting requests-toolbelt<2.0.0,>=1.0.0 (from langsmith<0.2.0,>=0.1.125->langchain-core)\n",
- " Downloading requests_toolbelt-1.0.0-py2.py3-none-any.whl.metadata (14 kB)\n",
- "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.4.3->langchain-cohere) (2.8.2)\n",
- "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.4.3->langchain-cohere) (2024.2)\n",
- "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.4.3->langchain-cohere) (2024.2)\n",
- "Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3.0.0,>=2.5.2->langchain-core) (0.7.0)\n",
- "Collecting python-dotenv>=0.21.0 (from pydantic-settings<3.0.0,>=2.4.0->langchain-community)\n",
- " Downloading python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB)\n",
- "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain-community) (3.3.2)\n",
- "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain-community) (3.10)\n",
- "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain-community) (2.2.3)\n",
- "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain-community) (2024.8.30)\n",
- "Requirement already satisfied: greenlet!=0.4.17 in /usr/local/lib/python3.10/dist-packages (from SQLAlchemy<3,>=1.4->langchain-community) (3.1.1)\n",
- "Collecting botocore<1.36.0,>=1.35.37 (from boto3<2.0.0,>=1.34.0->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading botocore-1.35.37-py3-none-any.whl.metadata (5.7 kB)\n",
- "Collecting jmespath<2.0.0,>=0.7.1 (from boto3<2.0.0,>=1.34.0->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading jmespath-1.0.1-py3-none-any.whl.metadata (7.6 kB)\n",
- "Collecting s3transfer<0.11.0,>=0.10.0 (from boto3<2.0.0,>=1.34.0->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading s3transfer-0.10.3-py3-none-any.whl.metadata (1.7 kB)\n",
- "Requirement already satisfied: anyio in /usr/local/lib/python3.10/dist-packages (from httpx>=0.21.2->cohere<6.0,>=5.5.6->langchain-cohere) (3.7.1)\n",
- "Collecting httpcore==1.* (from httpx>=0.21.2->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading httpcore-1.0.6-py3-none-any.whl.metadata (21 kB)\n",
- "Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from httpx>=0.21.2->cohere<6.0,>=5.5.6->langchain-cohere) (1.3.1)\n",
- "Collecting h11<0.15,>=0.13 (from httpcore==1.*->httpx>=0.21.2->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading h11-0.14.0-py3-none-any.whl.metadata (8.2 kB)\n",
- "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas>=1.4.3->langchain-cohere) (1.16.0)\n",
- "Collecting attrs>=17.3.0 (from aiohttp<4.0.0,>=3.8.3->langchain-community)\n",
- " Downloading attrs-23.2.0-py3-none-any.whl.metadata (9.5 kB)\n",
- "Requirement already satisfied: cloudpickle==2.2.1 in /usr/local/lib/python3.10/dist-packages (from sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (2.2.1)\n",
- "Collecting docker (from sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading docker-7.1.0-py3-none-any.whl.metadata (3.8 kB)\n",
- "Requirement already satisfied: google-pasta in /usr/local/lib/python3.10/dist-packages (from sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (0.2.0)\n",
- "Collecting importlib-metadata<7.0,>=1.4.0 (from sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading importlib_metadata-6.11.0-py3-none-any.whl.metadata (4.9 kB)\n",
- "Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (4.23.0)\n",
- "Collecting pathos (from sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading pathos-0.3.3-py3-none-any.whl.metadata (11 kB)\n",
- "Requirement already satisfied: platformdirs in /usr/local/lib/python3.10/dist-packages (from sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (4.3.6)\n",
- "Requirement already satisfied: protobuf<5.0,>=3.12 in /usr/local/lib/python3.10/dist-packages (from sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (3.20.3)\n",
- "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (5.9.5)\n",
- "Collecting sagemaker-core<2.0.0,>=1.0.0 (from sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading sagemaker_core-1.0.10-py3-none-any.whl.metadata (4.9 kB)\n",
- "Collecting sagemaker-mlflow (from sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading sagemaker_mlflow-0.1.0-py3-none-any.whl.metadata (3.3 kB)\n",
- "Collecting schema (from sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading schema-0.7.7-py2.py3-none-any.whl.metadata (34 kB)\n",
- "Collecting smdebug-rulesconfig==1.0.1 (from sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading smdebug_rulesconfig-1.0.1-py2.py3-none-any.whl.metadata (943 bytes)\n",
- "Requirement already satisfied: tblib<4,>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (3.0.0)\n",
- "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (4.66.5)\n",
- "Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from tokenizers<1,>=0.15->cohere<6.0,>=5.5.6->langchain-cohere) (0.24.7)\n",
- "Collecting mypy-extensions>=0.3.0 (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7,>=0.5.7->langchain-community)\n",
- " Downloading mypy_extensions-1.0.0-py3-none-any.whl.metadata (1.1 kB)\n",
- "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers<1,>=0.15->cohere<6.0,>=5.5.6->langchain-cohere) (3.16.1)\n",
- "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers<1,>=0.15->cohere<6.0,>=5.5.6->langchain-cohere) (2024.6.1)\n",
- "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.10/dist-packages (from importlib-metadata<7.0,>=1.4.0->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (3.20.2)\n",
- "Requirement already satisfied: rich<14.0.0,>=13.0.0 in /usr/local/lib/python3.10/dist-packages (from sagemaker-core<2.0.0,>=1.0.0->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (13.9.1)\n",
- "Collecting mock<5.0,>4.0 (from sagemaker-core<2.0.0,>=1.0.0->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading mock-4.0.3-py3-none-any.whl.metadata (2.8 kB)\n",
- "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (2023.12.1)\n",
- "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (0.35.1)\n",
- "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (0.20.0)\n",
- "Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio->httpx>=0.21.2->cohere<6.0,>=5.5.6->langchain-cohere) (1.2.2)\n",
- "Collecting ppft>=1.7.6.9 (from pathos->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading ppft-1.7.6.9-py3-none-any.whl.metadata (12 kB)\n",
- "Collecting dill>=0.3.9 (from pathos->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading dill-0.3.9-py3-none-any.whl.metadata (10 kB)\n",
- "Collecting pox>=0.3.5 (from pathos->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading pox-0.3.5-py3-none-any.whl.metadata (8.0 kB)\n",
- "Collecting multiprocess>=0.70.17 (from pathos->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading multiprocess-0.70.17-py310-none-any.whl.metadata (7.2 kB)\n",
- "Collecting mlflow>=2.8 (from sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading mlflow-2.16.2-py3-none-any.whl.metadata (29 kB)\n",
- "Collecting mlflow-skinny==2.16.2 (from mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading mlflow_skinny-2.16.2-py3-none-any.whl.metadata (30 kB)\n",
- "Requirement already satisfied: Flask<4 in /usr/local/lib/python3.10/dist-packages (from mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (2.2.5)\n",
- "Collecting alembic!=1.10.0,<2 (from mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading alembic-1.13.3-py3-none-any.whl.metadata (7.4 kB)\n",
- "Collecting graphene<4 (from mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading graphene-3.3-py2.py3-none-any.whl.metadata (7.7 kB)\n",
- "Requirement already satisfied: markdown<4,>=3.3 in /usr/local/lib/python3.10/dist-packages (from mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (3.7)\n",
- "Requirement already satisfied: matplotlib<4 in /usr/local/lib/python3.10/dist-packages (from mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (3.7.1)\n",
- "Requirement already satisfied: pyarrow<18,>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (16.1.0)\n",
- "Requirement already satisfied: scikit-learn<2 in /usr/local/lib/python3.10/dist-packages (from mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (1.5.2)\n",
- "Requirement already satisfied: scipy<2 in /usr/local/lib/python3.10/dist-packages (from mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (1.13.1)\n",
- "Requirement already satisfied: Jinja2<4,>=2.11 in /usr/local/lib/python3.10/dist-packages (from mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (3.1.4)\n",
- "Collecting gunicorn<24 (from mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading gunicorn-23.0.0-py3-none-any.whl.metadata (4.4 kB)\n",
- "Requirement already satisfied: cachetools<6,>=5.0.0 in /usr/local/lib/python3.10/dist-packages (from mlflow-skinny==2.16.2->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (5.5.0)\n",
- "Requirement already satisfied: click<9,>=7.0 in /usr/local/lib/python3.10/dist-packages (from mlflow-skinny==2.16.2->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (8.1.7)\n",
- "Collecting databricks-sdk<1,>=0.20.0 (from mlflow-skinny==2.16.2->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading databricks_sdk-0.34.0-py3-none-any.whl.metadata (37 kB)\n",
- "Collecting gitpython<4,>=3.1.9 (from mlflow-skinny==2.16.2->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading GitPython-3.1.43-py3-none-any.whl.metadata (13 kB)\n",
- "Requirement already satisfied: opentelemetry-api<3,>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from mlflow-skinny==2.16.2->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (1.27.0)\n",
- "Requirement already satisfied: opentelemetry-sdk<3,>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from mlflow-skinny==2.16.2->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (1.27.0)\n",
- "Requirement already satisfied: sqlparse<1,>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from mlflow-skinny==2.16.2->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (0.5.1)\n",
- "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich<14.0.0,>=13.0.0->sagemaker-core<2.0.0,>=1.0.0->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (3.0.0)\n",
- "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich<14.0.0,>=13.0.0->sagemaker-core<2.0.0,>=1.0.0->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (2.18.0)\n",
- "Collecting Mako (from alembic!=1.10.0,<2->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading Mako-1.3.5-py3-none-any.whl.metadata (2.9 kB)\n",
- "Requirement already satisfied: Werkzeug>=2.2.2 in /usr/local/lib/python3.10/dist-packages (from Flask<4->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (3.0.4)\n",
- "Requirement already satisfied: itsdangerous>=2.0 in /usr/local/lib/python3.10/dist-packages (from Flask<4->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (2.2.0)\n",
- "Collecting graphql-core<3.3,>=3.1 (from graphene<4->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading graphql_core-3.2.4-py3-none-any.whl.metadata (10 kB)\n",
- "Collecting graphql-relay<3.3,>=3.1 (from graphene<4->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading graphql_relay-3.2.0-py3-none-any.whl.metadata (12 kB)\n",
- "Collecting aniso8601<10,>=8 (from graphene<4->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading aniso8601-9.0.1-py2.py3-none-any.whl.metadata (23 kB)\n",
- "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from Jinja2<4,>=2.11->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (2.1.5)\n",
- "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich<14.0.0,>=13.0.0->sagemaker-core<2.0.0,>=1.0.0->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (0.1.2)\n",
- "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib<4->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (1.3.0)\n",
- "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib<4->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (0.12.1)\n",
- "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib<4->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (4.54.1)\n",
- "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib<4->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (1.4.7)\n",
- "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib<4->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (10.4.0)\n",
- "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib<4->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (3.1.4)\n",
- "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn<2->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (1.4.2)\n",
- "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn<2->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (3.5.0)\n",
- "Requirement already satisfied: google-auth~=2.0 in /usr/local/lib/python3.10/dist-packages (from databricks-sdk<1,>=0.20.0->mlflow-skinny==2.16.2->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (2.27.0)\n",
- "Collecting gitdb<5,>=4.0.1 (from gitpython<4,>=3.1.9->mlflow-skinny==2.16.2->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading gitdb-4.0.11-py3-none-any.whl.metadata (1.2 kB)\n",
- "Requirement already satisfied: deprecated>=1.2.6 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-api<3,>=1.9.0->mlflow-skinny==2.16.2->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (1.2.14)\n",
- "Requirement already satisfied: opentelemetry-semantic-conventions==0.48b0 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-sdk<3,>=1.9.0->mlflow-skinny==2.16.2->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (0.48b0)\n",
- "Requirement already satisfied: wrapt<2,>=1.10 in /usr/local/lib/python3.10/dist-packages (from deprecated>=1.2.6->opentelemetry-api<3,>=1.9.0->mlflow-skinny==2.16.2->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (1.16.0)\n",
- "Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->gitpython<4,>=3.1.9->mlflow-skinny==2.16.2->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere)\n",
- " Downloading smmap-5.0.1-py3-none-any.whl.metadata (4.3 kB)\n",
- "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from google-auth~=2.0->databricks-sdk<1,>=0.20.0->mlflow-skinny==2.16.2->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (0.4.1)\n",
- "Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.10/dist-packages (from google-auth~=2.0->databricks-sdk<1,>=0.20.0->mlflow-skinny==2.16.2->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (4.9)\n",
- "Requirement already satisfied: pyasn1<0.7.0,>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from pyasn1-modules>=0.2.1->google-auth~=2.0->databricks-sdk<1,>=0.20.0->mlflow-skinny==2.16.2->mlflow>=2.8->sagemaker-mlflow->sagemaker<3.0.0,>=2.232.1->cohere<6.0,>=5.5.6->langchain-cohere) (0.6.1)\n",
- "Downloading langchain_core-0.3.10-py3-none-any.whl (404 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m404.4/404.4 kB\u001b[0m \u001b[31m10.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading langchain_cohere-0.3.0-py3-none-any.whl (43 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.0/44.0 kB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading langchain_community-0.3.2-py3-none-any.whl (2.4 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.4/2.4 MB\u001b[0m \u001b[31m37.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading cohere-5.11.0-py3-none-any.whl (249 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m249.2/249.2 kB\u001b[0m \u001b[31m16.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading httpx_sse-0.4.0-py3-none-any.whl (7.8 kB)\n",
- "Downloading dataclasses_json-0.6.7-py3-none-any.whl (28 kB)\n",
- "Downloading jsonpatch-1.33-py2.py3-none-any.whl (12 kB)\n",
- "Downloading langchain-0.3.3-py3-none-any.whl (1.0 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m38.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading langchain_experimental-0.3.2-py3-none-any.whl (208 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m208.1/208.1 kB\u001b[0m \u001b[31m13.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading langsmith-0.1.133-py3-none-any.whl (295 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m295.8/295.8 kB\u001b[0m \u001b[31m18.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading pydantic_settings-2.5.2-py3-none-any.whl (26 kB)\n",
- "Downloading tenacity-8.5.0-py3-none-any.whl (28 kB)\n",
- "Downloading boto3-1.35.37-py3-none-any.whl (139 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m139.1/139.1 kB\u001b[0m \u001b[31m9.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading fastavro-1.9.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m69.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading httpx-0.27.2-py3-none-any.whl (76 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m76.4/76.4 kB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading httpcore-1.0.6-py3-none-any.whl (78 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.0/78.0 kB\u001b[0m \u001b[31m5.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading jsonpointer-3.0.0-py2.py3-none-any.whl (7.6 kB)\n",
- "Downloading langchain_text_splitters-0.3.0-py3-none-any.whl (25 kB)\n",
- "Downloading marshmallow-3.22.0-py3-none-any.whl (49 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.3/49.3 kB\u001b[0m \u001b[31m3.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading orjson-3.10.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (141 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m141.9/141.9 kB\u001b[0m \u001b[31m9.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading parameterized-0.9.0-py2.py3-none-any.whl (20 kB)\n",
- "Downloading python_dotenv-1.0.1-py3-none-any.whl (19 kB)\n",
- "Downloading requests_toolbelt-1.0.0-py2.py3-none-any.whl (54 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m54.5/54.5 kB\u001b[0m \u001b[31m3.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading sagemaker-2.232.2-py3-none-any.whl (1.6 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m55.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading smdebug_rulesconfig-1.0.1-py2.py3-none-any.whl (20 kB)\n",
- "Downloading attrs-23.2.0-py3-none-any.whl (60 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m60.8/60.8 kB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading types_requests-2.32.0.20240914-py3-none-any.whl (15 kB)\n",
- "Downloading typing_inspect-0.9.0-py3-none-any.whl (8.8 kB)\n",
- "Downloading botocore-1.35.37-py3-none-any.whl (12.6 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.6/12.6 MB\u001b[0m \u001b[31m71.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading importlib_metadata-6.11.0-py3-none-any.whl (23 kB)\n",
- "Downloading jmespath-1.0.1-py3-none-any.whl (20 kB)\n",
- "Downloading mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB)\n",
- "Downloading s3transfer-0.10.3-py3-none-any.whl (82 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m82.6/82.6 kB\u001b[0m \u001b[31m5.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading sagemaker_core-1.0.10-py3-none-any.whl (388 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m388.4/388.4 kB\u001b[0m \u001b[31m23.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading docker-7.1.0-py3-none-any.whl (147 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m147.8/147.8 kB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading pathos-0.3.3-py3-none-any.whl (82 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m82.1/82.1 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading sagemaker_mlflow-0.1.0-py3-none-any.whl (24 kB)\n",
- "Downloading schema-0.7.7-py2.py3-none-any.whl (18 kB)\n",
- "Downloading dill-0.3.9-py3-none-any.whl (119 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m119.4/119.4 kB\u001b[0m \u001b[31m7.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading h11-0.14.0-py3-none-any.whl (58 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m4.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading mlflow-2.16.2-py3-none-any.whl (26.7 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m26.7/26.7 MB\u001b[0m \u001b[31m35.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading mlflow_skinny-2.16.2-py3-none-any.whl (5.6 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.6/5.6 MB\u001b[0m \u001b[31m78.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading mock-4.0.3-py3-none-any.whl (28 kB)\n",
- "Downloading multiprocess-0.70.17-py310-none-any.whl (134 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m9.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading pox-0.3.5-py3-none-any.whl (29 kB)\n",
- "Downloading ppft-1.7.6.9-py3-none-any.whl (56 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.8/56.8 kB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading alembic-1.13.3-py3-none-any.whl (233 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m233.2/233.2 kB\u001b[0m \u001b[31m14.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading graphene-3.3-py2.py3-none-any.whl (128 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m128.2/128.2 kB\u001b[0m \u001b[31m8.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading gunicorn-23.0.0-py3-none-any.whl (85 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m85.0/85.0 kB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading aniso8601-9.0.1-py2.py3-none-any.whl (52 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m52.8/52.8 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading databricks_sdk-0.34.0-py3-none-any.whl (565 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m565.6/565.6 kB\u001b[0m \u001b[31m29.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading GitPython-3.1.43-py3-none-any.whl (207 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.3/207.3 kB\u001b[0m \u001b[31m14.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading graphql_core-3.2.4-py3-none-any.whl (203 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m203.2/203.2 kB\u001b[0m \u001b[31m14.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading graphql_relay-3.2.0-py3-none-any.whl (16 kB)\n",
- "Downloading Mako-1.3.5-py3-none-any.whl (78 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.6/78.6 kB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading gitdb-4.0.11-py3-none-any.whl (62 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.7/62.7 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hDownloading smmap-5.0.1-py3-none-any.whl (24 kB)\n",
- "Installing collected packages: schema, aniso8601, types-requests, tenacity, smmap, smdebug-rulesconfig, python-dotenv, ppft, pox, parameterized, orjson, mypy-extensions, mock, marshmallow, Mako, jsonpointer, jmespath, importlib-metadata, httpx-sse, h11, gunicorn, graphql-core, fastavro, dill, attrs, typing-inspect, requests-toolbelt, multiprocess, jsonpatch, httpcore, graphql-relay, gitdb, docker, botocore, alembic, s3transfer, pydantic-settings, pathos, httpx, graphene, gitpython, dataclasses-json, databricks-sdk, langsmith, boto3, sagemaker-core, mlflow-skinny, langchain-core, mlflow, langchain-text-splitters, sagemaker-mlflow, langchain, sagemaker, langchain-community, langchain-experimental, cohere, langchain-cohere\n",
- " Attempting uninstall: tenacity\n",
- " Found existing installation: tenacity 9.0.0\n",
- " Uninstalling tenacity-9.0.0:\n",
- " Successfully uninstalled tenacity-9.0.0\n",
- " Attempting uninstall: importlib-metadata\n",
- " Found existing installation: importlib_metadata 8.4.0\n",
- " Uninstalling importlib_metadata-8.4.0:\n",
- " Successfully uninstalled importlib_metadata-8.4.0\n",
- " Attempting uninstall: attrs\n",
- " Found existing installation: attrs 24.2.0\n",
- " Uninstalling attrs-24.2.0:\n",
- " Successfully uninstalled attrs-24.2.0\n",
- "Successfully installed Mako-1.3.5 alembic-1.13.3 aniso8601-9.0.1 attrs-23.2.0 boto3-1.35.37 botocore-1.35.37 cohere-5.11.0 databricks-sdk-0.34.0 dataclasses-json-0.6.7 dill-0.3.9 docker-7.1.0 fastavro-1.9.7 gitdb-4.0.11 gitpython-3.1.43 graphene-3.3 graphql-core-3.2.4 graphql-relay-3.2.0 gunicorn-23.0.0 h11-0.14.0 httpcore-1.0.6 httpx-0.27.2 httpx-sse-0.4.0 importlib-metadata-6.11.0 jmespath-1.0.1 jsonpatch-1.33 jsonpointer-3.0.0 langchain-0.3.3 langchain-cohere-0.3.0 langchain-community-0.3.2 langchain-core-0.3.10 langchain-experimental-0.3.2 langchain-text-splitters-0.3.0 langsmith-0.1.133 marshmallow-3.22.0 mlflow-2.16.2 mlflow-skinny-2.16.2 mock-4.0.3 multiprocess-0.70.17 mypy-extensions-1.0.0 orjson-3.10.7 parameterized-0.9.0 pathos-0.3.3 pox-0.3.5 ppft-1.7.6.9 pydantic-settings-2.5.2 python-dotenv-1.0.1 requests-toolbelt-1.0.0 s3transfer-0.10.3 sagemaker-2.232.2 sagemaker-core-1.0.10 sagemaker-mlflow-0.1.0 schema-0.7.7 smdebug-rulesconfig-1.0.1 smmap-5.0.1 tenacity-8.5.0 types-requests-2.32.0.20240914 typing-inspect-0.9.0\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "! pip install langchain-core langchain-cohere langchain-community"
+ "! pip install langchain-core langchain-cohere langchain-community -qq"
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "2DXDmpwryTQk",
- "outputId": "91d8a8db-d554-473b-8bb2-d12c687f15b7"
+ "id": "2DXDmpwryTQk"
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml\n",
- "sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "from langchain.agents import AgentExecutor\n",
"from langchain_cohere import create_sql_agent\n",
- "from langchain_core.prompts import ChatPromptTemplate\n",
"from langchain_cohere.chat_models import ChatCohere\n",
- "from langchain_community.utilities.sql_database import SQLDatabase\n",
"from langchain_community.agent_toolkits import SQLDatabaseToolkit\n",
"from langchain_community.vectorstores import FAISS\n",
"from langchain_core.example_selectors import SemanticSimilarityExampleSelector\n",
"from langchain_cohere import CohereEmbeddings\n",
- "from datetime import datetime, timedelta\n",
- "import os\n",
- "import json"
+ "from datetime import datetime"
]
},
{
@@ -483,21 +139,22 @@
"source": [
"Next, we load the database for our manufacturing data.\n",
"\n",
- "We create an in-memory SQLite database using SQL scripts for the `product_tracking` and `status` tables. You can get the [SQL tables here](https://github.com/cohere-ai/cohere-developer-experience/tree/main/notebooks/agents/i-5O-sql-agent).\n",
+ "### Download the sql files from the link below to create the database.\n",
+ "\n",
+ "We create an in-memory SQLite database using SQL scripts for the `product_tracking` and `status` tables. You can get the [SQL tables here](https://github.com/cohere-ai/notebooks/tree/main/notebooks/agents/i-5O-sql-agent).\n",
"\n",
"We then create a SQLDatabase instance, which will be used by our LangChain tools and agents to interact with the data."
]
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"metadata": {
"id": "_ft-Y7xFyTQl"
},
"outputs": [],
"source": [
"import sqlite3\n",
- "import os\n",
"\n",
"from langchain_community.utilities.sql_database import SQLDatabase\n",
"from sqlalchemy import create_engine\n",
@@ -531,30 +188,16 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
- "height": 87
+ "height": 71
},
"id": "r94_RgFdyTQl",
- "outputId": "2685d55a-6933-4a22-f1d4-ae687ed93006"
+ "outputId": "c9390193-1ad5-4435-9d7a-ab3ddfa0c9e7"
},
- "outputs": [
- {
- "data": {
- "application/vnd.google.colaboratory.intrinsic+json": {
- "type": "string"
- },
- "text/plain": [
- "\"[('2024-05-09 19:28:00', 'Canada/Toronto', '2024-05-09', '19', '28', 'stn3', 'downtime'), ('2024-04-21 06:57:00', 'Canada/Toronto', '2024-04-21', '6', '57', 'stn3', 'productive'), ('2024-04-11 23:52:00', 'Canada/Toronto', '2024-04-11', '23', '52', 'stn4', 'productive'), ('2024-04-03 21:52:00', 'Canada/Toronto', '2024-04-03', '21', '52', 'stn2', 'downtime'), ('2024-04-30 05:01:00', 'Canada/Toronto', '2024-04-30', '5', '1', 'stn4', 'productive')]\""
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"# Test the connection\n",
"db.run(\"SELECT * FROM status LIMIT 5;\")"
@@ -562,30 +205,16 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
- "height": 87
+ "height": 89
},
"id": "t9KO5OuuyTQm",
- "outputId": "e9cb3de9-7237-4068-b05c-73f532dcc86f"
+ "outputId": "69b26bf7-a8e1-41c8-8427-d09b1c89480b"
},
- "outputs": [
- {
- "data": {
- "application/vnd.google.colaboratory.intrinsic+json": {
- "type": "string"
- },
- "text/plain": [
- "\"[('2024-05-27 17:22:00', '2024-05-27 17:57:00', 'Canada/Toronto', '2024-05-27', '17', 'stn2', 'wip', '187', '35'), ('2024-04-26 15:56:00', '2024-04-26 17:56:00', 'Canada/Toronto', '2024-04-26', '15', 'stn4', 'wip', '299', '120'), ('2024-04-12 04:36:00', '2024-04-12 05:12:00', 'Canada/Toronto', '2024-04-12', '4', 'stn3', 'wip', '60', '36'), ('2024-04-19 15:15:00', '2024-04-19 15:22:00', 'Canada/Toronto', '2024-04-19', '15', 'stn4', 'wait', '227', '7'), ('2024-04-24 19:10:00', '2024-04-24 21:07:00', 'Canada/Toronto', '2024-04-24', '19', 'stn4', 'wait', '169', '117')]\""
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"# Test the connection\n",
"db.run(\"SELECT * FROM product_tracking LIMIT 5;\")"
@@ -606,29 +235,20 @@
"id": "OklZuvtByTQm"
},
"source": [
- "Next, we initialize the LangChain SQL Toolkit and sets up the language model to use Cohere's model. This prepares the necessary components for querying the SQL database using natural language.\n"
+ "Next, we initialize the LangChain SQL Toolkit and initialize the language model to use Cohere's LLM. This prepares the necessary components for querying the SQL database using natural language.\n"
]
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kq4t4BksyTQm",
- "outputId": "8a487f7b-3572-4530-b58b-2e0ee2a2a8fd"
+ "outputId": "1f63e3d2-128b-478b-88f0-61026b458d25"
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "**List of pre-defined Langchain Tools**\n",
- "['sql_db_query', 'sql_db_schema', 'sql_db_list_tables', 'sql_db_query_checker']\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"## Define model to use\n",
"import os\n",
@@ -654,7 +274,7 @@
"id": "tOQLAkL_yTQn"
},
"source": [
- "## [Create a prompt template with few shot examples](#toc0_)"
+ "## [Create a prompt template](#toc0_)"
]
},
{
@@ -663,14 +283,118 @@
"id": "9sTrPZ4lyTQn"
},
"source": [
- "Next, we create a prompt template with few-shot examples. Few-shot examples are used to provide the model with context and improve its performance on specific tasks. In this case, we'll prepare examples of natural language queries and their corresponding SQL queries to help the model generate accurate SQL statements for our database.\n",
+ "Next, we create a prompt template. In this section, we will introduce a simple system message, and then also show how we can improve the prompt by introducing few shot prompting examples in the later sections.\n",
+ "The system message is used to communicate instructions or provide context to the model at the beginning of a conversation.\n",
+ "\n",
+ "In this case, we provide the model with context on what sql dialect it should use, how many samples to query among other instructions."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "fza3DILtyTQn"
+ },
+ "outputs": [],
+ "source": [
+ "from langchain_core.prompts import (\n",
+ " PromptTemplate,\n",
+ " ChatPromptTemplate,\n",
+ " SystemMessagePromptTemplate,\n",
+ " MessagesPlaceholder\n",
+ ")\n",
+ "\n",
+ "system_message = \"\"\"You are an agent designed to interact with a SQL database.\n",
+ "You are an expert at answering questions about manufacturing data.\n",
+ "Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n",
+ "Always start with checking the schema of the available tables.\n",
+ "Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.\n",
+ "You can order the results by a relevant column to return the most interesting examples in the database.\n",
+ "Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n",
+ "You have access to tools for interacting with the database.\n",
+ "Only use the given tools. Only use the information returned by the tools to construct your final answer.\n",
+ "You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n",
+ "\n",
+ "DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n",
+ "\n",
+ "The current date is {date}.\n",
+ "\n",
+ "For questions regarding productive time, downtime, productive or productivity, use minutes as units.\n",
+ "\n",
+ "For questions regarding productive time, downtime, productive or productivity use the status table.\n",
+ "\n",
+ "For questions regarding processing time and average processing time, use minutes as units.\n",
+ "\n",
+ "For questions regarding bottlenecks, processing time and average processing time use the product_tracking table.\n",
+ "\n",
+ "If the question does not seem related to the database, just return \"I don't know\" as the answer.\"\"\"\n",
+ "\n",
+ "system_prompt = PromptTemplate.from_template(system_message)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "727SGeIEyTQn"
+ },
+ "outputs": [],
+ "source": [
+ "full_prompt = ChatPromptTemplate.from_messages(\n",
+ " [\n",
+ " SystemMessagePromptTemplate(prompt=system_prompt),\n",
+ " MessagesPlaceholder(variable_name='chat_history', optional=True),\n",
+ " (\"human\", \"{input}\"),\n",
+ " MessagesPlaceholder(\"agent_scratchpad\"),\n",
+ " ]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "NUVjPXqKCEhi",
+ "outputId": "1b5e9641-b24d-4a36-ec20-c6ba3290a850"
+ },
+ "outputs": [],
+ "source": [
+ "prompt_val = full_prompt.invoke({\n",
+ " \"input\": \"What was the productive time for all stations today?\",\n",
+ " \"top_k\": 5,\n",
+ " \"dialect\": \"SQLite\",\n",
+ " \"date\":datetime.now(),\n",
+ " \"agent_scratchpad\": [],\n",
+ " })\n",
+ "print(prompt_val.to_string())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ThKE_KNu_zBi"
+ },
+ "source": [
+ "## [Create a few-shot prompt template](#toc0_)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "bEfO_ADF_4-O"
+ },
+ "source": [
+ "In the above step, we've created a simple system prompt. Now, let us see how we can create a better few shot prompt template in this section. Few-shot examples are used to provide the model with context and improve its performance on specific tasks. In this case, we'll prepare examples of natural language queries and their corresponding SQL queries to help the model generate accurate SQL statements for our database.\n",
"\n",
"In this example, we use `SemanticSimilarityExampleSelector` to select the top k examples that are most similar to an input query out of all the examples available."
]
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"metadata": {
"id": "vM3CE7MqyTQn"
},
@@ -710,7 +434,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": null,
"metadata": {
"id": "cjWTXsEgyTQn"
},
@@ -728,9 +452,9 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": null,
"metadata": {
- "id": "fza3DILtyTQn"
+ "id": "O3r_brDBAKSl"
},
"outputs": [],
"source": [
@@ -767,7 +491,8 @@
"\n",
"If the question does not seem related to the database, just return \"I don't know\" as the answer.\n",
"\n",
- "Here are some examples of user inputs and their corresponding SQL queries:\"\"\"\n",
+ "Here are some examples of user inputs and their corresponding SQL queries:\n",
+ "\"\"\"\n",
"\n",
"few_shot_prompt = FewShotPromptTemplate(\n",
" example_selector=example_selector,\n",
@@ -782,14 +507,16 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": null,
"metadata": {
- "id": "727SGeIEyTQn"
+ "id": "cgcEZ5PwAZvo"
},
"outputs": [],
"source": [
"full_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
+ " # In the previous section, this was system_prompt instead without the few shot examples.\n",
+ " # We can use either prompting style as required\n",
" SystemMessagePromptTemplate(prompt=few_shot_prompt),\n",
" (\"human\", \"{input}\"),\n",
" MessagesPlaceholder(\"agent_scratchpad\"),\n",
@@ -799,64 +526,15 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "sMAfAKBOyTQn",
- "outputId": "1af879eb-664d-441e-e946-0f504876f9ed"
+ "outputId": "874d6d19-51e9-4fce-dc92-969943ba8a5c"
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "System: You are an agent designed to interact with a SQL database.\n",
- "You are an expert at answering questions about manufacturing data.\n",
- "Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.\n",
- "Always start with checking the schema of the available tables.\n",
- "Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.\n",
- "You can order the results by a relevant column to return the most interesting examples in the database.\n",
- "Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n",
- "You have access to tools for interacting with the database.\n",
- "Only use the given tools. Only use the information returned by the tools to construct your final answer.\n",
- "You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n",
- "\n",
- "DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n",
- "\n",
- "The current date is 2024-10-10 09:40:22.052973.\n",
- "\n",
- "For questions regarding productive time, downtime, productive or productivity, use minutes as units.\n",
- "\n",
- "For questions regarding productive time, downtime, productive or productivity use the status table. \n",
- "\n",
- "For questions regarding processing time and average processing time, use minutes as units. \n",
- "\n",
- "For questions regarding bottlenecks, processing time and average processing time use the product_tracking table. \n",
- "\n",
- "If the question does not seem related to the database, just return \"I don't know\" as the answer.\n",
- "\n",
- "Here are some examples of user inputs and their corresponding SQL queries:\n",
- "\n",
- "User input: What were the productive time and downtime numbers for all stations on April 3rd 2024?\n",
- "SQL query: SELECT station_name, station_status, COUNT(*) as total_time FROM status WHERE date = '2024-04-03' GROUP BY station_name, station_status;\n",
- "\n",
- "User input: What was the average processing time for all stations on April 3rd 2024?\n",
- "SQL query: SELECT station_name, AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND zone = 'wip' GROUP BY station_name ORDER BY station_name;\n",
- "\n",
- "User input: What was the average processing time for all stations on April 3rd 2024 between 4pm and 6pm?\n",
- "SQL query: SELECT station_name, AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND CAST(hour AS INTEGER) BETWEEN 16 AND 18 AND zone = 'wip' GROUP BY station_name ORDER BY station_name;\n",
- "\n",
- "User input: What was the bottleneck station on April 3rd 2024?\n",
- "SQL query: SELECT station_name, AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND zone = 'wip' GROUP BY station_name ORDER BY avg_processing_time DESC LIMIT 1;\n",
- "\n",
- "User input: What was the average processing time for stn4 on April 3rd 2024?\n",
- "SQL query: SELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND station_name = 'stn4' AND zone = 'wip';\n",
- "Human: What was the productive time for all stations today?\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"# Example formatted prompt\n",
"prompt_val = full_prompt.invoke(\n",
@@ -893,7 +571,7 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": null,
"metadata": {
"id": "pDRp_Mi-yTQo"
},
@@ -927,174 +605,37 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hNyuwg_0yTQo",
- "outputId": "acba82f4-b424-4cd5-fb56-c463b302f057"
+ "outputId": "a0092b7d-1da1-4c5f-c718-163bcd1e99db"
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\n",
- "\n",
- "\u001b[1m> Entering new Cohere SQL Agent Executor chain...\u001b[0m\n",
- "\u001b[32;1m\u001b[1;3m\n",
- "Invoking: `sql_db_list_tables` with `{}`\n",
- "responded: I will first check the schema of the available tables. Then, I will write and execute a query to find the total downtime for each station during the month of May 2024.\n",
- "\n",
- "\u001b[0m\u001b[38;5;200m\u001b[1;3mproduct_tracking, status\u001b[0m\u001b[32;1m\u001b[1;3m\n",
- "Invoking: `sql_db_schema` with `{'table_names': 'product_tracking, status'}`\n",
- "responded: I have found that the relevant tables are 'product_tracking' and 'status'. I will now check the schema of these tables.\n",
- "\n",
- "\u001b[0m\u001b[33;1m\u001b[1;3m\n",
- "CREATE TABLE product_tracking (\n",
- "\ttimestamp_start TEXT, \n",
- "\ttimestamp_end TEXT, \n",
- "\ttimezone TEXT, \n",
- "\tdate TEXT, \n",
- "\thour TEXT, \n",
- "\tstation_name TEXT, \n",
- "\tzone TEXT, \n",
- "\tproduct_id TEXT, \n",
- "\tduration TEXT\n",
- ")\n",
- "\n",
- "/*\n",
- "3 rows from product_tracking table:\n",
- "timestamp_start\ttimestamp_end\ttimezone\tdate\thour\tstation_name\tzone\tproduct_id\tduration\n",
- "2024-05-27 17:22:00\t2024-05-27 17:57:00\tCanada/Toronto\t2024-05-27\t17\tstn2\twip\t187\t35\n",
- "2024-04-26 15:56:00\t2024-04-26 17:56:00\tCanada/Toronto\t2024-04-26\t15\tstn4\twip\t299\t120\n",
- "2024-04-12 04:36:00\t2024-04-12 05:12:00\tCanada/Toronto\t2024-04-12\t4\tstn3\twip\t60\t36\n",
- "*/\n",
- "\n",
- "\n",
- "CREATE TABLE status (\n",
- "\ttimestamp_event TEXT, \n",
- "\ttimezone TEXT, \n",
- "\tdate TEXT, \n",
- "\thour TEXT, \n",
- "\tminute TEXT, \n",
- "\tstation_name TEXT, \n",
- "\tstation_status TEXT\n",
- ")\n",
- "\n",
- "/*\n",
- "3 rows from status table:\n",
- "timestamp_event\ttimezone\tdate\thour\tminute\tstation_name\tstation_status\n",
- "2024-05-09 19:28:00\tCanada/Toronto\t2024-05-09\t19\t28\tstn3\tdowntime\n",
- "2024-04-21 06:57:00\tCanada/Toronto\t2024-04-21\t6\t57\tstn3\tproductive\n",
- "2024-04-11 23:52:00\tCanada/Toronto\t2024-04-11\t23\t52\tstn4\tproductive\n",
- "*/\u001b[0m\u001b[32;1m\u001b[1;3m\n",
- "Invoking: `sql_db_query_checker` with `{'query': \"SELECT station_name, COUNT(*) as total_downtime FROM status WHERE date LIKE '2024-05-%' AND station_status = 'downtime' GROUP BY station_name;\"}`\n",
- "responded: I have found that the 'status' table contains the 'station_name' and 'station_status' columns, which are relevant to the user's request. I will now write and execute a query to find the total downtime for each station during the month of May 2024.\n",
- "\n",
- "\u001b[0m\u001b[36;1m\u001b[1;3m```sql\n",
- "SELECT station_name, COUNT(*) as total_downtime FROM status WHERE date LIKE '2024-05-%' AND station_status = 'downtime' GROUP BY station_name;\n",
- "```\u001b[0m\u001b[32;1m\u001b[1;3m\n",
- "Invoking: `sql_db_query` with `{'query': \"SELECT station_name, COUNT(*) as total_downtime FROM status WHERE date LIKE '2024-05-%' AND station_status = 'downtime' GROUP BY station_name;\"}`\n",
- "responded: I have checked the query and it is correct. I will now execute it.\n",
- "\n",
- "\u001b[0m\u001b[36;1m\u001b[1;3m[('stn2', 1), ('stn3', 2), ('stn5', 4)]\u001b[0m\u001b[32;1m\u001b[1;3mThe total downtime for each station during the month of May 2024 was:\n",
- "\n",
- "| station_name | total_downtime |\n",
- "|---|---|\n",
- "| stn2 | 1 |\n",
- "| stn3 | 2 |\n",
- "| stn5 | 4 |\u001b[0m\n",
- "\n",
- "\u001b[1m> Finished chain.\u001b[0m\n",
- "The total downtime for each station during the month of May 2024 was:\n",
- "\n",
- "| station_name | total_downtime |\n",
- "|---|---|\n",
- "| stn2 | 1 |\n",
- "| stn3 | 2 |\n",
- "| stn5 | 4 |\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"# %%time\n",
"output=agent.invoke({\n",
- " \"input\": \"What was the total downtime for each station during the month of May 2024?\",\n",
+ " \"input\": \"Which stations had some downtime in the month of May 2024?\",\n",
" \"date\": datetime.now()\n",
"})\n",
"print(output['output'])\n",
"\n",
- "# Answer: stn2: 1 minute(s), stn3: 2 minute(s)stn5: 4 minute(s)"
+ "# Answer: stn2, stn3 and stn5 had some downtime in the month of May 2024."
]
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ptIkKBfoyTQo",
- "outputId": "74aa4dcf-9f1f-4816-fa04-d8504023317f"
+ "outputId": "596eda34-c56d-4f16-e140-b3a8c20a81fd"
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\n",
- "\n",
- "\u001b[1m> Entering new Cohere SQL Agent Executor chain...\u001b[0m\n",
- "\u001b[32;1m\u001b[1;3m\n",
- "Invoking: `sql_db_list_tables` with `{}`\n",
- "responded: I will first check the schema of the available tables. Then, I will write and execute a query to find the average processing duration at stn5 in the wip zone.\n",
- "\n",
- "\u001b[0m\u001b[38;5;200m\u001b[1;3mproduct_tracking, status\u001b[0m\u001b[32;1m\u001b[1;3m\n",
- "Invoking: `sql_db_schema` with `{'table_names': 'product_tracking'}`\n",
- "responded: The available tables are product_tracking and status.\n",
- "\n",
- "I will now write and execute a query to find the average processing duration at stn5 in the wip zone.\n",
- "\n",
- "\u001b[0m\u001b[33;1m\u001b[1;3m\n",
- "CREATE TABLE product_tracking (\n",
- "\ttimestamp_start TEXT, \n",
- "\ttimestamp_end TEXT, \n",
- "\ttimezone TEXT, \n",
- "\tdate TEXT, \n",
- "\thour TEXT, \n",
- "\tstation_name TEXT, \n",
- "\tzone TEXT, \n",
- "\tproduct_id TEXT, \n",
- "\tduration TEXT\n",
- ")\n",
- "\n",
- "/*\n",
- "3 rows from product_tracking table:\n",
- "timestamp_start\ttimestamp_end\ttimezone\tdate\thour\tstation_name\tzone\tproduct_id\tduration\n",
- "2024-05-27 17:22:00\t2024-05-27 17:57:00\tCanada/Toronto\t2024-05-27\t17\tstn2\twip\t187\t35\n",
- "2024-04-26 15:56:00\t2024-04-26 17:56:00\tCanada/Toronto\t2024-04-26\t15\tstn4\twip\t299\t120\n",
- "2024-04-12 04:36:00\t2024-04-12 05:12:00\tCanada/Toronto\t2024-04-12\t4\tstn3\twip\t60\t36\n",
- "*/\u001b[0m\u001b[32;1m\u001b[1;3m\n",
- "Invoking: `sql_db_query_checker` with `{'query': \"SELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_duration FROM product_tracking WHERE station_name = 'stn5' AND zone = 'wip';\"}`\n",
- "responded: The available tables are product_tracking and status.\n",
- "\n",
- "I will now write and execute a query to find the average processing duration at stn5 in the wip zone.\n",
- "\n",
- "\u001b[0m\u001b[36;1m\u001b[1;3mSELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_duration FROM product_tracking WHERE station_name = 'stn5' AND zone = 'wip';\u001b[0m\u001b[32;1m\u001b[1;3m\n",
- "Invoking: `sql_db_query` with `{'query': \"SELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_duration FROM product_tracking WHERE station_name = 'stn5' AND zone = 'wip';\"}`\n",
- "responded: The available tables are product_tracking and status.\n",
- "\n",
- "I will now write and execute a query to find the average processing duration at stn5 in the wip zone.\n",
- "\n",
- "\u001b[0m\u001b[36;1m\u001b[1;3m[(39.166666666666664,)]\u001b[0m\u001b[32;1m\u001b[1;3mThe average processing duration at stn5 in the wip zone is 39.17 minutes.\u001b[0m\n",
- "\n",
- "\u001b[1m> Finished chain.\u001b[0m\n",
- "The average processing duration at stn5 in the wip zone is 39.17 minutes.\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"output=agent.invoke({\n",
" \"input\": \"What is the average processing duration at stn5 in the wip zone?\",\n",
@@ -1107,85 +648,15 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6EMkPBJnyTQo",
- "outputId": "183a9be5-2869-46e1-97e0-b0d6e0ab3581"
+ "outputId": "98effb52-ff6d-40e5-afa3-443bbff52472"
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\n",
- "\n",
- "\u001b[1m> Entering new Cohere SQL Agent Executor chain...\u001b[0m\n",
- "\u001b[32;1m\u001b[1;3m\n",
- "Invoking: `sql_db_list_tables` with `{}`\n",
- "responded: I will first check the schema of the available tables. Then I will write and execute a query to find the station with the highest total duration in the wait zone.\n",
- "\n",
- "\u001b[0m\u001b[38;5;200m\u001b[1;3mproduct_tracking, status\u001b[0m\u001b[32;1m\u001b[1;3m\n",
- "Invoking: `sql_db_schema` with `{'table_names': 'product_tracking, status'}`\n",
- "responded: I have found that the relevant tables are product_tracking and status. I will now write and execute a query to find the station with the highest total duration in the wait zone.\n",
- "\n",
- "\u001b[0m\u001b[33;1m\u001b[1;3m\n",
- "CREATE TABLE product_tracking (\n",
- "\ttimestamp_start TEXT, \n",
- "\ttimestamp_end TEXT, \n",
- "\ttimezone TEXT, \n",
- "\tdate TEXT, \n",
- "\thour TEXT, \n",
- "\tstation_name TEXT, \n",
- "\tzone TEXT, \n",
- "\tproduct_id TEXT, \n",
- "\tduration TEXT\n",
- ")\n",
- "\n",
- "/*\n",
- "3 rows from product_tracking table:\n",
- "timestamp_start\ttimestamp_end\ttimezone\tdate\thour\tstation_name\tzone\tproduct_id\tduration\n",
- "2024-05-27 17:22:00\t2024-05-27 17:57:00\tCanada/Toronto\t2024-05-27\t17\tstn2\twip\t187\t35\n",
- "2024-04-26 15:56:00\t2024-04-26 17:56:00\tCanada/Toronto\t2024-04-26\t15\tstn4\twip\t299\t120\n",
- "2024-04-12 04:36:00\t2024-04-12 05:12:00\tCanada/Toronto\t2024-04-12\t4\tstn3\twip\t60\t36\n",
- "*/\n",
- "\n",
- "\n",
- "CREATE TABLE status (\n",
- "\ttimestamp_event TEXT, \n",
- "\ttimezone TEXT, \n",
- "\tdate TEXT, \n",
- "\thour TEXT, \n",
- "\tminute TEXT, \n",
- "\tstation_name TEXT, \n",
- "\tstation_status TEXT\n",
- ")\n",
- "\n",
- "/*\n",
- "3 rows from status table:\n",
- "timestamp_event\ttimezone\tdate\thour\tminute\tstation_name\tstation_status\n",
- "2024-05-09 19:28:00\tCanada/Toronto\t2024-05-09\t19\t28\tstn3\tdowntime\n",
- "2024-04-21 06:57:00\tCanada/Toronto\t2024-04-21\t6\t57\tstn3\tproductive\n",
- "2024-04-11 23:52:00\tCanada/Toronto\t2024-04-11\t23\t52\tstn4\tproductive\n",
- "*/\u001b[0m\u001b[32;1m\u001b[1;3m\n",
- "Invoking: `sql_db_query_checker` with `{'query': \"SELECT station_name, SUM(CAST(duration AS INTEGER)) AS total_duration FROM product_tracking WHERE zone = 'wait' GROUP BY station_name ORDER BY total_duration DESC LIMIT 1;\"}`\n",
- "responded: I have found that the relevant columns are zone and duration in the product_tracking table. I will now write and execute a query to find the station with the highest total duration in the wait zone.\n",
- "\n",
- "\u001b[0m\u001b[36;1m\u001b[1;3m```sql\n",
- "SELECT station_name, SUM(CAST(duration AS INTEGER)) AS total_duration FROM product_tracking WHERE zone = 'wait' GROUP BY station_name ORDER BY total_duration DESC LIMIT 1;\n",
- "```\u001b[0m\u001b[32;1m\u001b[1;3m\n",
- "Invoking: `sql_db_query` with `{'query': \"SELECT station_name, SUM(CAST(duration AS INTEGER)) AS total_duration FROM product_tracking WHERE zone = 'wait' GROUP BY station_name ORDER BY total_duration DESC LIMIT 1;\"}`\n",
- "responded: I have checked the query and it is correct. I will now execute it.\n",
- "\n",
- "\u001b[0m\u001b[36;1m\u001b[1;3m[('stn4', 251)]\u001b[0m\u001b[32;1m\u001b[1;3mThe station with the highest total duration in the wait zone is stn4, with a total duration of 251 minutes.\u001b[0m\n",
- "\n",
- "\u001b[1m> Finished chain.\u001b[0m\n",
- "The station with the highest total duration in the wait zone is stn4, with a total duration of 251 minutes.\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"output=agent.invoke({\n",
" \"input\": \"Which station had the highest total duration in the wait zone?\",\n",
@@ -1196,6 +667,123 @@
"# Answer: stn4 - 251 minutes\n"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fX4NbrAr9QFg"
+ },
+ "source": [
+ "## Memory in the sql agent\n",
+ "We may want the agent to hold memory of our previous messages so that we're able to coherently engage with the agent to answer our queries. In this section, let's take a look at how we can add memory to the agent so that we're able to achieve this outcome!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "exzKbTS99QFg"
+ },
+ "outputs": [],
+ "source": [
+ "from langchain_core.runnables.history import RunnableWithMessageHistory\n",
+ "from langchain_core.chat_history import BaseChatMessageHistory\n",
+ "from langchain_core.messages import BaseMessage\n",
+ "from pydantic import BaseModel, Field\n",
+ "from typing import List"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "mRevYNNm9QFg"
+ },
+ "source": [
+ "In the code snippets below, we create a class to store the chat history in memory. This can be customised to store the messages from a database or any other suitable data store."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "6R7zoxd89QFg"
+ },
+ "outputs": [],
+ "source": [
+ "class InMemoryHistory(BaseChatMessageHistory, BaseModel):\n",
+ " \"\"\"In memory implementation of chat message history.\"\"\"\n",
+ "\n",
+ " messages: List[BaseMessage] = Field(default_factory=list)\n",
+ "\n",
+ " def add_messages(self, messages: List[BaseMessage]) -> None:\n",
+ " \"\"\"Add a list of messages to the store\"\"\"\n",
+ " self.messages.extend(messages)\n",
+ "\n",
+ " def clear(self) -> None:\n",
+ " self.messages = []\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_1GYoAwLC_c7"
+ },
+ "source": [
+ "In the below code snippet, we make use of the [RunnableWithMessageHistory](https://python.langchain.com/api_reference/core/runnables/langchain_core.runnables.history.RunnableWithMessageHistory.html) abstraction to wrap around the agent we've created above to provide the message history to the agent that we can now utilize by chatting with the `agent_with_chat_history` as shown below."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "3-5w4j_H9QFg",
+ "outputId": "24dc77ef-5fb5-4de2-fa18-d20df79b9a12"
+ },
+ "outputs": [],
+ "source": [
+ "store = {}\n",
+ "def get_by_session_id(session_id: str):\n",
+ " if session_id not in store:\n",
+ " store[session_id] = InMemoryHistory()\n",
+ " return store[session_id]\n",
+ "\n",
+ "agent_with_chat_history = RunnableWithMessageHistory(agent, get_by_session_id, history_messages_key=\"chat_history\")\n",
+ "\n",
+ "output = agent_with_chat_history.invoke({\"input\": \"What station had the longest duration on 27th May 2024?\", \"date\": datetime.now()}, config={\"configurable\": {\"session_id\": \"foo\"}})\n",
+ "print(output[\"output\"])\n",
+ "\n",
+ "# Answer: sstn2, with duration of 35 mins."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "LM_ybIgx9QFg",
+ "outputId": "7a5dae65-8348-42aa-bb4c-a3ddb11ee7d9"
+ },
+ "outputs": [],
+ "source": [
+ "output = agent_with_chat_history.invoke({\"input\": \"Can you tell me when this station had downtime on 2024-04-03?\", \"date\": datetime.now()}, config={\"configurable\": {\"session_id\": \"foo\"}})\n",
+ "print(output[\"output\"])\n",
+ "\n",
+ "# Answer: 21:52:00"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KMHm5rux9QFg"
+ },
+ "source": [
+ "We can see from the above code snippets that the agent is automatically able to infer and query with respect to 'stn2' in the above question without us having to specify it explicitly. This allows us to have more coherent conversations with the agent."
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {
@@ -1226,7 +814,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.4"
+ "version": "3.11.9"
}
},
"nbformat": 4,