-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpersonal_assistant.py
101 lines (90 loc) · 3.38 KB
/
personal_assistant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import re
import xml.etree.ElementTree as ET
from anthropic import Anthropic
from personal_assistant_utils import *
from personal_assistant_tools import TOOLS
client = Anthropic(
api_key=os.environ.get("ANTHROPIC"),
)
MODEL_NAME = "claude-3-opus-20240229"
#MODEL_NAME = "claude-3-sonnet-20240229"
memory = []
while True:
question = input("Ask Claude: ")
if question.lower() == "q":
break
add_to_memory("user", question, memory)
message = {
"role": "user",
"content": question
}
names = [tool.name for tool in TOOLS]
descriptions = [tool.description for tool in TOOLS]
parameters = [tool.parameters for tool in TOOLS]
all_tools = construct_format_tool_for_claude_prompt(names, descriptions, parameters)
system_prompt = construct_tool_use_system_prompt([all_tools])
function_calling_message = client.messages.create(
model=MODEL_NAME,
max_tokens=1024,
messages=[message],
system=system_prompt,
stop_sequences=["\nHuman:", "\nAssistant", "</function_calls>"]
).content[0].text + '</function_calls>'
xml_pattern = r'<function_calls>.*?</function_calls>'
xml_parts = re.findall(xml_pattern, function_calling_message, re.DOTALL) # handle thought texts
has_that_tool = False
for xml_part in xml_parts:
root = ET.fromstring(xml_part)
tool_name = root.find('.//tool_name').text
# Call functions
tool_functions = {
"wikipedia_search": wikipedia_search,
"duckduckgo_search": duckduckgo_search,
"save_note": save_note,
"code_executer": code_executer,
"calculator": do_pairwise_arithmetic,
"search_youtube": search_youtube,
"extract_text_from_file": extract_text_from_file
}
for key, func in tool_functions.items():
if tool_name == key:
has_that_tool = True
#print("-"*12)
#print(xml_parts)
#print("-"*12)
inp = root.find('.//user_input').text
result = func(inp)
formatted_results = [{
'tool_name': tool_name,
'tool_result': result
}]
function_results = construct_successful_function_run_injection_prompt(formatted_results)
partial_assistant_message = function_calling_message + "</function_calls>" + function_results # concatinate full answer
add_to_memory("assistant", result, memory)
if not has_that_tool:
final_message = client.messages.create(
model=MODEL_NAME,
max_tokens=1024,
messages=[
message,
{
"role":"assistant",
"content":function_calling_message
}
],
system=system_prompt
).content[0].text
final_message = client.messages.create(
model=MODEL_NAME,
max_tokens=1024,
messages=[
message,
{
"role":"assistant",
"content":partial_assistant_message
}
],
system=system_prompt
).content[0].text
print(partial_assistant_message+final_message)