Skip to content

Commit

Permalink
Memory with summarization added
Browse files Browse the repository at this point in the history
  • Loading branch information
sajedjalil committed Sep 30, 2024
1 parent 47538d0 commit ada2448
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 23 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Patient-Chat
An chat application that uses Langchain, Langgraph and knowledge graph.
A chat application that uses Langchain, Langgraph and knowledge graph.

## Long chat optimizations
Front end sends all the history to the backend. But we filter and summarize in the backend and store the summary in the database for future use, using a random ```thread_id``` in Langsmith. This ```thread_id``` is unique across all database users.

## database
- Install postgresql from https://www.postgresql.org/download/
Expand Down
Empty file added home/controllers/__init__.py
Empty file.
35 changes: 35 additions & 0 deletions home/llm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,38 @@

summary_prompt = "Create a summary of the conversation above in less than 300 words: "
summarize_trigger_count = 4

history_one_turn = history = [
{
"role": "user",
"content": "Hi, my name is Sajed. What's your name?"
},
{
"role": "assistant",
"content": "Hello Sajed, how can I assist you today? I'm an AI medical assistant and I'm happy to "
"help with any health-related inquiries or requests you may have. Please let me know how I "
"can be of service."
},
]


history_two_turns = [
{
"role": "user",
"content": "Hi, my name is Sajed. What's your name?"
},
{
"role": "assistant",
"content": "Hello Sajed, how can I assist you today? I'm an AI medical assistant and I'm happy to "
"help with any health-related inquiries or requests you may have. Please let me know how I "
"can be of service."
},
{
"role": "user",
"content": "Hi, how are you? Answer in one line."
},
{
"role": "assistant",
"content": "I am doing great."
}
]
Empty file.
2 changes: 1 addition & 1 deletion home/llm/tools.py → home/llm/function_tools/tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
class Tools:
def request_medication_change(previous_medication: str) -> str:
"""Puts a request to the doctor for medication change.
"""Puts a request to the doctor for medication change when the previous medication name is mentioned
Returns a string with the name of current medication and Change request submitted.
Expand Down
21 changes: 15 additions & 6 deletions home/llm/llm_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from home.llm import constants
from home.llm.constants import summary_prompt, summarize_trigger_count
from home.llm.tools import Tools
from home.llm.function_tools.tools import Tools

logger = logging.getLogger(__name__)

Expand All @@ -25,13 +25,13 @@ def __init__(self):
Tools.make_appointment,
Tools.request_appointment_change
]
self.model = self.model.bind_tools(self.tool_list)
memory = MemorySaver()
self.graph = self.build_graph().compile(checkpointer=memory)

def ai_agent(self, state: State):
sys_msg = SystemMessage(content=constants.prompt_text)
model_with_tools = self.model.bind_tools(self.tool_list)
return {"messages": [model_with_tools.invoke([sys_msg] + state["messages"])]}
return {"messages": [self.model.invoke([sys_msg] + state["messages"])]}

def build_summarize_subgraph(self) -> StateGraph:
builder = StateGraph(State)
Expand All @@ -40,20 +40,29 @@ def build_summarize_subgraph(self) -> StateGraph:
builder.add_edge("summarize_conversation", END)
return builder

def build_graph(self) -> StateGraph:
def build_tool_call_subgraph(self) -> StateGraph:
builder = StateGraph(State)
builder.add_node("ai_agent", self.ai_agent)
builder.add_node("tools", ToolNode(self.tool_list))
builder.add_node("summarization_subgraph", self.build_summarize_subgraph().compile())

builder.add_edge(START, "ai_agent")
builder.add_edge("ai_agent", "summarization_subgraph")
builder.add_conditional_edges("ai_agent", tools_condition)
builder.add_edge("tools", "ai_agent")

return builder

def build_graph(self) -> StateGraph:
builder = StateGraph(State)
builder.add_node("summarization_subgraph", self.build_summarize_subgraph().compile())
builder.add_node("tool_call_subgraph", self.build_tool_call_subgraph().compile())

builder.add_edge(START, "tool_call_subgraph")
builder.add_edge("tool_call_subgraph", "summarization_subgraph" )
builder.add_edge("summarization_subgraph", END)

return builder


def inference(self, user_message: str, history: List[dict], thread_id: str):
config = {"configurable": {"thread_id": thread_id}}
messages = self.convert_history_to_messages(history)
Expand Down
Empty file added home/services/__init__.py
Empty file.
37 changes: 22 additions & 15 deletions home/tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import uuid

from django.test import TestCase

from home.llm.constants import history_two_turns
from home.llm.llm_graph import LLMGraph


Expand All @@ -9,26 +12,30 @@ def setUp(self):

def test_llm_name(self):
user_message = "Hi, my name is Sajed. What's your name?"
ai_response = self.llm_graph.inference(user_message, [])
ai_response, summary = self.llm_graph.inference(user_message, [], "test_"+str(uuid.uuid4()))

self.assertTrue(ai_response.__contains__("Patient Chat"))

def test_llm_remembers_context_history(self):

history = [
{
"role": "user",
"content": "Hi, my name is Sajed. What's your name?"
},
{
"role": "assistant",
"content": "Hello Sajed, how can I assist you today? I'm an AI medical assistant and I'm happy to "
"help with any health-related inquiries or requests you may have. Please let me know how I "
"can be of service."
},
]
history = history_two_turns
user_message = "Now tell me what's my name?"
ai_response = self.llm_graph.inference(user_message, history)
ai_response, summary = self.llm_graph.inference(user_message, history, "test_"+str(uuid.uuid4()) )

print(ai_response)
self.assertTrue(ai_response.__contains__("Sajed"))

def test_llm_tool_call(self):

history = []
user_message = "Change my medicine lorazepam."
ai_response, summary = self.llm_graph.inference(user_message, history, "test_"+str(uuid.uuid4()) )

self.assertGreater( len(ai_response), 50)

def test_llm_tool_call_with_summary(self):

history = history_two_turns
user_message = "Change my medicine lorazepam."
ai_response, summary = self.llm_graph.inference(user_message, history, "test_"+str(uuid.uuid4()) )

self.assertGreater( len(summary), 0)

0 comments on commit ada2448

Please sign in to comment.