Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions src/google/adk/sessions/migration/_schema_check_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,28 @@ def get_db_schema_version_from_connection(connection) -> str:
return _get_schema_version_impl(inspector, connection)


def _to_sync_url(db_url: str) -> str:
"""Removes '+driver' from SQLAlchemy URL."""
def to_sync_url(db_url: str) -> str:
"""Removes '+driver' from SQLAlchemy URL.

This is useful when you need to use a synchronous SQLAlchemy engine with
a database URL that specifies an async driver (e.g., postgresql+asyncpg://
or sqlite+aiosqlite://).

Args:
db_url: The database URL, potentially with a driver specification.

Returns:
The database URL with the driver specification removed (e.g.,
'postgresql+asyncpg://host/db' becomes 'postgresql://host/db').

Examples:
>>> to_sync_url('postgresql+asyncpg://localhost/mydb')
'postgresql://localhost/mydb'
>>> to_sync_url('sqlite+aiosqlite:///path/to/db.sqlite')
'sqlite:///path/to/db.sqlite'
>>> to_sync_url('mysql://localhost/mydb') # No driver, returns unchanged
'mysql://localhost/mydb'
"""
if "://" in db_url:
scheme, _, rest = db_url.partition("://")
if "+" in scheme:
Expand All @@ -106,7 +126,7 @@ def get_db_schema_version(db_url: str) -> str:
"""
engine = None
try:
engine = create_sync_engine(_to_sync_url(db_url))
engine = create_sync_engine(to_sync_url(db_url))
with engine.connect() as connection:
inspector = inspect(connection)
return _get_schema_version_impl(inspector, connection)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,17 +165,23 @@ def _get_state_dict(state_val: Any) -> dict:
# --- Migration Logic ---
def migrate(source_db_url: str, dest_db_url: str):
"""Migrates data from old pickle schema to new JSON schema."""
# Convert async driver URLs to sync URLs for SQLAlchemy's synchronous engine.
# This allows users to provide URLs like 'postgresql+asyncpg://...' and have
# them automatically converted to 'postgresql://...' for migration.
source_sync_url = _schema_check_utils.to_sync_url(source_db_url)
dest_sync_url = _schema_check_utils.to_sync_url(dest_db_url)

logger.info(f"Connecting to source database: {source_db_url}")
try:
source_engine = create_engine(source_db_url)
source_engine = create_engine(source_sync_url)
SourceSession = sessionmaker(bind=source_engine)
except Exception as e:
logger.error(f"Failed to connect to source database: {e}")
raise RuntimeError(f"Failed to connect to source database: {e}") from e

logger.info(f"Connecting to destination database: {dest_db_url}")
try:
dest_engine = create_engine(dest_db_url)
dest_engine = create_engine(dest_sync_url)
v1.Base.metadata.create_all(dest_engine)
DestSession = sessionmaker(bind=dest_engine)
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import sys

from google.adk.sessions import sqlite_session_service as sss
from google.adk.sessions.migration import _schema_check_utils
from google.adk.sessions.schemas import v0 as v0_schema
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
Expand All @@ -32,9 +33,14 @@

def migrate(source_db_url: str, dest_db_path: str):
"""Migrates data from a SQLAlchemy-based SQLite DB to the new schema."""
# Convert async driver URLs to sync URLs for SQLAlchemy's synchronous engine.
# This allows users to provide URLs like 'sqlite+aiosqlite://...' and have
# them automatically converted to 'sqlite://...' for migration.
source_sync_url = _schema_check_utils.to_sync_url(source_db_url)

logger.info(f"Connecting to source database: {source_db_url}")
try:
engine = create_engine(source_db_url)
engine = create_engine(source_sync_url)
v0_schema.Base.metadata.create_all(
engine
) # Ensure tables exist for inspection
Expand Down
141 changes: 141 additions & 0 deletions tests/unittests/sessions/migration/test_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,88 @@
from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle as mfsp
from google.adk.sessions.schemas import v0
from google.adk.sessions.schemas import v1
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker


class TestToSyncUrl:
"""Tests for the to_sync_url function."""

@pytest.mark.parametrize(
"input_url,expected_url",
[
# PostgreSQL async drivers
(
"postgresql+asyncpg://localhost/mydb",
"postgresql://localhost/mydb",
),
(
"postgresql+asyncpg://user:pass@localhost:5432/mydb",
"postgresql://user:pass@localhost:5432/mydb",
),
# PostgreSQL sync drivers (should still strip)
(
"postgresql+psycopg2://localhost/mydb",
"postgresql://localhost/mydb",
),
# MySQL async drivers
(
"mysql+aiomysql://localhost/mydb",
"mysql://localhost/mydb",
),
(
"mysql+asyncmy://user:pass@localhost:3306/mydb",
"mysql://user:pass@localhost:3306/mydb",
),
# SQLite async driver
(
"sqlite+aiosqlite:///path/to/db.sqlite",
"sqlite:///path/to/db.sqlite",
),
(
"sqlite+aiosqlite:///:memory:",
"sqlite:///:memory:",
),
# URLs without driver specification (unchanged)
(
"postgresql://localhost/mydb",
"postgresql://localhost/mydb",
),
(
"mysql://localhost/mydb",
"mysql://localhost/mydb",
),
(
"sqlite:///path/to/db.sqlite",
"sqlite:///path/to/db.sqlite",
),
# Edge cases
(
"sqlite:///:memory:",
"sqlite:///:memory:",
),
# Complex URL with query parameters
(
"postgresql+asyncpg://user:pass@host/db?ssl=require",
"postgresql://user:pass@host/db?ssl=require",
),
],
)
def test_to_sync_url(self, input_url, expected_url):
"""Test that async driver specifications are correctly removed."""
assert _schema_check_utils.to_sync_url(input_url) == expected_url

def test_to_sync_url_no_scheme_separator(self):
"""Test that URLs without :// are returned unchanged."""
# This is an invalid URL but the function should handle it gracefully
assert _schema_check_utils.to_sync_url("not-a-url") == "not-a-url"

def test_to_sync_url_empty_string(self):
"""Test that empty string is returned unchanged."""
assert _schema_check_utils.to_sync_url("") == ""


def test_migrate_from_sqlalchemy_pickle(tmp_path):
"""Tests for migrate_from_sqlalchemy_pickle."""
source_db_path = tmp_path / "source_pickle.db"
Expand Down Expand Up @@ -104,3 +182,66 @@ def test_migrate_from_sqlalchemy_pickle(tmp_path):
assert event_res.event_data["actions"]["state_delta"] == {"skey": 4}

dest_session.close()


def test_migrate_from_sqlalchemy_pickle_with_async_driver_urls(tmp_path):
"""Tests that migration works with async driver URLs (fixes issue #4176).

Users often provide async driver URLs (e.g., postgresql+asyncpg://) since
that's what ADK requires at runtime. The migration tool should handle these
by automatically converting them to sync URLs.
"""
source_db_path = tmp_path / "source_pickle_async.db"
dest_db_path = tmp_path / "dest_json_async.db"
# Use async driver URLs like users would typically provide
source_db_url = f"sqlite+aiosqlite:///{source_db_path}"
dest_db_url = f"sqlite+aiosqlite:///{dest_db_path}"

# Set up source DB with old pickle schema using sync URL
sync_source_url = f"sqlite:///{source_db_path}"
source_engine = create_engine(sync_source_url)
v0.Base.metadata.create_all(source_engine)
SourceSession = sessionmaker(bind=source_engine)
source_session = SourceSession()

# Populate source data
now = datetime.now(timezone.utc)
app_state = v0.StorageAppState(
app_name="async_app", state={"key": "value"}, update_time=now
)
session = v0.StorageSession(
app_name="async_app",
user_id="async_user",
id="async_session",
state={},
create_time=now,
update_time=now,
)
source_session.add_all([app_state, session])
source_session.commit()
source_session.close()

# This should NOT raise an error about async drivers (the fix for #4176)
mfsp.migrate(source_db_url, dest_db_url)

# Verify destination DB
sync_dest_url = f"sqlite:///{dest_db_path}"
dest_engine = create_engine(sync_dest_url)
DestSession = sessionmaker(bind=dest_engine)
dest_session = DestSession()

metadata = dest_session.query(v1.StorageMetadata).first()
assert metadata is not None
assert metadata.key == _schema_check_utils.SCHEMA_VERSION_KEY
assert metadata.value == _schema_check_utils.SCHEMA_VERSION_1_JSON

app_state_res = dest_session.query(v1.StorageAppState).first()
assert app_state_res is not None
assert app_state_res.app_name == "async_app"
assert app_state_res.state == {"key": "value"}

session_res = dest_session.query(v1.StorageSession).first()
assert session_res is not None
assert session_res.id == "async_session"

dest_session.close()
Loading