Skip to content

Commit

Permalink
Medical Insights Added
Browse files Browse the repository at this point in the history
  • Loading branch information
sajedjalil committed Oct 2, 2024
1 parent cf5e4b7 commit a0a2133
Showing 9 changed files with 219 additions and 279 deletions.
8 changes: 4 additions & 4 deletions home/constants/constants.py
Original file line number Diff line number Diff line change
@@ -13,10 +13,10 @@
modify appointments and medication changes. You should filter out and ignore any unrelated, overly sensitive,
or controversial topics.'''

rag_prompt_text = '''You are a helpful AI medical assistant specialized in making medical insights. You also take
account patient name, date of birth, medical conditions and medications that are taken by a patient. You will build
your responses based on these information and other provided medical knowledge. Now use your medical expertise to
provide a comprehensive medical insight for the texts below:'''
rag_prompt_text = '''You are a helpful AI medical assistant specialized in summarizing personal medical data,
health issues and their connection to various given symptoms. You also take account patient name, age,
medical conditions and medications. You will build your responses based on these information and other provided
medical knowledge.'''

history_one_turn = history = [
{
11 changes: 9 additions & 2 deletions home/llm/llm_graph.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from langchain_community.graphs import Neo4jGraph
from langchain_core.documents import Document
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, StateGraph, MessagesState, END
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, RemoveMessage, ToolMessage
from typing import Literal, List

from home.constants import constants
from home.constants.chat_models import model_claude_3_haiku
from home.constants.chat_models import model_claude_3_haiku, model_claude_3_5_sonnet
from home.constants.constants import summary_prompt, summarize_trigger_count
from home.llm.function_tools.tools import Tools
from home.models.patient import Patient
@@ -22,13 +25,16 @@ class State(MessagesState):

class LLMGraph:
def __init__(self):
self.patient = Patient.objects.first()

self.model = model_claude_3_haiku
self.model = self.model.bind_tools(tool_list)
memory = MemorySaver()
self.graph = self.build_graph().compile(checkpointer=memory)

def ai_agent(self, state: State):
sys_msg = SystemMessage(content=constants.llm_prompt_text)
sys_msg = SystemMessage(content=constants.llm_prompt_text + "Currently, you are chatting with a patient with "
"the following information: " + self.patient.__str__())
return {"messages": [self.model.invoke([sys_msg] + state["messages"])]}

def build_summarize_subgraph(self) -> StateGraph:
@@ -61,6 +67,7 @@ def build_graph(self) -> StateGraph:
return builder

def chat_inference(self, user_message: str, history: List[dict], thread_id: str):

config = {"configurable": {"thread_id": thread_id}}
messages = self.convert_history_to_messages(history)
messages.append(HumanMessage(content=user_message))
66 changes: 31 additions & 35 deletions home/llm/rag_graph.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from langgraph.graph import START, StateGraph, MessagesState, END
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.checkpoint.memory import MemorySaver
from langchain_community.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT
from langchain.chains import GraphCypherQAChain
from langchain_community.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT
from langchain_community.graphs import Neo4jGraph
from langchain_core.documents import Document
from langchain_core.messages import AIMessage
from langchain_experimental.graph_transformers import LLMGraphTransformer

from home.constants.chat_models import *
from home.constants.constants import rag_prompt_text
from home.models.patient import Patient


@@ -16,36 +17,31 @@ def fetch_user_info():

class RAGGraph:
def __init__(self):
self.model = model_claude_3_5_sonnet
memory = MemorySaver()
self.graph = self.build_graph().compile(checkpointer=memory)

def query_knowledge_graph(self, state: MessagesState):
graph = Neo4jGraph()
chain = GraphCypherQAChain.from_llm(graph=graph, llm=self.model, verbose=True, validate_cypher=True,
allow_dangerous_requests=True, top_k=5,
self.patient = Patient.objects.first()
self.graph_query_model = model_claude_3_5_sonnet
self.llm_graph_transformer = LLMGraphTransformer(llm=model_claude_3_5_sonnet)
self.neo4j_graph = Neo4jGraph()

def rag_store_and_query(self, user_message) -> str:
self.__save_to_neo4j(user_message) # Stores the user message extracted knowledge to Neo4J graph
return self.__query_knowledge_graph(user_message)

def __query_knowledge_graph(self, user_message):
chain = GraphCypherQAChain.from_llm(graph=self.neo4j_graph, llm=self.graph_query_model, verbose=False,
validate_cypher=True, allow_dangerous_requests=True, top_k=5,
cypher_prompt=CYPHER_GENERATION_PROMPT)
response = chain.invoke({"query": state['messages']})

return {'messages': [AIMessage(response['result'])]}

def build_graph(self) -> StateGraph:
builder = StateGraph(MessagesState)
builder.add_node("query_knowledge_graph", self.query_knowledge_graph)

builder.add_edge(START, "query_knowledge_graph")
builder.add_edge("query_knowledge_graph", END)

return builder

def rag_inference(self, input_text: str, thread_id: str):
config = {"configurable": {"thread_id": thread_id}}
messages = [HumanMessage(content=
"Here is a summary chat information of " + fetch_user_info() + ":\n"
+ input_text +
"\nGet Medical information analysis of the following person.")]

result = self.graph.invoke({"messages": messages}, config)
assistant_response = result['messages'][-1].content

return assistant_response
response = chain.invoke(rag_prompt_text +
"Current patient name is: " + self.patient.first_name + " " + self.patient.last_name +
"Fetch all related data of the patient and make an analysis. Continue with what you "
"have.",
return_only_outputs=True)['result']
return response

def __save_to_neo4j(self, user_message):
context_prompt = ('The following message is from a patient having the following information : ' +
self.patient.__str__())
documents = [Document(page_content=context_prompt + user_message)]
graph_documents = self.llm_graph_transformer.convert_to_graph_documents(documents)
self.neo4j_graph.add_graph_documents(graph_documents)
self.neo4j_graph.refresh_schema()
11 changes: 10 additions & 1 deletion home/models/patient.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,16 @@ class Patient(models.Model):
doctor_name = models.CharField(max_length=100)

def __str__(self):
return f"{self.first_name} {self.last_name}"
return (f"Patient Name: {self.first_name} {self.last_name}, "
f"Date of Birth: {self.date_of_birth}, "
f"Phone: {self.phone_number}, "
f"Email: {self.email}, "
f"Medical Conditions: {self.medical_conditions}, "
f"Medication: {self.medication_regimen}, "
f"Last Appointment: {self.last_appointment}, "
f"Next Appointment: {self.next_appointment}, "
f"Current Doctor Name: {self.doctor_name}\n")


class Meta:
db_table = 'patient'
28 changes: 0 additions & 28 deletions home/tests.py
Original file line number Diff line number Diff line change
@@ -35,31 +35,3 @@ def test_llm_tool_call_with_summary(self):
ai_response, summary = self.llm_graph.chat_inference(user_message, history, "test_" + str(uuid.uuid4()))
self.assertGreater(len(summary), 0)


class KnowledgeGraphTestCase(TestCase):
def setUp(self):
self.graph = KnowledgeGraph()

def test_add_knowledge_and_query(self):
knowledge_text_data = knowledge_text
self.graph.add_knowledge_to_graph(knowledge_text_data)

user_message = "What awards Marie Curie had?"
response = self.graph.query_graph(user_message)
self.assertTrue( response.lower().__contains__( str("nobel").lower() ))
self.assertTrue(response.lower().__contains__(str("presidential").lower()))



# class RAGTestCase(TestCase):
# def setUp(self):
# self.graph = RAGGraph()
#
# def test_rag_inference(self):
# knowledge_text_data = knowledge_text
# self.graph.add_knowledge_to_graph(knowledge_text_data)
#
# user_message = "What awards Marie Curie had?"
# response = self.graph.query_graph(user_message)
# self.assertTrue( response.lower().__contains__( str("nobel").lower() ))
# self.assertTrue(response.lower().__contains__(str("presidential").lower()))
9 changes: 3 additions & 6 deletions home/urls.py
Original file line number Diff line number Diff line change
@@ -2,12 +2,9 @@
from . import views

urlpatterns = [
path("chat/", views.inference, name='Patient Chat Home'),
path('', views.index, name='Patient Chat Home'),
path("chat/", views.inference, name='Patient Chat Home'),
path('insight/', views.insight, name='insight'),
path('user-info/', views.get_user_info, name='user_info'),

path('thread-id/', views.get_unique_thread_id, name='thread-id'),
path('get-rag/', views.get_rag, name='get-rag')

path('thread-id/', views.get_unique_thread_id, name='thread-id')
]

71 changes: 28 additions & 43 deletions home/views.py
Original file line number Diff line number Diff line change
@@ -23,43 +23,47 @@ def index(request):
@require_http_methods(["POST"])
def inference(request):
data = json.loads(request.body)
message = data['message']
history = data.get('history', [])
user_timestamp = data.get('timestamp')
thread_id = data.get('threadId')
message, history, user_timestamp, thread_id = (
data['message'],
data.get('history', []),
data.get('timestamp'),
data.get('threadId')
)

llm_graph = LLMGraph()

if len(history) >= summarize_trigger_count:
summary = get_latest_summary(patient_id=1, is_user=False, thread_id=thread_id)

history = [
{
"role": "user",
"content": "Provide me with the conversation summary"
},
{
"role": "assistant",
"content": summary
}
{"role": "user", "content": "Provide me with the conversation summary"},
{"role": "assistant", "content": summary}
]

response, summary, tools_called = llm_graph.chat_inference(message, history, thread_id)

user_entry, ai_entry = save_chat_entries_db(message, response, summary, user_timestamp, thread_id)

summary = summary if summary else "No summary available yet. Chat more to get one."

# medical_insights = RAGGraph().rag_inference(summary, thread_id)
return JsonResponse({
'response': response,
'user_timestamp': user_entry.timestamp.timestamp() * 1000,
'ai_timestamp': ai_entry.timestamp.timestamp() * 1000,
'summary': summary,
'summary': summary or "No summary available yet. Chat more to get one.",
'tools': tools_called
})


@csrf_exempt
@require_http_methods(["POST"])
def insight(request):
data = json.loads(request.body)
user_message = data['message']

response = RAGGraph().rag_store_and_query(user_message)
print("response: ", response)
return JsonResponse({
'insight': response
})


def save_chat_entries_db(user_message, ai_response, summary, user_timestamp, thread_id):
user_entry = ChatHistory.objects.create(
patient_id=1,
@@ -83,11 +87,8 @@ def save_chat_entries_db(user_message, ai_response, summary, user_timestamp, thr

def get_latest_summary(patient_id, is_user, thread_id):
latest_chat = ChatHistory.objects.filter(
Q(patient_id=patient_id) &
Q(is_user=is_user) &
Q(thread_id=thread_id)
Q(patient_id=patient_id) & Q(is_user=is_user) & Q(thread_id=thread_id)
).order_by('-timestamp').first()

return latest_chat.summary if latest_chat else None


@@ -106,30 +107,14 @@ def get_user_info(request):
return JsonResponse({'error': 'No patient found'}, status=404)

return JsonResponse({
'first_name': patient.first_name,
'last_name': patient.last_name,
'date_of_birth': patient.date_of_birth,
'phone_number': patient.phone_number,
'email': patient.email,
'medical_conditions': patient.medical_conditions,
'medication_regimen': patient.medication_regimen,
'last_appointment': patient.last_appointment,
'next_appointment': patient.next_appointment,
'doctor_name': patient.doctor_name,
field: getattr(patient, field) for field in
['first_name', 'last_name', 'date_of_birth', 'phone_number', 'email',
'medical_conditions', 'medication_regimen', 'last_appointment',
'next_appointment', 'doctor_name']
}, encoder=DateTimeEncoder)


@csrf_exempt
@require_http_methods(["GET"])
def get_unique_thread_id(request):
return JsonResponse({
'thread_id': uuid.uuid4()
})


@csrf_exempt
@require_http_methods(["GET"])
def get_rag(request):
return JsonResponse({
"data": RAGGraph().rag_inference("", uuid.uuid4())
})
return JsonResponse({'thread_id': uuid.uuid4()})
Loading

0 comments on commit a0a2133

Please sign in to comment.