Skip to content

Commit

Permalink
add memory feature
Browse files Browse the repository at this point in the history
  • Loading branch information
ks6088ts committed Jul 28, 2024
1 parent 2a902e9 commit 9f1b66a
Showing 1 changed file with 32 additions and 2 deletions.
34 changes: 32 additions & 2 deletions frontend/pages/tool_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain.memory import ConversationBufferWindowMemory
from langchain_community.callbacks import StreamlitCallbackHandler
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_openai import AzureChatOpenAI
from tools.fetch_contoso_rules import fetch_contoso_rules
from tools.search_ddg import search_ddg
Expand Down Expand Up @@ -48,6 +51,21 @@
ユーザーが日本語で質問した場合は、日本語で回答してください。ユーザーがスペイン語で質問した場合は、スペイン語で回答してください。
"""

store = {}


def wrap_with_history(runnable):
def get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in store:
store[session_id] = ChatMessageHistory()
return store[session_id]

return RunnableWithMessageHistory(
runnable,
get_session_history,
history_messages_key="chat_history",
)


def init_page():
st.set_page_config(page_title="Web Browsing Agent", page_icon="🤗")
Expand Down Expand Up @@ -96,7 +114,9 @@ def create_agent():
def main():
init_page()
init_messages()
web_browsing_agent = create_agent()
session_id = "example_session"

web_browsing_agent = wrap_with_history(create_agent())

for msg in st.session_state["memory"].chat_memory.messages:
st.chat_message(msg.type).write(msg.content)
Expand All @@ -108,9 +128,19 @@ def main():
st_cb = StreamlitCallbackHandler(st.container(), expand_new_thoughts=True)
response = web_browsing_agent.invoke(
{"input": prompt},
config=RunnableConfig({"callbacks": [st_cb]}),
config=RunnableConfig(
{
"callbacks": [
st_cb,
],
"configurable": {
"session_id": session_id,
},
}
),
)
st.write(response["output"])
print(f"\chat_history:\n{store[session_id]}")


if __name__ == "__main__":
Expand Down

0 comments on commit 9f1b66a

Please sign in to comment.