Skip to content

Commit

Permalink
Merge pull request #46 from tjni/vb/async-list-checkpoints
Browse files Browse the repository at this point in the history
Handle calling .list on async checkpointer.
  • Loading branch information
tjni authored Jan 15, 2025
2 parents 8d30a85 + 0c9175c commit 5be6be1
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 2 deletions.
14 changes: 13 additions & 1 deletion langgraph/checkpoint/mysql/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,18 @@ def list(
Yields:
Iterator[CheckpointTuple]: An iterator of matching checkpoint tuples.
"""
try:
# check if we are in the main thread, only bg threads can block
# we don't check in other methods to avoid the overhead
if asyncio.get_running_loop() is self.loop:
raise asyncio.InvalidStateError(
"Synchronous calls to AsyncSqliteSaver are only allowed from a "
"different thread. From the main thread, use the async interface. "
"For example, use `checkpointer.alist(...)` or `await "
"graph.ainvoke(...)`."
)
except RuntimeError:
pass
aiter_ = self.alist(config, filter=filter, before=before, limit=limit)
while True:
try:
Expand Down Expand Up @@ -407,7 +419,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
if asyncio.get_running_loop() is self.loop:
raise asyncio.InvalidStateError(
"Synchronous calls to AIOMySQLSaver are only allowed from a "
"different thread. From the main thread, use the async interface."
"different thread. From the main thread, use the async interface. "
"For example, use `await checkpointer.aget_tuple(...)` or `await "
"graph.ainvoke(...)`."
)
Expand Down
14 changes: 13 additions & 1 deletion langgraph/checkpoint/mysql/shallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,18 @@ def list(
on the provided config. For shallow savers, this method returns a list with
ONLY the most recent checkpoint.
"""
try:
# check if we are in the main thread, only bg threads can block
# we don't check in other methods to avoid the overhead
if asyncio.get_running_loop() is self.loop:
raise asyncio.InvalidStateError(
"Synchronous calls to AsyncSqliteSaver are only allowed from a "
"different thread. From the main thread, use the async interface. "
"For example, use `checkpointer.alist(...)` or `await "
"graph.ainvoke(...)`."
)
except RuntimeError:
pass
aiter_ = self.alist(config, filter=filter, before=before, limit=limit)
while True:
try:
Expand All @@ -745,7 +757,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
if asyncio.get_running_loop() is self.loop:
raise asyncio.InvalidStateError(
"Synchronous calls to asynchronous shallow savers are only allowed from a "
"different thread. From the main thread, use the async interface."
"different thread. From the main thread, use the async interface. "
"For example, use `await checkpointer.aget_tuple(...)` or `await "
"graph.ainvoke(...)`."
)
Expand Down
21 changes: 21 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from copy import deepcopy
Expand All @@ -17,6 +18,7 @@
)
from langgraph.checkpoint.mysql.aio import AIOMySQLSaver, ShallowAIOMySQLSaver
from langgraph.checkpoint.serde.types import TASKS
from langgraph.graph import END, START, MessagesState, StateGraph
from tests.conftest import DEFAULT_BASE_URI


Expand Down Expand Up @@ -369,3 +371,22 @@ async def test_write_with_same_checkpoint_ns_updates(
results = [c async for c in saver.alist({})]

assert len(results) == 1


@pytest.mark.parametrize("saver_name", ["base", "pool", "shallow"])
async def test_graph_sync_get_state_history_raises(saver_name: str) -> None:
"""Regression test for https://github.com/langchain-ai/langgraph/issues/2992"""

builder = StateGraph(MessagesState)
builder.add_node("foo", lambda _: None)
builder.add_edge(START, "foo")
builder.add_edge("foo", END)

async with _saver(saver_name) as saver:
graph = builder.compile(checkpointer=saver)
config: RunnableConfig = {"configurable": {"thread_id": "1"}}
await graph.ainvoke({"messages": []}, config)

# this method should not hang
with pytest.raises(asyncio.exceptions.InvalidStateError):
next(graph.get_state_history(config))

0 comments on commit 5be6be1

Please sign in to comment.