Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simpler connection pool support. #50

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
570 changes: 20 additions & 550 deletions langgraph-tests/tests/__snapshots__/test_large_cases.ambr

Large diffs are not rendered by default.

182 changes: 8 additions & 174 deletions langgraph-tests/tests/__snapshots__/test_pregel.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

'''
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge[pymysql_sqlalchemy_pool]
# name: test_in_one_fan_out_state_graph_waiting_edge[pymysql_pool]
'''
graph TD;
__start__ --> rewrite_query;
Expand All @@ -38,19 +38,6 @@

'''
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge[pymysql_callable]
'''
graph TD;
__start__ --> rewrite_query;
analyzer_one --> retriever_one;
qa --> __end__;
retriever_one --> qa;
retriever_two --> qa;
rewrite_query --> analyzer_one;
rewrite_query --> retriever_two;

'''
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql]
'''
graph TD;
Expand Down Expand Up @@ -121,77 +108,7 @@
'type': 'object',
})
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_sqlalchemy_pool]
'''
graph TD;
__start__ --> rewrite_query;
analyzer_one --> retriever_one;
qa --> __end__;
retriever_one --> qa;
retriever_two --> qa;
rewrite_query --> analyzer_one;
rewrite_query -.-> retriever_two;

'''
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_sqlalchemy_pool].1
dict({
'definitions': dict({
'InnerObject': dict({
'properties': dict({
'yo': dict({
'title': 'Yo',
'type': 'integer',
}),
}),
'required': list([
'yo',
]),
'title': 'InnerObject',
'type': 'object',
}),
}),
'properties': dict({
'inner': dict({
'$ref': '#/definitions/InnerObject',
}),
'query': dict({
'title': 'Query',
'type': 'string',
}),
}),
'required': list([
'query',
'inner',
]),
'title': 'Input',
'type': 'object',
})
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_sqlalchemy_pool].2
dict({
'properties': dict({
'answer': dict({
'title': 'Answer',
'type': 'string',
}),
'docs': dict({
'items': dict({
'type': 'string',
}),
'title': 'Docs',
'type': 'array',
}),
}),
'required': list([
'answer',
'docs',
]),
'title': 'Output',
'type': 'object',
})
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_callable]
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_pool]
'''
graph TD;
__start__ --> rewrite_query;
Expand All @@ -204,7 +121,7 @@

'''
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_callable].1
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_pool].1
dict({
'definitions': dict({
'InnerObject': dict({
Expand Down Expand Up @@ -238,7 +155,7 @@
'type': 'object',
})
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_callable].2
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_pool].2
dict({
'properties': dict({
'answer': dict({
Expand Down Expand Up @@ -401,7 +318,7 @@
'type': 'object',
})
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_sqlalchemy_pool]
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_pool]
'''
graph TD;
__start__ --> rewrite_query;
Expand All @@ -414,7 +331,7 @@

'''
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_sqlalchemy_pool].1
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_pool].1
dict({
'$defs': dict({
'InnerObject': dict({
Expand Down Expand Up @@ -448,77 +365,7 @@
'type': 'object',
})
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_sqlalchemy_pool].2
dict({
'properties': dict({
'answer': dict({
'title': 'Answer',
'type': 'string',
}),
'docs': dict({
'items': dict({
'type': 'string',
}),
'title': 'Docs',
'type': 'array',
}),
}),
'required': list([
'answer',
'docs',
]),
'title': 'Output',
'type': 'object',
})
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_callable]
'''
graph TD;
__start__ --> rewrite_query;
analyzer_one --> retriever_one;
qa --> __end__;
retriever_one --> qa;
retriever_two --> qa;
rewrite_query --> analyzer_one;
rewrite_query -.-> retriever_two;

'''
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_callable].1
dict({
'$defs': dict({
'InnerObject': dict({
'properties': dict({
'yo': dict({
'title': 'Yo',
'type': 'integer',
}),
}),
'required': list([
'yo',
]),
'title': 'InnerObject',
'type': 'object',
}),
}),
'properties': dict({
'inner': dict({
'$ref': '#/$defs/InnerObject',
}),
'query': dict({
'title': 'Query',
'type': 'string',
}),
}),
'required': list([
'query',
'inner',
]),
'title': 'Input',
'type': 'object',
})
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_callable].2
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_pool].2
dict({
'properties': dict({
'answer': dict({
Expand Down Expand Up @@ -624,20 +471,7 @@

'''
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_via_branch[pymysql_sqlalchemy_pool]
'''
graph TD;
__start__ --> rewrite_query;
analyzer_one --> retriever_one;
qa --> __end__;
retriever_one --> qa;
retriever_two --> qa;
rewrite_query --> analyzer_one;
rewrite_query -.-> retriever_two;

'''
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_via_branch[pymysql_callable]
# name: test_in_one_fan_out_state_graph_waiting_edge_via_branch[pymysql_pool]
'''
graph TD;
__start__ --> rewrite_query;
Expand Down
71 changes: 13 additions & 58 deletions langgraph-tests/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import asynccontextmanager, closing
from typing import AsyncIterator, Optional, cast
from contextlib import asynccontextmanager
from typing import AsyncIterator, Optional
from uuid import UUID, uuid4

import aiomysql # type: ignore
Expand All @@ -9,7 +9,7 @@
from langchain_core import __version__ as core_version
from packaging import version
from pytest_mock import MockerFixture
from sqlalchemy import Pool, create_pool_from_url
from sqlalchemy import Engine, create_engine

from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.mysql.aio import AIOMySQLSaver, ShallowAIOMySQLSaver
Expand All @@ -28,9 +28,9 @@
SHOULD_CHECK_SNAPSHOTS = IS_LANGCHAIN_CORE_030_OR_GREATER


def get_pymysql_sqlalchemy_pool(uri: str) -> Pool:
def get_pymysql_sqlalchemy_engine(uri: str) -> Engine:
updated_uri = uri.replace("mysql://", "mysql+pymysql://")
return create_pool_from_url(updated_uri)
return create_engine(updated_uri)


@pytest.fixture
Expand Down Expand Up @@ -91,39 +91,16 @@ def checkpointer_pymysql_shallow():


@pytest.fixture(scope="function")
def checkpointer_pymysql_sqlalchemy_pool():
def checkpointer_pymysql_pool():
database = f"test_{uuid4().hex[:16]}"

# create unique db
with pymysql.connect(**PyMySQLSaver.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE {database}")
try:
checkpointer = PyMySQLSaver(get_pymysql_sqlalchemy_pool(DEFAULT_MYSQL_URI + database))
checkpointer.setup()
yield checkpointer
finally:
# drop unique db
with pymysql.connect(**PyMySQLSaver.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE {database}")


@pytest.fixture(scope="function")
def checkpointer_pymysql_callable():
database = f"test_{uuid4().hex[:16]}"

# create unique db
with pymysql.connect(**PyMySQLSaver.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE {database}")
try:
pool = get_pymysql_sqlalchemy_pool(DEFAULT_MYSQL_URI + database)

def callable() -> pymysql.Connection:
return cast(pymysql.Connection, closing(pool.connect()))

checkpointer = PyMySQLSaver(callable)
pool = get_pymysql_sqlalchemy_engine(DEFAULT_MYSQL_URI + database)
checkpointer = PyMySQLSaver(pool.raw_connection)
checkpointer.setup()
yield checkpointer
finally:
Expand Down Expand Up @@ -255,27 +232,7 @@ def store_pymysql():


@pytest.fixture(scope="function")
def store_pymysql_sqlalchemy_pool():
database = f"test_{uuid4().hex[:16]}"

# create unique db
with pymysql.connect(**PyMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE {database}")
try:
# yield store
store = PyMySQLStore(get_pymysql_sqlalchemy_pool(DEFAULT_MYSQL_URI + database))
store.setup()
yield store
finally:
# drop unique db
with pymysql.connect(**PyMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE {database}")


@pytest.fixture(scope="function")
def store_pymysql_callable():
def store_pymysql_pool():
database = f"test_{uuid4().hex[:16]}"

# create unique db
Expand All @@ -284,9 +241,8 @@ def store_pymysql_callable():
cursor.execute(f"CREATE DATABASE {database}")
try:
# yield store
pool = get_pymysql_sqlalchemy_pool(DEFAULT_MYSQL_URI + database)
callable = lambda: cast(pymysql.Connection, closing(pool.connect()))
store = PyMySQLStore(callable)
engine = get_pymysql_sqlalchemy_engine(DEFAULT_MYSQL_URI + database)
store = PyMySQLStore(engine.raw_connection)
store.setup()
yield store
finally:
Expand Down Expand Up @@ -386,8 +342,7 @@ async def awith_store(store_name: Optional[str]) -> AsyncIterator[BaseStore]:
SHALLOW_CHECKPOINTERS_SYNC = ["pymysql_shallow"]
REGULAR_CHECKPOINTERS_SYNC = [
"pymysql",
"pymysql_sqlalchemy_pool",
"pymysql_callable"
"pymysql_pool",
]
ALL_CHECKPOINTERS_SYNC = [
*REGULAR_CHECKPOINTERS_SYNC,
Expand All @@ -399,5 +354,5 @@ async def awith_store(store_name: Optional[str]) -> AsyncIterator[BaseStore]:
*REGULAR_CHECKPOINTERS_ASYNC,
*SHALLOW_CHECKPOINTERS_ASYNC,
]
ALL_STORES_SYNC = ["pymysql", "pymysql_sqlalchemy_pool", "pymysql_callable"]
ALL_STORES_SYNC = ["pymysql", "pymysql_pool"]
ALL_STORES_ASYNC = ["aiomysql", "aiomysql_pool"]
Loading
Loading