-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b3a7b53
commit d07002e
Showing
9 changed files
with
529 additions
and
13 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from langchain_anthropic import ChatAnthropic | ||
|
||
|
||
class LLM: | ||
def __init__(self, model_name: str = "claude-3-haiku-20240307"): | ||
self.model_name = model_name | ||
llm = ChatAnthropic(model=model_name) | ||
|
||
self.tools = [request_medication_change, make_appointment, request_appointment_change] | ||
self.llm_with_tools = llm.bind_tools(self.tools) | ||
|
||
|
||
def request_medication_change(previous_medication: str) -> str: | ||
"""Puts a request to the doctor for medication change. | ||
Returns a string with the name of current medication and Change request submitted. | ||
Args: | ||
previous_medication: first str | ||
""" | ||
return "Change request submitted for " + previous_medication | ||
|
||
|
||
def make_appointment(date: str, reason: str) -> str: | ||
"""Puts an appointment request on the specified date and reason. | ||
Returns a string with the reason for the requested appointment and date for the appointment | ||
Args: | ||
date: first str | ||
reason: second str | ||
""" | ||
return "Appointment requested on " + date + " for " + reason | ||
|
||
|
||
def request_appointment_change(past_date: str, requested_date: str) -> str: | ||
"""Puts an appointment change request for a given date and requested. | ||
Returns a string with the changed and previous date for appointment. | ||
Args: | ||
past_date: first str | ||
requested_date: second str | ||
""" | ||
return past_date + " " + requested_date |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from langgraph.graph import START, StateGraph | ||
from langgraph.graph.state import CompiledStateGraph | ||
from langgraph.prebuilt import tools_condition, ToolNode | ||
|
||
from langgraph.graph import MessagesState | ||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | ||
|
||
from home.llm.llm import LLM | ||
|
||
|
||
class LLMGraph: | ||
prompt_text = ''' | ||
You are a helpful AI medical assistant. You should only respond to health-related topics such as: | ||
- General human health and lifestyle inquiries. | ||
- Questions about the patient’s medical condition, medication regimen, diet, etc. | ||
- Various requests from the patient to their doctor such as make appointments, modify appointments and medication changes. | ||
You should filter out and ignore any unrelated, sensitive, or controversial topics. | ||
''' | ||
|
||
def __init__(self): | ||
self.llm = LLM() | ||
self.graph = self.build_graph() | ||
|
||
|
||
def assistant(self, state: MessagesState): | ||
# Prompt message | ||
sys_msg = SystemMessage(content=self.prompt_text) | ||
return {"messages": [self.llm.llm_with_tools.invoke([sys_msg] + state["messages"])]} | ||
|
||
def build_graph(self) -> CompiledStateGraph: | ||
builder = StateGraph(MessagesState) | ||
|
||
# Define nodes: these do the work | ||
builder.add_node("assistant", self.assistant) | ||
builder.add_node("tools", ToolNode(self.llm.tools)) | ||
|
||
# Define edges: these determine how the control flow moves | ||
builder.add_edge(START, "assistant") | ||
builder.add_conditional_edges( | ||
"assistant", | ||
# If the latest message (result) from assistant is a tool call -> tools_condition routes to tools | ||
# If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END | ||
tools_condition, | ||
) | ||
builder.add_edge("tools", "assistant") | ||
|
||
return builder.compile() | ||
|
||
def inference(self, user_message) -> str: | ||
messages = [HumanMessage(content=user_message)] | ||
messages = self.graph.invoke({"messages": messages}) | ||
for m in messages['messages']: | ||
m.pretty_print() | ||
|
||
print() | ||
return messages['messages'][-1].content |
Oops, something went wrong.