From 9d591573bd0b085bdc4c33863cc8c8ba8f751f46 Mon Sep 17 00:00:00 2001 From: Iisakki Rotko Date: Thu, 9 Nov 2023 12:12:22 +0100 Subject: [PATCH] feat: messages for map actions fix: code cleanup --- wanderlust.py | 94 +++++++++++++++++++++------------------------------ 1 file changed, 38 insertions(+), 56 deletions(-) diff --git a/wanderlust.py b/wanderlust.py index e0641af..0b75466 100644 --- a/wanderlust.py +++ b/wanderlust.py @@ -4,7 +4,6 @@ import ipyleaflet from openai import OpenAI, NotFoundError from openai.types.beta import Thread -from openai.types.beta.threads import Run import time @@ -13,9 +12,7 @@ center_default = (0, 0) zoom_default = 2 -messages_default = [] - -messages = solara.reactive(messages_default) +messages = solara.reactive([]) zoom_level = solara.reactive(zoom_default) center = solara.reactive(center_default) markers = solara.reactive([]) @@ -25,6 +22,7 @@ model = "gpt-4-1106-preview" +# Declare tools for openai assistant to use tools = [ { "type": "function", @@ -80,7 +78,6 @@ def update_map(longitude, latitude, zoom): - print("update_map", longitude, latitude, zoom) center.set((latitude, longitude)) zoom_level.set(zoom) return "Map updated" @@ -111,12 +108,9 @@ def ai_call(tool_call): @solara.component def Map(): - print("Map", zoom_level.value, center.value, markers.value) ipyleaflet.Map.element( # type: ignore zoom=zoom_level.value, - # on_zoom=zoom_level.set, center=center.value, - # on_center=center.set, scroll_wheel_zoom=True, layers=[ ipyleaflet.TileLayer.element(url=url), @@ -134,7 +128,6 @@ def ChatInterface(): run_id: solara.Reactive[str] = solara.use_reactive(None) thread: Thread = solara.use_memo(openai.beta.threads.create, dependencies=[]) - print("thread id:", thread.id) def add_message(value: str): if value == "": @@ -149,7 +142,6 @@ def add_message(value: str): assistant_id="asst_RqVKAzaybZ8un7chIwPCIQdH", tools=tools, ).id - print("Run id:", run_id.value) def poll(): if not run_id.value: @@ -159,7 +151,8 @@ def poll(): try: run = openai.beta.threads.runs.retrieve( run_id.value, thread_id=thread.id - ) # When run is complete + ) + # Above will raise NotFoundError when run creation is still in progress except NotFoundError: continue if run.status == "requires_action": @@ -167,6 +160,7 @@ def poll(): for tool_call in run.required_action.submit_tool_outputs.tool_calls: tool_output = ai_call(tool_call) tool_outputs.append(tool_output) + messages.set([*messages.value, tool_output]) openai.beta.threads.runs.submit_tool_outputs( thread_id=thread.id, run_id=run_id.value, @@ -182,27 +176,10 @@ def poll(): run_id.set(None) completed = True time.sleep(0.1) - retrieved_messages = openai.beta.threads.messages.list(thread_id=thread.id) - messages.set(retrieved_messages.data) result = solara.use_thread(poll, dependencies=[run_id.value]) - def handle_message(message): - print("handle", message) - messages = [] - if message.role == "assistant": - tools_calls = message.get("tool_calls", []) - for tool_call in tools_calls: - messages.append(ai_call(tool_call)) - return messages - - def handle_initial(): - print("handle initial", messages.value) - for message in messages.value: - handle_message(message) - - solara.use_effect(handle_initial, []) - # result = solara.use_thread(ask, dependencies=[messages.value]) + # Create DOM for chat interface with solara.Column( classes=["chat-interface"], ): @@ -214,16 +191,25 @@ def handle_initial(): "overflow-y": "auto", "height": "100px", "flex-direction": "column-reverse", - } + }, + classes=["chat-box"], ): for message in reversed(messages.value): with solara.Row(style={"align-items": "flex-start"}): - if message.role == "user": + # Catch "messages" that are actually tool calls + if isinstance(message, dict): + icon = ( + "mdi-map" + if message["output"] == "Map updated" + else "mdi-map-marker" + ) + solara.v.Icon(children=[icon], style_="padding-top: 10px;") + solara.Markdown(message["output"]) + elif message.role == "user": solara.Text( message.content[0].text.value, classes=["chat-message", "user-message"], ) - assert len(message.content) == 1 elif message.role == "assistant": if message.content[0].text.value: solara.v.Icon( @@ -246,8 +232,6 @@ def handle_initial(): repr(message), classes=["chat-message", "assistant-message"], ) - elif message["role"] == "tool": - pass # no need to display else: solara.v.Icon( children=["mdi-compass-outline"], @@ -272,21 +256,6 @@ def handle_initial(): @solara.component def Page(): - reset_counter, set_reset_counter = solara.use_state(0) - print("reset", reset_counter, f"chat-{reset_counter}") - - def reset_ui(): - set_reset_counter(reset_counter + 1) - - def save(): - with open("log.json", "w") as f: - json.dump(messages.value, f) - - def load(): - with open("log.json", "r") as f: - messages.set(json.load(f)) - reset_ui() - with solara.Column( classes=["ui-container"], gap="5vh", @@ -299,16 +268,12 @@ def load(): unsafe_innerHTML="Wanderlust", style={"display": "inline-block"}, ) - # with solara.Row(gap="10px"): - # solara.Button("Save", on_click=save) - # solara.Button("Load", on_click=load) - # solara.Button("Soft reset", on_click=reset_ui) with solara.Row( justify="space-between", style={"flex-grow": "1"}, classes=["container-row"] ): - ChatInterface().key(f"chat-{reset_counter}") + ChatInterface() # .key(f"chat-{reset_counter}") with solara.Column(classes=["map-container"]): - Map() # .key(f"map-{reset_counter}") + Map() solara.Style( """ @@ -335,13 +300,30 @@ def load(): height: 100%; width: 38vw; justify-content: center; - background: linear-gradient(0deg, transparent 75%, white 100%); + position: relative; + } + .chat-interface:after { + content: ""; + position: absolute; + z-index: 1; + top: 0; + left: 0; + pointer-events: none; + background-image: linear-gradient(to top, rgba(255,255,255,0), rgba(255,255,255, 1) 100%); + width: 100%; + height: 15%; + } + .chat-box > :last-child{ + padding-top: 7.5vh; } .map-container{ width: 50vw; height: 100%; justify-content: center; } + .user-message{ + font-weight: bold; + } @media screen and (max-aspect-ratio: 1/1) { .ui-container{ padding: 30px;