From 9f1b66a186dfacf9a50565b5db49583fbecf9380 Mon Sep 17 00:00:00 2001 From: ks6088ts Date: Sun, 28 Jul 2024 07:27:10 +0900 Subject: [PATCH] add memory feature --- frontend/pages/tool_agent.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/frontend/pages/tool_agent.py b/frontend/pages/tool_agent.py index f10ec7f..5c72cbf 100644 --- a/frontend/pages/tool_agent.py +++ b/frontend/pages/tool_agent.py @@ -6,8 +6,11 @@ from langchain.agents import AgentExecutor, create_tool_calling_agent from langchain.memory import ConversationBufferWindowMemory from langchain_community.callbacks import StreamlitCallbackHandler +from langchain_community.chat_message_histories import ChatMessageHistory +from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnableConfig +from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_openai import AzureChatOpenAI from tools.fetch_contoso_rules import fetch_contoso_rules from tools.search_ddg import search_ddg @@ -48,6 +51,21 @@ ユーザーが日本語で質問した場合は、日本語で回答してください。ユーザーがスペイン語で質問した場合は、スペイン語で回答してください。 """ +store = {} + + +def wrap_with_history(runnable): + def get_session_history(session_id: str) -> BaseChatMessageHistory: + if session_id not in store: + store[session_id] = ChatMessageHistory() + return store[session_id] + + return RunnableWithMessageHistory( + runnable, + get_session_history, + history_messages_key="chat_history", + ) + def init_page(): st.set_page_config(page_title="Web Browsing Agent", page_icon="🤗") @@ -96,7 +114,9 @@ def create_agent(): def main(): init_page() init_messages() - web_browsing_agent = create_agent() + session_id = "example_session" + + web_browsing_agent = wrap_with_history(create_agent()) for msg in st.session_state["memory"].chat_memory.messages: st.chat_message(msg.type).write(msg.content) @@ -108,9 +128,19 @@ def main(): st_cb = StreamlitCallbackHandler(st.container(), expand_new_thoughts=True) response = web_browsing_agent.invoke( {"input": prompt}, - config=RunnableConfig({"callbacks": [st_cb]}), + config=RunnableConfig( + { + "callbacks": [ + st_cb, + ], + "configurable": { + "session_id": session_id, + }, + } + ), ) st.write(response["output"]) + print(f"\chat_history:\n{store[session_id]}") if __name__ == "__main__":