Skip to content

Commit

Permalink
Chat history context added
Browse files Browse the repository at this point in the history
  • Loading branch information
sajedjalil committed Sep 27, 2024
1 parent a2dbd75 commit cefd176
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 80 deletions.
55 changes: 27 additions & 28 deletions home/llm/llm_graph.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,55 @@
from langgraph.graph import START, StateGraph
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage

from langgraph.graph import MessagesState
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

import logging
from home.llm.llm import LLM

logger = logging.getLogger(__name__)


class LLMGraph:
prompt_text = '''
You are a helpful AI medical assistant. You should only respond to health-related topics such as:
- General human health and lifestyle inquiries.
- Questions about the patient’s medical condition, medication regimen, diet, etc.
- Various requests from the patient to their doctor such as make appointments, modify appointments and medication changes.
You should filter out and ignore any unrelated, sensitive, or controversial topics.
'''
prompt_text = '''You are a helpful AI medical assistant namely Patient Chat and are developed by a software
engineer named Sajed.
You should only respond to health-related topics such as:
- General human health and lifestyle inquiries.
- Questions about men, women and children health
- Questions about the patient's medical condition, medication regimen, diet, etc.
- Various requests from the patient to their doctor such as make appointments, modify appointments and medication changes.
You should filter out and ignore any unrelated, overly sensitive, or controversial topics.'''

def __init__(self):
self.llm = LLM()
self.graph = self.build_graph()


def assistant(self, state: MessagesState):
# Prompt message
sys_msg = SystemMessage(content=self.prompt_text)
return {"messages": [self.llm.llm_with_tools.invoke([sys_msg] + state["messages"])]}

def build_graph(self) -> CompiledStateGraph:
builder = StateGraph(MessagesState)

# Define nodes: these do the work
builder.add_node("assistant", self.assistant)
builder.add_node("tools", ToolNode(self.llm.tools))

# Define edges: these determine how the control flow moves
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
"assistant",
# If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
# If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
tools_condition,
)
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")

return builder.compile()

def inference(self, user_message) -> str:
messages = [HumanMessage(content=user_message)]
messages = self.graph.invoke({"messages": messages})
for m in messages['messages']:
m.pretty_print()
def inference(self, user_message, history) -> str:
messages = []
for msg in history:
if msg['role'] == 'user':
messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
messages.append(AIMessage(content=msg['content']))

messages.append(HumanMessage(content=user_message))

result = self.graph.invoke({"messages": messages})
logger.debug(result)

print()
return messages['messages'][-1].content
return result['messages'][-1].content
26 changes: 9 additions & 17 deletions home/views.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,25 @@
from django.shortcuts import render, HttpResponse

from django.shortcuts import render
from home.llm.llm_graph import LLMGraph


def inference(request):
return HttpResponse(LLMGraph().inference("How to avoid diabetes."))
from django.http import JsonResponse
from django.views.decorators.csrf import csrf_exempt
import json


def index(request):
return render(request, 'index.html')


from django.http import JsonResponse
from django.views.decorators.csrf import ensure_csrf_cookie
import json


@ensure_csrf_cookie
@csrf_exempt
def inference(request):
if request.method == 'POST':
data = json.loads(request.body)
user_type = data.get('userType')
message = data.get('message')
history = data.get('history')
history = data.get('history', [])

# Process the message and generate a response
# This is where you'd call your AI model or processing logic
response = LLMGraph().inference(message)
llm_graph = LLMGraph()
response = llm_graph.inference(message, history)

return JsonResponse({'response': response})
else:
return JsonResponse({'error': 'Only POST requests are allowed'}, status=405)
return JsonResponse({'error': 'There was an error in the request to the server.'}, status=405)
91 changes: 87 additions & 4 deletions static/css/style.css
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,63 @@ body {
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
min-height: 100vh;
background-color: #f0f0f0;
}

.chat-container {
width: 80%;
width: 100%;
max-width: 600px;
height: 100vh;
max-height: 800px;
display: flex;
flex-direction: column;
background-color: white;
border-radius: 8px;
box-shadow: 0 0 10px rgba(0,0,0,0.1);
overflow: hidden;
}

.chat-messages {
height: 400px;
flex-grow: 1;
overflow-y: auto;
padding: 20px;
display: flex;
flex-direction: column;
}

.message {
max-width: 80%;
margin-bottom: 10px;
padding: 10px;
border-radius: 5px;
word-wrap: break-word;
}

.user-message {
background-color: #e6f2ff;
text-align: right;
align-self: flex-end;
}

.ai-message {
background-color: #f0f0f0;
align-self: flex-start;
}

.input-area {
display: flex;
padding: 20px;
border-top: 1px solid #e0e0e0;
}

#user-input {
flex-grow: 1;
padding: 10px;
border: 1px solid #ccc;
border-radius: 4px;
font-size: 16px;
}

#send-button {
padding: 10px 20px;
background-color: #007bff;
Expand All @@ -52,4 +70,69 @@ body {
border-radius: 4px;
margin-left: 10px;
cursor: pointer;
font-size: 16px;
}

/* Markdown Styles */
.ai-message p {
margin: 0 0 10px 0;
}

.ai-message h1, .ai-message h2, .ai-message h3, .ai-message h4, .ai-message h5, .ai-message h6 {
margin: 15px 0 10px 0;
}

.ai-message ul, .ai-message ol {
margin: 10px 0;
padding-left: 20px;
}

.ai-message code {
background-color: #f8f8f8;
padding: 2px 4px;
border-radius: 4px;
}

.ai-message pre {
background-color: #f8f8f8;
padding: 10px;
border-radius: 4px;
overflow-x: auto;
}

.ai-message blockquote {
border-left: 4px solid #ccc;
margin: 10px 0;
padding-left: 10px;
color: #666;
}

/* Responsive adjustments */
@media (max-width: 768px) {
body {
align-items: flex-start;
}

.chat-container {
width: 100%;
height: 100vh;
max-height: none;
border-radius: 0;
}

.message {
max-width: 90%;
}
}

@media (max-width: 480px) {
.input-area {
flex-direction: column;
}

#send-button {
margin-left: 0;
margin-top: 10px;
width: 100%;
}
}
61 changes: 31 additions & 30 deletions static/js/script.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,25 @@ function addMessage(message, isUser) {
const messageElement = document.createElement('div');
messageElement.classList.add('message');
messageElement.classList.add(isUser ? 'user-message' : 'ai-message');
messageElement.textContent = message;

if (isUser) {
messageElement.innerHTML = marked.parse(message);
} else {
messageElement.innerHTML = '<div class="typing-indicator">AI is typing...</div>';
}

chatMessages.appendChild(messageElement);
chatMessages.scrollTop = chatMessages.scrollHeight;
return messageElement;
}

function updateAIMessage(messageElement, content) {
const typingIndicator = messageElement.querySelector('.typing-indicator');
if (typingIndicator) {
typingIndicator.remove();
}
messageElement.innerHTML = marked.parse(content);
chatMessages.scrollTop = chatMessages.scrollHeight;
}

async function sendMessage() {
Expand All @@ -19,19 +35,22 @@ async function sendMessage() {
addMessage(message, true);
userInput.value = '';

// Prepare the data to send to the API
// Add user message to chat history
chatHistory.push({ role: 'user', content: message });

const data = {
userType: 'patient', // You can change this as needed
userType: 'patient',
message: message,
history: chatHistory
};

const aiMessageElement = addMessage('', false);

try {
const response = await fetch('/chat/', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-CSRFToken': getCookie('csrftoken') // Function to get CSRF token
'Content-Type': 'application/json'
},
body: JSON.stringify(data)
});
Expand All @@ -40,18 +59,16 @@ async function sendMessage() {
throw new Error(`HTTP error! status: ${response.status}`);
}

const result = await response.json();
const responseData = await response.json();
const aiMessage = responseData.response;
updateAIMessage(aiMessageElement, aiMessage);

// Add the AI's response to the chat
addMessage(result.response, false);

// Update chat history
chatHistory.push({ role: 'user', content: message });
chatHistory.push({ role: 'assistant', content: result.response });
// Add AI message to chat history
chatHistory.push({ role: 'assistant', content: aiMessage });

} catch (error) {
console.error('Error:', error);
addMessage("Sorry, there was an error processing your request.", false);
updateAIMessage(aiMessageElement, "Sorry, there was an error processing your request.");
}
}
}
Expand All @@ -61,20 +78,4 @@ userInput.addEventListener('keypress', (e) => {
if (e.key === 'Enter') {
sendMessage();
}
});

// Function to get CSRF token from cookies
function getCookie(name) {
let cookieValue = null;
if (document.cookie && document.cookie !== '') {
const cookies = document.cookie.split(';');
for (let i = 0; i < cookies.length; i++) {
const cookie = cookies[i].trim();
if (cookie.substring(0, name.length + 1) === (name + '=')) {
cookieValue = decodeURIComponent(cookie.substring(name.length + 1));
break;
}
}
}
return cookieValue;
}
});
3 changes: 2 additions & 1 deletion templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>AI Doctor</title>
<title>AI Chat Application</title>
<link rel="stylesheet" href="{% static 'css/style.css' %}">
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
</head>
<body>
<div class="chat-container">
Expand Down

0 comments on commit cefd176

Please sign in to comment.