From ada2448ec6aba6dbf16adca411bf0a353301fa33 Mon Sep 17 00:00:00 2001 From: Sajed Jalil Date: Sun, 29 Sep 2024 23:26:59 -0400 Subject: [PATCH] Memory with summarization added --- README.md | 4 ++- home/controllers/__init__.py | 0 home/llm/constants.py | 35 ++++++++++++++++++++++++ home/llm/function_tools/__init__.py | 0 home/llm/{ => function_tools}/tools.py | 2 +- home/llm/llm_graph.py | 21 ++++++++++----- home/services/__init__.py | 0 home/tests.py | 37 +++++++++++++++----------- 8 files changed, 76 insertions(+), 23 deletions(-) create mode 100644 home/controllers/__init__.py create mode 100644 home/llm/function_tools/__init__.py rename home/llm/{ => function_tools}/tools.py (90%) create mode 100644 home/services/__init__.py diff --git a/README.md b/README.md index 02b9719..b124db0 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ # Patient-Chat -An chat application that uses Langchain, Langgraph and knowledge graph. +A chat application that uses Langchain, Langgraph and knowledge graph. +## Long chat optimizations +Front end sends all the history to the backend. But we filter and summarize in the backend and store the summary in the database for future use, using a random ```thread_id``` in Langsmith. This ```thread_id``` is unique across all database users. ## database - Install postgresql from https://www.postgresql.org/download/ diff --git a/home/controllers/__init__.py b/home/controllers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/home/llm/constants.py b/home/llm/constants.py index 34a6314..922d1ce 100644 --- a/home/llm/constants.py +++ b/home/llm/constants.py @@ -9,3 +9,38 @@ summary_prompt = "Create a summary of the conversation above in less than 300 words: " summarize_trigger_count = 4 + +history_one_turn = history = [ + { + "role": "user", + "content": "Hi, my name is Sajed. What's your name?" + }, + { + "role": "assistant", + "content": "Hello Sajed, how can I assist you today? I'm an AI medical assistant and I'm happy to " + "help with any health-related inquiries or requests you may have. Please let me know how I " + "can be of service." + }, +] + + +history_two_turns = [ + { + "role": "user", + "content": "Hi, my name is Sajed. What's your name?" + }, + { + "role": "assistant", + "content": "Hello Sajed, how can I assist you today? I'm an AI medical assistant and I'm happy to " + "help with any health-related inquiries or requests you may have. Please let me know how I " + "can be of service." + }, + { + "role": "user", + "content": "Hi, how are you? Answer in one line." + }, + { + "role": "assistant", + "content": "I am doing great." + } +] \ No newline at end of file diff --git a/home/llm/function_tools/__init__.py b/home/llm/function_tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/home/llm/tools.py b/home/llm/function_tools/tools.py similarity index 90% rename from home/llm/tools.py rename to home/llm/function_tools/tools.py index 0c7848f..0be9c20 100644 --- a/home/llm/tools.py +++ b/home/llm/function_tools/tools.py @@ -1,6 +1,6 @@ class Tools: def request_medication_change(previous_medication: str) -> str: - """Puts a request to the doctor for medication change. + """Puts a request to the doctor for medication change when the previous medication name is mentioned Returns a string with the name of current medication and Change request submitted. diff --git a/home/llm/llm_graph.py b/home/llm/llm_graph.py index bb1e289..774324a 100644 --- a/home/llm/llm_graph.py +++ b/home/llm/llm_graph.py @@ -8,7 +8,7 @@ from home.llm import constants from home.llm.constants import summary_prompt, summarize_trigger_count -from home.llm.tools import Tools +from home.llm.function_tools.tools import Tools logger = logging.getLogger(__name__) @@ -25,13 +25,13 @@ def __init__(self): Tools.make_appointment, Tools.request_appointment_change ] + self.model = self.model.bind_tools(self.tool_list) memory = MemorySaver() self.graph = self.build_graph().compile(checkpointer=memory) def ai_agent(self, state: State): sys_msg = SystemMessage(content=constants.prompt_text) - model_with_tools = self.model.bind_tools(self.tool_list) - return {"messages": [model_with_tools.invoke([sys_msg] + state["messages"])]} + return {"messages": [self.model.invoke([sys_msg] + state["messages"])]} def build_summarize_subgraph(self) -> StateGraph: builder = StateGraph(State) @@ -40,20 +40,29 @@ def build_summarize_subgraph(self) -> StateGraph: builder.add_edge("summarize_conversation", END) return builder - def build_graph(self) -> StateGraph: + def build_tool_call_subgraph(self) -> StateGraph: builder = StateGraph(State) builder.add_node("ai_agent", self.ai_agent) builder.add_node("tools", ToolNode(self.tool_list)) - builder.add_node("summarization_subgraph", self.build_summarize_subgraph().compile()) builder.add_edge(START, "ai_agent") - builder.add_edge("ai_agent", "summarization_subgraph") builder.add_conditional_edges("ai_agent", tools_condition) builder.add_edge("tools", "ai_agent") + + return builder + + def build_graph(self) -> StateGraph: + builder = StateGraph(State) + builder.add_node("summarization_subgraph", self.build_summarize_subgraph().compile()) + builder.add_node("tool_call_subgraph", self.build_tool_call_subgraph().compile()) + + builder.add_edge(START, "tool_call_subgraph") + builder.add_edge("tool_call_subgraph", "summarization_subgraph" ) builder.add_edge("summarization_subgraph", END) return builder + def inference(self, user_message: str, history: List[dict], thread_id: str): config = {"configurable": {"thread_id": thread_id}} messages = self.convert_history_to_messages(history) diff --git a/home/services/__init__.py b/home/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/home/tests.py b/home/tests.py index 685a4ec..abd8bf9 100644 --- a/home/tests.py +++ b/home/tests.py @@ -1,5 +1,8 @@ +import uuid + from django.test import TestCase +from home.llm.constants import history_two_turns from home.llm.llm_graph import LLMGraph @@ -9,26 +12,30 @@ def setUp(self): def test_llm_name(self): user_message = "Hi, my name is Sajed. What's your name?" - ai_response = self.llm_graph.inference(user_message, []) + ai_response, summary = self.llm_graph.inference(user_message, [], "test_"+str(uuid.uuid4())) self.assertTrue(ai_response.__contains__("Patient Chat")) def test_llm_remembers_context_history(self): - history = [ - { - "role": "user", - "content": "Hi, my name is Sajed. What's your name?" - }, - { - "role": "assistant", - "content": "Hello Sajed, how can I assist you today? I'm an AI medical assistant and I'm happy to " - "help with any health-related inquiries or requests you may have. Please let me know how I " - "can be of service." - }, - ] + history = history_two_turns user_message = "Now tell me what's my name?" - ai_response = self.llm_graph.inference(user_message, history) + ai_response, summary = self.llm_graph.inference(user_message, history, "test_"+str(uuid.uuid4()) ) - print(ai_response) self.assertTrue(ai_response.__contains__("Sajed")) + + def test_llm_tool_call(self): + + history = [] + user_message = "Change my medicine lorazepam." + ai_response, summary = self.llm_graph.inference(user_message, history, "test_"+str(uuid.uuid4()) ) + + self.assertGreater( len(ai_response), 50) + + def test_llm_tool_call_with_summary(self): + + history = history_two_turns + user_message = "Change my medicine lorazepam." + ai_response, summary = self.llm_graph.inference(user_message, history, "test_"+str(uuid.uuid4()) ) + + self.assertGreater( len(summary), 0) \ No newline at end of file