diff --git a/src/ipylangchat/__init__.py b/src/ipylangchat/__init__.py index 20205ba..0bb5af1 100644 --- a/src/ipylangchat/__init__.py +++ b/src/ipylangchat/__init__.py @@ -1,3 +1,4 @@ +import asyncio import importlib.metadata import pathlib @@ -24,19 +25,80 @@ def __init__(self, chain, **kwargs): self.chain = chain self.chat_history = [] - def handle_user_question(change): + def on_user_msg(change): self.chat_history.extend( [ HumanMessage(content=self.user_msg), AIMessage(content=self.ai_msg), ] ) - self.send({ "type": "create" }) + self.send({"type": "create"}) for chunk in chain.stream( {"input": change.new, "chat_history": self.chat_history} ): if "answer" in chunk: self.send({"type": "append", "text": chunk["answer"]}) - self.send({ "type": "finish" }) + self.send({"type": "finish"}) - self.observe(handle_user_question, names=["user_msg"]) + self.observe(on_user_msg, names=["user_msg"]) + + +class AsyncChatUIWidget(anywidget.AnyWidget): + """ + Chat UI widget that uses an event loop to process astream events. + + Notes + ----- + There doesn't seem to be a vetted solution for running a separate event + loop in Jupyter, since Jupyter is already running in its own event loop. + + https://github.com/python/cpython/issues/66435 + + The workaround is to use the `nest_asyncio` package, which monkeypatches + asyncio to allow nested event loops but it is no longer maintained. + + ``` + import nest_asyncio + nest_asyncio.apply() + ``` + """ + _esm = pathlib.Path(__file__).parent / "static" / "widget.js" + _css = pathlib.Path(__file__).parent / "static" / "widget.css" + user_msg = traitlets.Unicode(sync=True) + ai_msg = traitlets.Unicode(sync=True) + + def __init__(self, chain, version="v1", event_loop=None, **kwargs): + super().__init__(**kwargs) + + self.chain = chain + self.chat_history = [] + self.version = version + if event_loop is None: + self.event_loop = asyncio.get_event_loop() + else: + self.event_loop = event_loop + + async def process_user_input(user_input): + async for event in chain.astream_events( + {"input": user_input, "chat_history": self.chat_history}, + version=self.version, + ): + if ( + event["event"] == "on_chat_model_stream" + and "seq:step:3" in event["tags"] # TODO: find another way to filter for the output chat model + ): + chunk = event["data"]["chunk"] + self.send({"type": "append", "text": f"{chunk.content}"}) + + def on_user_msg(change): + self.chat_history.extend( + [ + HumanMessage(content=self.user_msg), + AIMessage(content=self.ai_msg), + ] + ) + self.send({"type": "create"}) + self.event_loop.run_until_complete(process_user_input(change.new)) + self.send({"type": "finish"}) + + self.observe(on_user_msg, names=["user_msg"])