Skip to content

Commit

Permalink
feat: messages for map actions
Browse files Browse the repository at this point in the history
fix: code cleanup
  • Loading branch information
iisakkirotko committed Nov 9, 2023
1 parent 06e15bc commit 9d59157
Showing 1 changed file with 38 additions and 56 deletions.
94 changes: 38 additions & 56 deletions wanderlust.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import ipyleaflet
from openai import OpenAI, NotFoundError
from openai.types.beta import Thread
from openai.types.beta.threads import Run

import time

Expand All @@ -13,9 +12,7 @@
center_default = (0, 0)
zoom_default = 2

messages_default = []

messages = solara.reactive(messages_default)
messages = solara.reactive([])
zoom_level = solara.reactive(zoom_default)
center = solara.reactive(center_default)
markers = solara.reactive([])
Expand All @@ -25,6 +22,7 @@
model = "gpt-4-1106-preview"


# Declare tools for openai assistant to use
tools = [
{
"type": "function",
Expand Down Expand Up @@ -80,7 +78,6 @@


def update_map(longitude, latitude, zoom):
print("update_map", longitude, latitude, zoom)
center.set((latitude, longitude))
zoom_level.set(zoom)
return "Map updated"
Expand Down Expand Up @@ -111,12 +108,9 @@ def ai_call(tool_call):

@solara.component
def Map():
print("Map", zoom_level.value, center.value, markers.value)
ipyleaflet.Map.element( # type: ignore
zoom=zoom_level.value,
# on_zoom=zoom_level.set,
center=center.value,
# on_center=center.set,
scroll_wheel_zoom=True,
layers=[
ipyleaflet.TileLayer.element(url=url),
Expand All @@ -134,7 +128,6 @@ def ChatInterface():
run_id: solara.Reactive[str] = solara.use_reactive(None)

thread: Thread = solara.use_memo(openai.beta.threads.create, dependencies=[])
print("thread id:", thread.id)

def add_message(value: str):
if value == "":
Expand All @@ -149,7 +142,6 @@ def add_message(value: str):
assistant_id="asst_RqVKAzaybZ8un7chIwPCIQdH",
tools=tools,
).id
print("Run id:", run_id.value)

def poll():
if not run_id.value:
Expand All @@ -159,14 +151,16 @@ def poll():
try:
run = openai.beta.threads.runs.retrieve(
run_id.value, thread_id=thread.id
) # When run is complete
)
# Above will raise NotFoundError when run creation is still in progress
except NotFoundError:
continue
if run.status == "requires_action":
tool_outputs = []
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
tool_output = ai_call(tool_call)
tool_outputs.append(tool_output)
messages.set([*messages.value, tool_output])
openai.beta.threads.runs.submit_tool_outputs(
thread_id=thread.id,
run_id=run_id.value,
Expand All @@ -182,27 +176,10 @@ def poll():
run_id.set(None)
completed = True
time.sleep(0.1)
retrieved_messages = openai.beta.threads.messages.list(thread_id=thread.id)
messages.set(retrieved_messages.data)

result = solara.use_thread(poll, dependencies=[run_id.value])

def handle_message(message):
print("handle", message)
messages = []
if message.role == "assistant":
tools_calls = message.get("tool_calls", [])
for tool_call in tools_calls:
messages.append(ai_call(tool_call))
return messages

def handle_initial():
print("handle initial", messages.value)
for message in messages.value:
handle_message(message)

solara.use_effect(handle_initial, [])
# result = solara.use_thread(ask, dependencies=[messages.value])
# Create DOM for chat interface
with solara.Column(
classes=["chat-interface"],
):
Expand All @@ -214,16 +191,25 @@ def handle_initial():
"overflow-y": "auto",
"height": "100px",
"flex-direction": "column-reverse",
}
},
classes=["chat-box"],
):
for message in reversed(messages.value):
with solara.Row(style={"align-items": "flex-start"}):
if message.role == "user":
# Catch "messages" that are actually tool calls
if isinstance(message, dict):
icon = (
"mdi-map"
if message["output"] == "Map updated"
else "mdi-map-marker"
)
solara.v.Icon(children=[icon], style_="padding-top: 10px;")
solara.Markdown(message["output"])
elif message.role == "user":
solara.Text(
message.content[0].text.value,
classes=["chat-message", "user-message"],
)
assert len(message.content) == 1
elif message.role == "assistant":
if message.content[0].text.value:
solara.v.Icon(
Expand All @@ -246,8 +232,6 @@ def handle_initial():
repr(message),
classes=["chat-message", "assistant-message"],
)
elif message["role"] == "tool":
pass # no need to display
else:
solara.v.Icon(
children=["mdi-compass-outline"],
Expand All @@ -272,21 +256,6 @@ def handle_initial():

@solara.component
def Page():
reset_counter, set_reset_counter = solara.use_state(0)
print("reset", reset_counter, f"chat-{reset_counter}")

def reset_ui():
set_reset_counter(reset_counter + 1)

def save():
with open("log.json", "w") as f:
json.dump(messages.value, f)

def load():
with open("log.json", "r") as f:
messages.set(json.load(f))
reset_ui()

with solara.Column(
classes=["ui-container"],
gap="5vh",
Expand All @@ -299,16 +268,12 @@ def load():
unsafe_innerHTML="Wanderlust",
style={"display": "inline-block"},
)
# with solara.Row(gap="10px"):
# solara.Button("Save", on_click=save)
# solara.Button("Load", on_click=load)
# solara.Button("Soft reset", on_click=reset_ui)
with solara.Row(
justify="space-between", style={"flex-grow": "1"}, classes=["container-row"]
):
ChatInterface().key(f"chat-{reset_counter}")
ChatInterface() # .key(f"chat-{reset_counter}")
with solara.Column(classes=["map-container"]):
Map() # .key(f"map-{reset_counter}")
Map()

solara.Style(
"""
Expand All @@ -335,13 +300,30 @@ def load():
height: 100%;
width: 38vw;
justify-content: center;
background: linear-gradient(0deg, transparent 75%, white 100%);
position: relative;
}
.chat-interface:after {
content: "";
position: absolute;
z-index: 1;
top: 0;
left: 0;
pointer-events: none;
background-image: linear-gradient(to top, rgba(255,255,255,0), rgba(255,255,255, 1) 100%);
width: 100%;
height: 15%;
}
.chat-box > :last-child{
padding-top: 7.5vh;
}
.map-container{
width: 50vw;
height: 100%;
justify-content: center;
}
.user-message{
font-weight: bold;
}
@media screen and (max-aspect-ratio: 1/1) {
.ui-container{
padding: 30px;
Expand Down

0 comments on commit 9d59157

Please sign in to comment.