Skip to content

Commit

Permalink
Thread Id and unit tests added
Browse files Browse the repository at this point in the history
  • Loading branch information
sajedjalil committed Sep 29, 2024
1 parent 4577b06 commit a6618d8
Show file tree
Hide file tree
Showing 13 changed files with 299 additions and 259 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@ An chat application that uses Langchain, Langgraph and knowledge graph.
- Then run ```./create_db_tables.sh```
- Lastly, run ```./insert_data.sh```
- Configure the database connection parameters in settings.py

## Testing
- Run tests by ```python manage.py test```


2 changes: 1 addition & 1 deletion db_scripts/create_db_tables.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ psql -d $DB_NAME -c "
CREATE TABLE IF NOT EXISTS chat_history (
id SERIAL PRIMARY KEY,
patient_id INTEGER NOT NULL,
chat_id INTEGER NOT NULL,
thread_id VARCHAR(50) NOT NULL,
is_user BOOLEAN NOT NULL,
text TEXT NOT NULL,
summary TEXT,
Expand Down
16 changes: 8 additions & 8 deletions db_scripts/insert_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ INSERT INTO patient (
next_appointment,
doctor_name
) VALUES (
'John',
'Doe',
'1980-01-15',
'Nikola',
'Tesla',
'07-10-1856',
'(555) 123-4567',
'john.doe@email.com',
'nicktesla@email.com',
'Hypertension, Type 2 Diabetes',
'Metformin 500mg twice daily, Lisinopril 10mg once daily',
'2023-09-15 10:30:00',
Expand All @@ -31,11 +31,11 @@ INSERT INTO patient (
) RETURNING id;
-- Insert sample chat history entries
INSERT INTO chat_history (patient_id, chat_id, is_user, text, summary)
INSERT INTO chat_history (patient_id, thread_id, is_user, text, summary)
VALUES
(currval('patient_id_seq'), 1, FALSE, 'Hello, I''ve been experiencing increased thirst lately. Is this related to my diabetes?', 'Patient reports increased thirst'),
(currval('patient_id_seq'), 1, TRUE, 'Hi John, increased thirst can indeed be a symptom of diabetes. Let''s schedule a check-up to monitor your blood sugar levels. How does next week look for you?', 'System suggests check-up for diabetes symptoms'),
(currval('patient_id_seq'), 2, FALSE, 'Next week works for me. Should I prepare anything specific for the appointment?', 'Patient agrees to check-up, asks for preparation instructions');
(currval('patient_id_seq'), 'ebc92830-6110-45d8-bea9-fead9e51cd17', FALSE, 'Hello, I''ve been experiencing increased thirst lately. Is this related to my diabetes?', 'Patient reports increased thirst'),
(currval('patient_id_seq'), 'ebc92830-6110-45d8-bea9-fead9e51cd17', TRUE, 'Hi John, increased thirst can indeed be a symptom of diabetes. Let''s schedule a check-up to monitor your blood sugar levels. How does next week look for you?', 'System suggests check-up for diabetes symptoms'),
(currval('patient_id_seq'), 'nks72743-4565-87s6-idw4-tour3e23we52', FALSE, 'Next week works for me. Should I prepare anything specific for the appointment?', 'Patient agrees to check-up, asks for preparation instructions');
"

# Execute the SQL commands
Expand Down
2 changes: 1 addition & 1 deletion home/db_schema/chat_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

class ChatHistory(models.Model):
patient_id = models.IntegerField()
chat_id = models.IntegerField()
thread_id = models.CharField(max_length=50)
is_user = models.BooleanField()
text = models.TextField()
summary = models.TextField(null=True, blank=True)
Expand Down
4 changes: 2 additions & 2 deletions home/db_schema/patient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@


class Patient(models.Model):
first_name = models.CharField(max_length=100)
last_name = models.CharField(max_length=100)
first_name = models.CharField(max_length=50)
last_name = models.CharField(max_length=50)
date_of_birth = models.DateField()
phone_number = models.CharField(max_length=20)
email = models.EmailField()
Expand Down
102 changes: 61 additions & 41 deletions home/llm/llm_graph.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from langchain_anthropic import ChatAnthropic
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.graph.state import CompiledStateGraph
from langgraph.graph import START, StateGraph, MessagesState, END
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, RemoveMessage

from typing import Literal, List
from django.utils import timezone
import logging

from home.db_schema.chat_history import ChatHistory
from home.llm import prompt
from home.llm.tools import Tools
from django.utils import timezone

logger = logging.getLogger(__name__)

Expand All @@ -21,57 +20,78 @@ class State(MessagesState):
class LLMGraph:
def __init__(self):
self.model = ChatAnthropic(model="claude-3-haiku-20240307")
self.tool_list = [Tools.request_medication_change, Tools.make_appointment, Tools.request_appointment_change]
self.graph = self.build_graph()
self.tool_list = [
Tools.request_medication_change,
Tools.make_appointment,
Tools.request_appointment_change
]
self.graph = self.build_graph().compile()

def assistant(self, state: State):
sys_msg = SystemMessage(content=prompt.prompt_text)
model_with_tools = self.model.bind_tools(self.tool_list)
return {"messages": [model_with_tools.invoke([sys_msg] + state["messages"])]}

def build_summarize_subgraph(self) -> StateGraph:
builder = StateGraph(State)
builder.add_node("summarize_conversation", self.summarize_conversation)
builder.add_conditional_edges(START, self.if_need_summarization)
builder.add_edge("summarize_conversation", END)
return builder

def build_graph(self) -> CompiledStateGraph:
def build_graph(self) -> StateGraph:
builder = StateGraph(State)
builder.add_node("assistant", self.assistant)
builder.add_node("tools", ToolNode(self.tool_list))
builder.add_node("summarization_subgraph", self.build_summarize_subgraph().compile())

builder.add_edge(START, "assistant")
builder.add_edge(START, "summarization_subgraph")
builder.add_edge("summarization_subgraph", "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")

return builder.compile()

def assistant(self, state: State):
# Prompt message
sys_msg = SystemMessage(content=prompt.prompt_text)
model_with_tools = self.model.bind_tools(self.tool_list)
return {"messages": [model_with_tools.invoke([sys_msg] + state["messages"])]}

def inference(self, user_message, history) -> str:
messages = []
for msg in history:
if msg['role'] == 'user':
messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
messages.append(AIMessage(content=msg['content']))
return builder

def inference(self, user_message: str, history: List[dict]) -> str:
messages = self.convert_history_to_messages(history)
messages.append(HumanMessage(content=user_message))

result = self.graph.invoke({"messages": messages})
logger.debug(result)

assistant_response = result['messages'][-1].content

# Create user message entry
ChatHistory.objects.create(
patient_id=1,
chat_id=1,
is_user=True,
text=user_message,
timestamp=timezone.now()
)

# Create assistant message entry
ChatHistory.objects.create(
patient_id=1,
chat_id=1,
is_user=False,
text=assistant_response,
timestamp=timezone.now()
)

return assistant_response

def summarize_conversation(self, state: State):
summary = state.get("summary", "")
summary_message = self.get_summary_message(summary)

messages = state["messages"] + [HumanMessage(content=summary_message)]
response = self.model.invoke(messages)

delete_messages = [RemoveMessage(id=m.id) for m in state["messages"][:-2]]
return {"summary": response.content, "messages": delete_messages}

@staticmethod
def if_need_summarization(state: State) -> Literal["summarize_conversation", "__end__"]:
if len(state["messages"]) >= 6:
return "summarize_conversation"
else:
return "__end__"

@staticmethod
def convert_history_to_messages(history: List[dict]) -> List[HumanMessage | AIMessage]:
return [
HumanMessage(content=msg['content']) if msg['role'] == 'user'
else AIMessage(content=msg['content'])
for msg in history
]

@staticmethod
def get_summary_message(summary: str) -> str:
if summary:
return f"This is summary of the conversation to date: {summary}\n\n" \
"Extend the summary by taking into account the new messages above:"
return "Create a summary of the conversation above:"

Loading

0 comments on commit a6618d8

Please sign in to comment.