From 4b29d15b3e5df65f3503daffa6bc7af85159507b Mon Sep 17 00:00:00 2001 From: Liang Wu Date: Mon, 19 Jan 2026 19:38:52 -0800 Subject: [PATCH] fix: Handle async driver URLs in migration tool The migration tool uses synchronous SQLAlchemy engines but users often provide async driver URLs (e.g., postgresql+asyncpg://) since that's what ADK requires at runtime. This fix: - Makes `to_sync_url()` public in `_schema_check_utils.py` for reuse - Updates `migrate_from_sqlalchemy_pickle.py` to convert async URLs - Updates `migrate_from_sqlalchemy_sqlite.py` to convert async URLs - Adds comprehensive unit tests for `to_sync_url()` function - Adds integration test for migration with async driver URLs Fixes #4176 Co-authored-by: Liang Wu PiperOrigin-RevId: 858359061 --- .../sessions/migration/_schema_check_utils.py | 26 +++- .../migrate_from_sqlalchemy_pickle.py | 10 +- .../migrate_from_sqlalchemy_sqlite.py | 8 +- .../sessions/migration/test_migration.py | 141 ++++++++++++++++++ 4 files changed, 179 insertions(+), 6 deletions(-) diff --git a/src/google/adk/sessions/migration/_schema_check_utils.py b/src/google/adk/sessions/migration/_schema_check_utils.py index 249161c84c..3223847b6b 100644 --- a/src/google/adk/sessions/migration/_schema_check_utils.py +++ b/src/google/adk/sessions/migration/_schema_check_utils.py @@ -82,8 +82,28 @@ def get_db_schema_version_from_connection(connection) -> str: return _get_schema_version_impl(inspector, connection) -def _to_sync_url(db_url: str) -> str: - """Removes '+driver' from SQLAlchemy URL.""" +def to_sync_url(db_url: str) -> str: + """Removes '+driver' from SQLAlchemy URL. + + This is useful when you need to use a synchronous SQLAlchemy engine with + a database URL that specifies an async driver (e.g., postgresql+asyncpg:// + or sqlite+aiosqlite://). + + Args: + db_url: The database URL, potentially with a driver specification. + + Returns: + The database URL with the driver specification removed (e.g., + 'postgresql+asyncpg://host/db' becomes 'postgresql://host/db'). + + Examples: + >>> to_sync_url('postgresql+asyncpg://localhost/mydb') + 'postgresql://localhost/mydb' + >>> to_sync_url('sqlite+aiosqlite:///path/to/db.sqlite') + 'sqlite:///path/to/db.sqlite' + >>> to_sync_url('mysql://localhost/mydb') # No driver, returns unchanged + 'mysql://localhost/mydb' + """ if "://" in db_url: scheme, _, rest = db_url.partition("://") if "+" in scheme: @@ -106,7 +126,7 @@ def get_db_schema_version(db_url: str) -> str: """ engine = None try: - engine = create_sync_engine(_to_sync_url(db_url)) + engine = create_sync_engine(to_sync_url(db_url)) with engine.connect() as connection: inspector = inspect(connection) return _get_schema_version_impl(inspector, connection) diff --git a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py index d24f71f682..b6ad673d90 100644 --- a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py +++ b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py @@ -165,9 +165,15 @@ def _get_state_dict(state_val: Any) -> dict: # --- Migration Logic --- def migrate(source_db_url: str, dest_db_url: str): """Migrates data from old pickle schema to new JSON schema.""" + # Convert async driver URLs to sync URLs for SQLAlchemy's synchronous engine. + # This allows users to provide URLs like 'postgresql+asyncpg://...' and have + # them automatically converted to 'postgresql://...' for migration. + source_sync_url = _schema_check_utils.to_sync_url(source_db_url) + dest_sync_url = _schema_check_utils.to_sync_url(dest_db_url) + logger.info(f"Connecting to source database: {source_db_url}") try: - source_engine = create_engine(source_db_url) + source_engine = create_engine(source_sync_url) SourceSession = sessionmaker(bind=source_engine) except Exception as e: logger.error(f"Failed to connect to source database: {e}") @@ -175,7 +181,7 @@ def migrate(source_db_url: str, dest_db_url: str): logger.info(f"Connecting to destination database: {dest_db_url}") try: - dest_engine = create_engine(dest_db_url) + dest_engine = create_engine(dest_sync_url) v1.Base.metadata.create_all(dest_engine) DestSession = sessionmaker(bind=dest_engine) except Exception as e: diff --git a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py index a0dd3a84a1..28830a8bc8 100644 --- a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py +++ b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py @@ -23,6 +23,7 @@ import sys from google.adk.sessions import sqlite_session_service as sss +from google.adk.sessions.migration import _schema_check_utils from google.adk.sessions.schemas import v0 as v0_schema from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker @@ -32,9 +33,14 @@ def migrate(source_db_url: str, dest_db_path: str): """Migrates data from a SQLAlchemy-based SQLite DB to the new schema.""" + # Convert async driver URLs to sync URLs for SQLAlchemy's synchronous engine. + # This allows users to provide URLs like 'sqlite+aiosqlite://...' and have + # them automatically converted to 'sqlite://...' for migration. + source_sync_url = _schema_check_utils.to_sync_url(source_db_url) + logger.info(f"Connecting to source database: {source_db_url}") try: - engine = create_engine(source_db_url) + engine = create_engine(source_sync_url) v0_schema.Base.metadata.create_all( engine ) # Ensure tables exist for inspection diff --git a/tests/unittests/sessions/migration/test_migration.py b/tests/unittests/sessions/migration/test_migration.py index f51356ec32..f8d18bbd01 100644 --- a/tests/unittests/sessions/migration/test_migration.py +++ b/tests/unittests/sessions/migration/test_migration.py @@ -23,10 +23,88 @@ from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle as mfsp from google.adk.sessions.schemas import v0 from google.adk.sessions.schemas import v1 +import pytest from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker +class TestToSyncUrl: + """Tests for the to_sync_url function.""" + + @pytest.mark.parametrize( + "input_url,expected_url", + [ + # PostgreSQL async drivers + ( + "postgresql+asyncpg://localhost/mydb", + "postgresql://localhost/mydb", + ), + ( + "postgresql+asyncpg://user:pass@localhost:5432/mydb", + "postgresql://user:pass@localhost:5432/mydb", + ), + # PostgreSQL sync drivers (should still strip) + ( + "postgresql+psycopg2://localhost/mydb", + "postgresql://localhost/mydb", + ), + # MySQL async drivers + ( + "mysql+aiomysql://localhost/mydb", + "mysql://localhost/mydb", + ), + ( + "mysql+asyncmy://user:pass@localhost:3306/mydb", + "mysql://user:pass@localhost:3306/mydb", + ), + # SQLite async driver + ( + "sqlite+aiosqlite:///path/to/db.sqlite", + "sqlite:///path/to/db.sqlite", + ), + ( + "sqlite+aiosqlite:///:memory:", + "sqlite:///:memory:", + ), + # URLs without driver specification (unchanged) + ( + "postgresql://localhost/mydb", + "postgresql://localhost/mydb", + ), + ( + "mysql://localhost/mydb", + "mysql://localhost/mydb", + ), + ( + "sqlite:///path/to/db.sqlite", + "sqlite:///path/to/db.sqlite", + ), + # Edge cases + ( + "sqlite:///:memory:", + "sqlite:///:memory:", + ), + # Complex URL with query parameters + ( + "postgresql+asyncpg://user:pass@host/db?ssl=require", + "postgresql://user:pass@host/db?ssl=require", + ), + ], + ) + def test_to_sync_url(self, input_url, expected_url): + """Test that async driver specifications are correctly removed.""" + assert _schema_check_utils.to_sync_url(input_url) == expected_url + + def test_to_sync_url_no_scheme_separator(self): + """Test that URLs without :// are returned unchanged.""" + # This is an invalid URL but the function should handle it gracefully + assert _schema_check_utils.to_sync_url("not-a-url") == "not-a-url" + + def test_to_sync_url_empty_string(self): + """Test that empty string is returned unchanged.""" + assert _schema_check_utils.to_sync_url("") == "" + + def test_migrate_from_sqlalchemy_pickle(tmp_path): """Tests for migrate_from_sqlalchemy_pickle.""" source_db_path = tmp_path / "source_pickle.db" @@ -104,3 +182,66 @@ def test_migrate_from_sqlalchemy_pickle(tmp_path): assert event_res.event_data["actions"]["state_delta"] == {"skey": 4} dest_session.close() + + +def test_migrate_from_sqlalchemy_pickle_with_async_driver_urls(tmp_path): + """Tests that migration works with async driver URLs (fixes issue #4176). + + Users often provide async driver URLs (e.g., postgresql+asyncpg://) since + that's what ADK requires at runtime. The migration tool should handle these + by automatically converting them to sync URLs. + """ + source_db_path = tmp_path / "source_pickle_async.db" + dest_db_path = tmp_path / "dest_json_async.db" + # Use async driver URLs like users would typically provide + source_db_url = f"sqlite+aiosqlite:///{source_db_path}" + dest_db_url = f"sqlite+aiosqlite:///{dest_db_path}" + + # Set up source DB with old pickle schema using sync URL + sync_source_url = f"sqlite:///{source_db_path}" + source_engine = create_engine(sync_source_url) + v0.Base.metadata.create_all(source_engine) + SourceSession = sessionmaker(bind=source_engine) + source_session = SourceSession() + + # Populate source data + now = datetime.now(timezone.utc) + app_state = v0.StorageAppState( + app_name="async_app", state={"key": "value"}, update_time=now + ) + session = v0.StorageSession( + app_name="async_app", + user_id="async_user", + id="async_session", + state={}, + create_time=now, + update_time=now, + ) + source_session.add_all([app_state, session]) + source_session.commit() + source_session.close() + + # This should NOT raise an error about async drivers (the fix for #4176) + mfsp.migrate(source_db_url, dest_db_url) + + # Verify destination DB + sync_dest_url = f"sqlite:///{dest_db_path}" + dest_engine = create_engine(sync_dest_url) + DestSession = sessionmaker(bind=dest_engine) + dest_session = DestSession() + + metadata = dest_session.query(v1.StorageMetadata).first() + assert metadata is not None + assert metadata.key == _schema_check_utils.SCHEMA_VERSION_KEY + assert metadata.value == _schema_check_utils.SCHEMA_VERSION_1_JSON + + app_state_res = dest_session.query(v1.StorageAppState).first() + assert app_state_res is not None + assert app_state_res.app_name == "async_app" + assert app_state_res.state == {"key": "value"} + + session_res = dest_session.query(v1.StorageSession).first() + assert session_res is not None + assert session_res.id == "async_session" + + dest_session.close()