-
Notifications
You must be signed in to change notification settings - Fork 0
/
assistant_utils.py
183 lines (137 loc) · 5.89 KB
/
assistant_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
Author: Buddhi W
Date: 07/25/2024
Functions related to AI assistant that computes current travel distance between two locations for a given mode of travel.
"""
from openai import OpenAI
import openai
import requests
from dotenv import load_dotenv
import os
load_dotenv()
client = OpenAI(api_key = os.getenv("OPENAI_API_KEY"))
def get_user_input(user_query: str) -> str:
"""
Natural language processing of the user query. Identifies relevant information and irrelevant queries.
Parameters:
user_query (str): Input obtained through web UI.
Returns:
str: Processed query.
"""
# OpenAI API call
completion = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are an assistant that processes user queries about current travel duration between two locations for a given mode of transportation"},
{"role": "user", "content": f"Extract origin location, destination location and mode of travel from this query: {user_query}. Return None for missing values.\
If the query is unrelated to travel duration, return 'OutOfContext'."}
]
)
return completion.choices[0].message.content.strip()
def parse_user_input(user_query: str) -> (tuple[str, str, str] | str):
"""
Process user query and extract origin, destination and mode of travel. Handle irrelevant queries.
Parameters:
user_query (str): Input obtained through web UI.
Returns:
tuple: Extracted information.
"""
# API call reads the user query and returns origin, destination and mode in a single string
user_input = get_user_input(user_query)
if user_input == 'OutOfContext':
return user_input
# Splitting the string and extracting information
data = user_input.split('\n')
origin = data[0].split(':')[1].strip()
destination = data[1].split(':')[1].strip()
mode = data[2].split(':')[1].strip()
return origin, destination, mode
def get_travel_duration(origin:str, destination:str, mode:str) -> tuple[str, list]:
"""
Compute travel duration using Google Maps API.
Parameters:
origin (str): Starting location.
destination (str): Travel destination.
mode (str): Mode of travel.
Returns:
tuple: Travel duration, list of error messages for incomplete queries.
"""
api_key = os.getenv("GOOGLE_MAPS_API_KEY")
url = "https://maps.googleapis.com/maps/api/distancematrix/json"
params = {
"origins": origin,
"destinations": destination,
"mode": mode,
"key": api_key
}
# Google Maps API call
response = requests.get(url, params=params).json()
# Handling incomplete queries and recording corresponding error messages
error_messages = []
if not response['destination_addresses'][0]:
error_messages.append('Please provide valid destination address.')
if not response['origin_addresses'][0]:
error_messages.append('Please provide valid starting address.')
if mode == 'None':
error_messages.append('Please provide valid mode of transportation')
# Computing duration for valid inputs
if response['status'] == 'OK' and response['rows'][0]['elements'][0]['status'] == 'OK':
duration = response['rows'][0]['elements'][0]['duration']['text']
return duration, error_messages
return None, error_messages
def generate_output(origin:str, destination:str, mode:str, duration:str) -> str:
"""
Generate natural language output for valid queries.
Parameters:
origin (str): Starting location.
destination (str): Travel destination.
mode (str): Mode of travel.
duration (str): Duration of travel
Returns:
str: Output message.
"""
# This API call combines computed information into a natural language output
completion = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are an assistant that combines given four inputs into a short, clear, to the point natural language output"},
{"role": "user", "content": f"Combine following information into a response: Origin location: {origin}, destination location: {destination}, travel mode:{mode} and duration: {duration}."}
]
)
return completion.choices[0].message.content
def generate_output_error(error_messages: list) -> str:
"""
Generate natural language output for invalid queries.
Parameters:
error_messages (list[str]): List of error messages corresponding to the invalid/missing information
Returns:
str: Error message in natural language.
"""
# This API call generates natural language error messages using the list of errors given as input to the function
prompt = f"Combine the given error messages: {error_messages} into a clear and concise error message"
completion = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are an assistant that lets the user know about invalid inputs."},
{"role": "user", "content": prompt}
]
)
return completion.choices[0].message.content
def run_assistant(user_query: str) -> str:
"""
Run the AI assistant pipeline.
Parameters:
user_query (str): Input obtained through web UI.
Returns:
str: Output displayed on the web UI.
"""
user_input = parse_user_input(user_query)
if user_input == 'OutOfContext':
return "I'm sorry, I did not understand your question. Please input a query related to travel duration calculation."
else:
duration, error = get_travel_duration(user_input[0], user_input[1], user_input[2])
if duration:
output = generate_output(user_input[0], user_input[1], user_input[2], duration)
else:
output = generate_output_error(error) # Invoke error output if the duration is not computed
return output