diff --git a/home/constants/constants.py b/home/constants/constants.py index 7d3dd80..bdedf83 100644 --- a/home/constants/constants.py +++ b/home/constants/constants.py @@ -13,10 +13,10 @@ modify appointments and medication changes. You should filter out and ignore any unrelated, overly sensitive, or controversial topics.''' -rag_prompt_text = '''You are a helpful AI medical assistant specialized in making medical insights. You also take -account patient name, date of birth, medical conditions and medications that are taken by a patient. You will build -your responses based on these information and other provided medical knowledge. Now use your medical expertise to -provide a comprehensive medical insight for the texts below:''' +rag_prompt_text = '''You are a helpful AI medical assistant specialized in summarizing personal medical data, +health issues and their connection to various given symptoms. You also take account patient name, age, +medical conditions and medications. You will build your responses based on these information and other provided +medical knowledge.''' history_one_turn = history = [ { diff --git a/home/llm/llm_graph.py b/home/llm/llm_graph.py index 26cc57a..5d45b1e 100644 --- a/home/llm/llm_graph.py +++ b/home/llm/llm_graph.py @@ -1,3 +1,6 @@ +from langchain_community.graphs import Neo4jGraph +from langchain_core.documents import Document +from langchain_experimental.graph_transformers import LLMGraphTransformer from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import START, StateGraph, MessagesState, END from langgraph.prebuilt import tools_condition, ToolNode @@ -5,7 +8,7 @@ from typing import Literal, List from home.constants import constants -from home.constants.chat_models import model_claude_3_haiku +from home.constants.chat_models import model_claude_3_haiku, model_claude_3_5_sonnet from home.constants.constants import summary_prompt, summarize_trigger_count from home.llm.function_tools.tools import Tools from home.models.patient import Patient @@ -22,13 +25,16 @@ class State(MessagesState): class LLMGraph: def __init__(self): + self.patient = Patient.objects.first() + self.model = model_claude_3_haiku self.model = self.model.bind_tools(tool_list) memory = MemorySaver() self.graph = self.build_graph().compile(checkpointer=memory) def ai_agent(self, state: State): - sys_msg = SystemMessage(content=constants.llm_prompt_text) + sys_msg = SystemMessage(content=constants.llm_prompt_text + "Currently, you are chatting with a patient with " + "the following information: " + self.patient.__str__()) return {"messages": [self.model.invoke([sys_msg] + state["messages"])]} def build_summarize_subgraph(self) -> StateGraph: @@ -61,6 +67,7 @@ def build_graph(self) -> StateGraph: return builder def chat_inference(self, user_message: str, history: List[dict], thread_id: str): + config = {"configurable": {"thread_id": thread_id}} messages = self.convert_history_to_messages(history) messages.append(HumanMessage(content=user_message)) diff --git a/home/llm/rag_graph.py b/home/llm/rag_graph.py index b2ba84c..380d697 100644 --- a/home/llm/rag_graph.py +++ b/home/llm/rag_graph.py @@ -1,11 +1,12 @@ -from langgraph.graph import START, StateGraph, MessagesState, END -from langchain_core.messages import HumanMessage, AIMessage -from langgraph.checkpoint.memory import MemorySaver -from langchain_community.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT from langchain.chains import GraphCypherQAChain +from langchain_community.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT from langchain_community.graphs import Neo4jGraph +from langchain_core.documents import Document +from langchain_core.messages import AIMessage +from langchain_experimental.graph_transformers import LLMGraphTransformer from home.constants.chat_models import * +from home.constants.constants import rag_prompt_text from home.models.patient import Patient @@ -16,36 +17,31 @@ def fetch_user_info(): class RAGGraph: def __init__(self): - self.model = model_claude_3_5_sonnet - memory = MemorySaver() - self.graph = self.build_graph().compile(checkpointer=memory) - - def query_knowledge_graph(self, state: MessagesState): - graph = Neo4jGraph() - chain = GraphCypherQAChain.from_llm(graph=graph, llm=self.model, verbose=True, validate_cypher=True, - allow_dangerous_requests=True, top_k=5, + self.patient = Patient.objects.first() + self.graph_query_model = model_claude_3_5_sonnet + self.llm_graph_transformer = LLMGraphTransformer(llm=model_claude_3_5_sonnet) + self.neo4j_graph = Neo4jGraph() + + def rag_store_and_query(self, user_message) -> str: + self.__save_to_neo4j(user_message) # Stores the user message extracted knowledge to Neo4J graph + return self.__query_knowledge_graph(user_message) + + def __query_knowledge_graph(self, user_message): + chain = GraphCypherQAChain.from_llm(graph=self.neo4j_graph, llm=self.graph_query_model, verbose=False, + validate_cypher=True, allow_dangerous_requests=True, top_k=5, cypher_prompt=CYPHER_GENERATION_PROMPT) - response = chain.invoke({"query": state['messages']}) - - return {'messages': [AIMessage(response['result'])]} - - def build_graph(self) -> StateGraph: - builder = StateGraph(MessagesState) - builder.add_node("query_knowledge_graph", self.query_knowledge_graph) - - builder.add_edge(START, "query_knowledge_graph") - builder.add_edge("query_knowledge_graph", END) - - return builder - - def rag_inference(self, input_text: str, thread_id: str): - config = {"configurable": {"thread_id": thread_id}} - messages = [HumanMessage(content= - "Here is a summary chat information of " + fetch_user_info() + ":\n" - + input_text + - "\nGet Medical information analysis of the following person.")] - - result = self.graph.invoke({"messages": messages}, config) - assistant_response = result['messages'][-1].content - return assistant_response + response = chain.invoke(rag_prompt_text + + "Current patient name is: " + self.patient.first_name + " " + self.patient.last_name + + "Fetch all related data of the patient and make an analysis. Continue with what you " + "have.", + return_only_outputs=True)['result'] + return response + + def __save_to_neo4j(self, user_message): + context_prompt = ('The following message is from a patient having the following information : ' + + self.patient.__str__()) + documents = [Document(page_content=context_prompt + user_message)] + graph_documents = self.llm_graph_transformer.convert_to_graph_documents(documents) + self.neo4j_graph.add_graph_documents(graph_documents) + self.neo4j_graph.refresh_schema() diff --git a/home/models/patient.py b/home/models/patient.py index 84d3780..e97cbae 100644 --- a/home/models/patient.py +++ b/home/models/patient.py @@ -14,7 +14,16 @@ class Patient(models.Model): doctor_name = models.CharField(max_length=100) def __str__(self): - return f"{self.first_name} {self.last_name}" + return (f"Patient Name: {self.first_name} {self.last_name}, " + f"Date of Birth: {self.date_of_birth}, " + f"Phone: {self.phone_number}, " + f"Email: {self.email}, " + f"Medical Conditions: {self.medical_conditions}, " + f"Medication: {self.medication_regimen}, " + f"Last Appointment: {self.last_appointment}, " + f"Next Appointment: {self.next_appointment}, " + f"Current Doctor Name: {self.doctor_name}\n") + class Meta: db_table = 'patient' diff --git a/home/tests.py b/home/tests.py index 8091249..1b7804a 100644 --- a/home/tests.py +++ b/home/tests.py @@ -35,31 +35,3 @@ def test_llm_tool_call_with_summary(self): ai_response, summary = self.llm_graph.chat_inference(user_message, history, "test_" + str(uuid.uuid4())) self.assertGreater(len(summary), 0) - -class KnowledgeGraphTestCase(TestCase): - def setUp(self): - self.graph = KnowledgeGraph() - - def test_add_knowledge_and_query(self): - knowledge_text_data = knowledge_text - self.graph.add_knowledge_to_graph(knowledge_text_data) - - user_message = "What awards Marie Curie had?" - response = self.graph.query_graph(user_message) - self.assertTrue( response.lower().__contains__( str("nobel").lower() )) - self.assertTrue(response.lower().__contains__(str("presidential").lower())) - - - -# class RAGTestCase(TestCase): -# def setUp(self): -# self.graph = RAGGraph() -# -# def test_rag_inference(self): -# knowledge_text_data = knowledge_text -# self.graph.add_knowledge_to_graph(knowledge_text_data) -# -# user_message = "What awards Marie Curie had?" -# response = self.graph.query_graph(user_message) -# self.assertTrue( response.lower().__contains__( str("nobel").lower() )) -# self.assertTrue(response.lower().__contains__(str("presidential").lower())) diff --git a/home/urls.py b/home/urls.py index cd2222e..0811fe3 100644 --- a/home/urls.py +++ b/home/urls.py @@ -2,12 +2,9 @@ from . import views urlpatterns = [ - path("chat/", views.inference, name='Patient Chat Home'), path('', views.index, name='Patient Chat Home'), + path("chat/", views.inference, name='Patient Chat Home'), + path('insight/', views.insight, name='insight'), path('user-info/', views.get_user_info, name='user_info'), - - path('thread-id/', views.get_unique_thread_id, name='thread-id'), - path('get-rag/', views.get_rag, name='get-rag') - + path('thread-id/', views.get_unique_thread_id, name='thread-id') ] - diff --git a/home/views.py b/home/views.py index e799a3c..bf0e7ec 100644 --- a/home/views.py +++ b/home/views.py @@ -23,43 +23,47 @@ def index(request): @require_http_methods(["POST"]) def inference(request): data = json.loads(request.body) - message = data['message'] - history = data.get('history', []) - user_timestamp = data.get('timestamp') - thread_id = data.get('threadId') + message, history, user_timestamp, thread_id = ( + data['message'], + data.get('history', []), + data.get('timestamp'), + data.get('threadId') + ) llm_graph = LLMGraph() if len(history) >= summarize_trigger_count: summary = get_latest_summary(patient_id=1, is_user=False, thread_id=thread_id) - history = [ - { - "role": "user", - "content": "Provide me with the conversation summary" - }, - { - "role": "assistant", - "content": summary - } + {"role": "user", "content": "Provide me with the conversation summary"}, + {"role": "assistant", "content": summary} ] response, summary, tools_called = llm_graph.chat_inference(message, history, thread_id) - user_entry, ai_entry = save_chat_entries_db(message, response, summary, user_timestamp, thread_id) - summary = summary if summary else "No summary available yet. Chat more to get one." - - # medical_insights = RAGGraph().rag_inference(summary, thread_id) return JsonResponse({ 'response': response, 'user_timestamp': user_entry.timestamp.timestamp() * 1000, 'ai_timestamp': ai_entry.timestamp.timestamp() * 1000, - 'summary': summary, + 'summary': summary or "No summary available yet. Chat more to get one.", 'tools': tools_called }) +@csrf_exempt +@require_http_methods(["POST"]) +def insight(request): + data = json.loads(request.body) + user_message = data['message'] + + response = RAGGraph().rag_store_and_query(user_message) + print("response: ", response) + return JsonResponse({ + 'insight': response + }) + + def save_chat_entries_db(user_message, ai_response, summary, user_timestamp, thread_id): user_entry = ChatHistory.objects.create( patient_id=1, @@ -83,11 +87,8 @@ def save_chat_entries_db(user_message, ai_response, summary, user_timestamp, thr def get_latest_summary(patient_id, is_user, thread_id): latest_chat = ChatHistory.objects.filter( - Q(patient_id=patient_id) & - Q(is_user=is_user) & - Q(thread_id=thread_id) + Q(patient_id=patient_id) & Q(is_user=is_user) & Q(thread_id=thread_id) ).order_by('-timestamp').first() - return latest_chat.summary if latest_chat else None @@ -106,30 +107,14 @@ def get_user_info(request): return JsonResponse({'error': 'No patient found'}, status=404) return JsonResponse({ - 'first_name': patient.first_name, - 'last_name': patient.last_name, - 'date_of_birth': patient.date_of_birth, - 'phone_number': patient.phone_number, - 'email': patient.email, - 'medical_conditions': patient.medical_conditions, - 'medication_regimen': patient.medication_regimen, - 'last_appointment': patient.last_appointment, - 'next_appointment': patient.next_appointment, - 'doctor_name': patient.doctor_name, + field: getattr(patient, field) for field in + ['first_name', 'last_name', 'date_of_birth', 'phone_number', 'email', + 'medical_conditions', 'medication_regimen', 'last_appointment', + 'next_appointment', 'doctor_name'] }, encoder=DateTimeEncoder) @csrf_exempt @require_http_methods(["GET"]) def get_unique_thread_id(request): - return JsonResponse({ - 'thread_id': uuid.uuid4() - }) - - -@csrf_exempt -@require_http_methods(["GET"]) -def get_rag(request): - return JsonResponse({ - "data": RAGGraph().rag_inference("", uuid.uuid4()) - }) \ No newline at end of file + return JsonResponse({'thread_id': uuid.uuid4()}) diff --git a/static/js/apiService.js b/static/js/apiService.js index fc2d68d..487cbe9 100644 --- a/static/js/apiService.js +++ b/static/js/apiService.js @@ -1,57 +1,61 @@ -// apiService.js import { getCookie } from './utils.js'; -export async function sendMessageToAPI(message, chatHistory, userTimestamp, threadId) { - // Use array destructuring with the rest element at the end - const [lastElement, ...historyWithoutLast] = [...chatHistory].reverse(); - - const data = { - userType: 'patient', - message: message, - history: historyWithoutLast.reverse(), // Reverse back to original order - timestamp: userTimestamp, - threadId: threadId - }; - - const response = await fetch('/chat/', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'X-CSRFToken': getCookie('csrftoken') - }, - body: JSON.stringify(data) - }); - - if (!response.ok) { - throw new Error(`HTTP error! status: ${response.status}`); - } - - return await response.json(); -} - -export async function fetchUserInfo() { - try { - const response = await fetch('/user-info/'); - if (!response.ok) { - throw new Error(`HTTP error! status: ${response.status}`); - } - return await response.json(); - } catch (error) { - console.error('Error fetching user info:', error); - throw error; - } -} - - -export async function fetchThreadId() { - try { - const response = await fetch('/thread-id/'); - if (!response.ok) { - throw new Error(`HTTP error! status: ${response.status}`); - } - return await response.json(); - } catch (error) { - console.error('Error fetching user info:', error); - throw error; - } -} \ No newline at end of file +const BASE_URL = ''; // Add your base URL here if needed + +const handleResponse = async (response) => { + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + return response.json(); +}; + +const fetchWithErrorHandling = async (url, options = {}) => { + try { + const response = await fetch(BASE_URL + url, options); + return await handleResponse(response); + } catch (error) { + console.error(`Error fetching ${url}:`, error); + throw error; + } +}; + +export const sendMessageToAPI = async (message, chatHistory, userTimestamp, threadId) => { + const [lastElement, ...historyWithoutLast] = [...chatHistory].reverse(); + + const data = { + userType: 'patient', + message, + history: historyWithoutLast.reverse(), + timestamp: userTimestamp, + threadId + }; + + return fetchWithErrorHandling('/chat/', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'X-CSRFToken': getCookie('csrftoken') + }, + body: JSON.stringify(data) + }); +}; + + +export const sendInsightAPI = async (message) => { + const data = { + message: message + }; + + return fetchWithErrorHandling('/insight/', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'X-CSRFToken': getCookie('csrftoken') + }, + body: JSON.stringify(data) + }); +}; + +export const fetchUserInfo = () => fetchWithErrorHandling('/user-info/'); + +export const fetchThreadId = () => fetchWithErrorHandling('/thread-id/'); \ No newline at end of file diff --git a/static/js/main.js b/static/js/main.js index 252dbf8..5f8ed26 100644 --- a/static/js/main.js +++ b/static/js/main.js @@ -1,113 +1,83 @@ -import {userInfoElement, userInput, sendButton, conversationSummary, medicalInsights, actions} from './uiElements.js'; +import { + userInfoElement, userInput, sendButton, conversationSummary, actions, medicalInsights +} from './uiElements.js'; import { addMessage, updateAIMessage } from './messageHandler.js'; -import { sendMessageToAPI, fetchUserInfo, fetchThreadId } from './apiService.js'; +import {sendMessageToAPI, fetchUserInfo, fetchThreadId, sendInsightAPI} from './apiService.js'; let chatHistory = []; let threadId = null; -async function loadUserInfo() { - try { - const userInfo = await fetchUserInfo(); - const userInfoElement = document.getElementById('user-info'); - userInfoElement.innerHTML = ` -
Date of Birth: ${userInfo.date_of_birth}
-Phone: ${userInfo.phone_number}
-Email: ${userInfo.email}
-Medical Conditions: ${userInfo.medical_conditions}
-Medication: ${userInfo.medication_regimen}
-Last Appointment: ${userInfo.last_appointment}
-Next Appointment: ${userInfo.next_appointment}
-Doctor: ${userInfo.doctor_name}
- `; - } catch (error) { - console.error('Failed to load user info:', error); - // Optionally, display an error message to the user - } -} - -async function saveThreadId() { - try { - const response = await fetchThreadId(); - threadId = response['thread_id'] - } catch (error) { - console.error('Failed to get thread id:', error); - // Optionally, display an error message to the user - } -} - -async function sendMessage() { - const message = userInput.value.trim(); - if (message) { - const userTimestamp = Date.now(); - addMessage(message, true, userTimestamp); - userInput.value = ''; - - chatHistory.push({ role: 'user', content: message }); - - const aiMessageElement = addMessage('', false, null); - - try { - const responseData = await sendMessageToAPI(message, chatHistory, userTimestamp, threadId); - const aiMessage = responseData.response; - const aiTimestamp = responseData.ai_timestamp; - updateAIMessage(aiMessageElement, aiMessage, aiTimestamp); - conversationSummary.innerText = responseData.summary; - console.log(responseData.tools); - updateActionsCards(responseData.tools); - console.log(actions); - console.log(actions.innerText); - - chatHistory.push({ role: 'assistant', content: aiMessage }); - - } catch (error) { - console.error('Error:', error); - updateAIMessage(aiMessageElement, "Sorry, there was an error processing your request.", Date.now()); - } - } -} +const loadUserInfo = async () => { + try { + const userInfo = await fetchUserInfo(); + userInfoElement.innerHTML = ` +${key.replace('_', ' ').charAt(0).toUpperCase() + key.slice(1).replace('_', ' ')}: ${userInfo[key]}
`) + .join('')} + `; + } catch (error) { + console.error('Failed to load user info:', error); + } +}; + +const saveThreadId = async () => { + try { + const response = await fetchThreadId(); + threadId = response.thread_id; + } catch (error) { + console.error('Failed to get thread id:', error); + } +}; + +const sendMessage = async () => { + const message = userInput.value.trim(); + if (!message) return; + + const userTimestamp = Date.now(); + addMessage(message, true, userTimestamp); + userInput.value = ''; + + chatHistory.push({ role: 'user', content: message }); + + const aiMessageElement = addMessage('', false, null); + + try { + // Run both API calls in parallel + const [responseData, insightResponse] = await Promise.all([ + sendMessageToAPI(message, chatHistory, userTimestamp, threadId), + sendInsightAPI(message) + ]); + + updateAIMessage(aiMessageElement, responseData.response, responseData.ai_timestamp); + conversationSummary.innerText = responseData.summary; + medicalInsights.innerText = insightResponse.insight; + updateActionsCards(responseData.tools); + + chatHistory.push({ role: 'assistant', content: responseData.response }); + insightResponse + } catch (error) { + console.error('Error:', error); + updateAIMessage(aiMessageElement, "Sorry, there was an error processing your request.", Date.now()); + } +}; + +const updateActionsCards = (tools) => { + actions.innerHTML = tools && tools.length + ? tools.map(tool => ` +