From 7e6ef71eec8be2e804286cc4140d0cbdf84f1206 Mon Sep 17 00:00:00 2001 From: Liang Wu Date: Mon, 15 Dec 2025 17:32:24 -0800 Subject: [PATCH] feat: Introduce new database schema for DatabaseSessionService Part 1 of https://github.com/google/adk-python/discussions/3605. This change adds a new schema that uses JSON serialization to store Events data in the database. A new "adk_internal_metadata" table is also added to store information like schema version. Since we want to keep supporting existing DB, we fork from the original schema and call it "v0", while the new one is called "v1". The change is no-op for existing users. In later change, the new schema will be used for new databases, and migration scripts will be provided for existing databases. Co-authored-by: Liang Wu PiperOrigin-RevId: 844986248 --- .../adk/sessions/database_session_service.py | 441 ++---------------- .../migrate_from_sqlalchemy_sqlite.py | 14 +- src/google/adk/sessions/schemas/shared.py | 67 +++ src/google/adk/sessions/schemas/v0.py | 373 +++++++++++++++ src/google/adk/sessions/schemas/v1.py | 239 ++++++++++ .../sessions/test_dynamic_pickle_type.py | 2 +- 6 files changed, 730 insertions(+), 406 deletions(-) create mode 100644 src/google/adk/sessions/schemas/shared.py create mode 100644 src/google/adk/sessions/schemas/v0.py create mode 100644 src/google/adk/sessions/schemas/v1.py diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 047340c55e..11718a653c 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -16,420 +16,59 @@ import asyncio import copy from datetime import datetime -from datetime import timezone -import json import logging -import pickle from typing import Any from typing import Optional -import uuid -from google.genai import types -from sqlalchemy import Boolean from sqlalchemy import delete -from sqlalchemy import Dialect from sqlalchemy import event -from sqlalchemy import ForeignKeyConstraint -from sqlalchemy import func from sqlalchemy import select -from sqlalchemy import Text -from sqlalchemy.dialects import mysql -from sqlalchemy.dialects import postgresql from sqlalchemy.exc import ArgumentError from sqlalchemy.ext.asyncio import async_sessionmaker from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory from sqlalchemy.ext.asyncio import create_async_engine -from sqlalchemy.ext.mutable import MutableDict -from sqlalchemy.inspection import inspect -from sqlalchemy.orm import DeclarativeBase -from sqlalchemy.orm import Mapped -from sqlalchemy.orm import mapped_column -from sqlalchemy.orm import relationship from sqlalchemy.schema import MetaData -from sqlalchemy.types import DateTime -from sqlalchemy.types import PickleType -from sqlalchemy.types import String -from sqlalchemy.types import TypeDecorator from typing_extensions import override from tzlocal import get_localzone from . import _session_util from ..errors.already_exists_error import AlreadyExistsError from ..events.event import Event -from ..events.event_actions import EventActions from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig from .base_session_service import ListSessionsResponse +from .schemas.v0 import Base as BaseV0 +from .schemas.v0 import StorageAppState as StorageAppStateV0 +from .schemas.v0 import StorageEvent as StorageEventV0 +from .schemas.v0 import StorageSession as StorageSessionV0 +from .schemas.v0 import StorageUserState as StorageUserStateV0 from .session import Session from .state import State logger = logging.getLogger("google_adk." + __name__) -DEFAULT_MAX_KEY_LENGTH = 128 -DEFAULT_MAX_VARCHAR_LENGTH = 256 - -class DynamicJSON(TypeDecorator): - """A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON serialization for other databases.""" - - impl = Text # Default implementation is TEXT - - def load_dialect_impl(self, dialect: Dialect): - if dialect.name == "postgresql": - return dialect.type_descriptor(postgresql.JSONB) - if dialect.name == "mysql": - # Use LONGTEXT for MySQL to address the data too long issue - return dialect.type_descriptor(mysql.LONGTEXT) - return dialect.type_descriptor(Text) # Default to Text for other dialects - - def process_bind_param(self, value, dialect: Dialect): - if value is not None: - if dialect.name == "postgresql": - return value # JSONB handles dict directly - return json.dumps(value) # Serialize to JSON string for TEXT - return value - - def process_result_value(self, value, dialect: Dialect): - if value is not None: - if dialect.name == "postgresql": - return value # JSONB returns dict directly - else: - return json.loads(value) # Deserialize from JSON string for TEXT - return value - - -class PreciseTimestamp(TypeDecorator): - """Represents a timestamp precise to the microsecond.""" - - impl = DateTime - cache_ok = True - - def load_dialect_impl(self, dialect): - if dialect.name == "mysql": - return dialect.type_descriptor(mysql.DATETIME(fsp=6)) - return self.impl - - -class DynamicPickleType(TypeDecorator): - """Represents a type that can be pickled.""" - - impl = PickleType - - def load_dialect_impl(self, dialect): - if dialect.name == "mysql": - return dialect.type_descriptor(mysql.LONGBLOB) - if dialect.name == "spanner+spanner": - from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType - - return dialect.type_descriptor(SpannerPickleType) - return self.impl - - def process_bind_param(self, value, dialect): - """Ensures the pickled value is a bytes object before passing it to the database dialect.""" - if value is not None: - if dialect.name in ("spanner+spanner", "mysql"): - return pickle.dumps(value) - return value - - def process_result_value(self, value, dialect): - """Ensures the raw bytes from the database are unpickled back into a Python object.""" - if value is not None: - if dialect.name in ("spanner+spanner", "mysql"): - return pickle.loads(value) - return value - - -class Base(DeclarativeBase): - """Base class for database tables.""" - - pass - - -class StorageSession(Base): - """Represents a session stored in the database.""" - - __tablename__ = "sessions" - - app_name: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - user_id: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - id: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), - primary_key=True, - default=lambda: str(uuid.uuid4()), - ) - - state: Mapped[MutableDict[str, Any]] = mapped_column( - MutableDict.as_mutable(DynamicJSON), default={} - ) - - create_time: Mapped[datetime] = mapped_column( - PreciseTimestamp, default=func.now() - ) - update_time: Mapped[datetime] = mapped_column( - PreciseTimestamp, default=func.now(), onupdate=func.now() - ) - - storage_events: Mapped[list[StorageEvent]] = relationship( - "StorageEvent", - back_populates="storage_session", - ) - - def __repr__(self): - return f"" - - @property - def _dialect_name(self) -> Optional[str]: - session = inspect(self).session - return session.bind.dialect.name if session else None - - @property - def update_timestamp_tz(self) -> datetime: - """Returns the time zone aware update timestamp.""" - if self._dialect_name == "sqlite": - # SQLite does not support timezone. SQLAlchemy returns a naive datetime - # object without timezone information. We need to convert it to UTC - # manually. - return self.update_time.replace(tzinfo=timezone.utc).timestamp() - return self.update_time.timestamp() - - def to_session( - self, - state: dict[str, Any] | None = None, - events: list[Event] | None = None, - ) -> Session: - """Converts the storage session to a session object.""" - if state is None: - state = {} - if events is None: - events = [] - - return Session( - app_name=self.app_name, - user_id=self.user_id, - id=self.id, - state=state, - events=events, - last_update_time=self.update_timestamp_tz, - ) - - -class StorageEvent(Base): - """Represents an event stored in the database.""" - - __tablename__ = "events" - - id: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - app_name: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - user_id: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - session_id: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - - invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) - author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) - actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType) - long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column( - Text, nullable=True - ) - branch: Mapped[str] = mapped_column( - String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True - ) - timestamp: Mapped[PreciseTimestamp] = mapped_column( - PreciseTimestamp, default=func.now() - ) - - # === Fields from llm_response.py === - content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) - grounding_metadata: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - custom_metadata: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - usage_metadata: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - citation_metadata: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - - partial: Mapped[bool] = mapped_column(Boolean, nullable=True) - turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True) - error_code: Mapped[str] = mapped_column( - String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True - ) - error_message: Mapped[str] = mapped_column(Text, nullable=True) - interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True) - input_transcription: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - output_transcription: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - - storage_session: Mapped[StorageSession] = relationship( - "StorageSession", - back_populates="storage_events", - ) - - __table_args__ = ( - ForeignKeyConstraint( - ["app_name", "user_id", "session_id"], - ["sessions.app_name", "sessions.user_id", "sessions.id"], - ondelete="CASCADE", - ), - ) - - @property - def long_running_tool_ids(self) -> set[str]: - return ( - set(json.loads(self.long_running_tool_ids_json)) - if self.long_running_tool_ids_json - else set() - ) - - @long_running_tool_ids.setter - def long_running_tool_ids(self, value: set[str]): - if value is None: - self.long_running_tool_ids_json = None - else: - self.long_running_tool_ids_json = json.dumps(list(value)) - - @classmethod - def from_event(cls, session: Session, event: Event) -> StorageEvent: - storage_event = StorageEvent( - id=event.id, - invocation_id=event.invocation_id, - author=event.author, - branch=event.branch, - actions=event.actions, - session_id=session.id, - app_name=session.app_name, - user_id=session.user_id, - timestamp=datetime.fromtimestamp(event.timestamp), - long_running_tool_ids=event.long_running_tool_ids, - partial=event.partial, - turn_complete=event.turn_complete, - error_code=event.error_code, - error_message=event.error_message, - interrupted=event.interrupted, - ) - if event.content: - storage_event.content = event.content.model_dump( - exclude_none=True, mode="json" - ) - if event.grounding_metadata: - storage_event.grounding_metadata = event.grounding_metadata.model_dump( - exclude_none=True, mode="json" - ) - if event.custom_metadata: - storage_event.custom_metadata = event.custom_metadata - if event.usage_metadata: - storage_event.usage_metadata = event.usage_metadata.model_dump( - exclude_none=True, mode="json" - ) - if event.citation_metadata: - storage_event.citation_metadata = event.citation_metadata.model_dump( - exclude_none=True, mode="json" - ) - if event.input_transcription: - storage_event.input_transcription = event.input_transcription.model_dump( - exclude_none=True, mode="json" - ) - if event.output_transcription: - storage_event.output_transcription = ( - event.output_transcription.model_dump(exclude_none=True, mode="json") - ) - return storage_event - - def to_event(self) -> Event: - return Event( - id=self.id, - invocation_id=self.invocation_id, - author=self.author, - branch=self.branch, - # This is needed as previous ADK version pickled actions might not have - # value defined in the current version of the EventActions model. - actions=EventActions().model_copy(update=self.actions.model_dump()), - timestamp=self.timestamp.timestamp(), - long_running_tool_ids=self.long_running_tool_ids, - partial=self.partial, - turn_complete=self.turn_complete, - error_code=self.error_code, - error_message=self.error_message, - interrupted=self.interrupted, - custom_metadata=self.custom_metadata, - content=_session_util.decode_model(self.content, types.Content), - grounding_metadata=_session_util.decode_model( - self.grounding_metadata, types.GroundingMetadata - ), - usage_metadata=_session_util.decode_model( - self.usage_metadata, types.GenerateContentResponseUsageMetadata - ), - citation_metadata=_session_util.decode_model( - self.citation_metadata, types.CitationMetadata - ), - input_transcription=_session_util.decode_model( - self.input_transcription, types.Transcription - ), - output_transcription=_session_util.decode_model( - self.output_transcription, types.Transcription - ), - ) - - -class StorageAppState(Base): - """Represents an app state stored in the database.""" - - __tablename__ = "app_states" - - app_name: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - state: Mapped[MutableDict[str, Any]] = mapped_column( - MutableDict.as_mutable(DynamicJSON), default={} - ) - update_time: Mapped[datetime] = mapped_column( - PreciseTimestamp, default=func.now(), onupdate=func.now() - ) - - -class StorageUserState(Base): - """Represents a user state stored in the database.""" - - __tablename__ = "user_states" - - app_name: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - user_id: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - state: Mapped[MutableDict[str, Any]] = mapped_column( - MutableDict.as_mutable(DynamicJSON), default={} - ) - update_time: Mapped[datetime] = mapped_column( - PreciseTimestamp, default=func.now(), onupdate=func.now() - ) - - -def set_sqlite_pragma(dbapi_connection, connection_record): +def _set_sqlite_pragma(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() +def _merge_state( + app_state: dict[str, Any], + user_state: dict[str, Any], + session_state: dict[str, Any], +) -> dict[str, Any]: + """Merge app, user, and session states into a single state dictionary.""" + merged_state = copy.deepcopy(session_state) + for key in app_state.keys(): + merged_state[State.APP_PREFIX + key] = app_state[key] + for key in user_state.keys(): + merged_state[State.USER_PREFIX + key] = user_state[key] + return merged_state + + class DatabaseSessionService(BaseSessionService): """A session service that uses a database for storage.""" @@ -440,10 +79,9 @@ def __init__(self, db_url: str, **kwargs: Any): # 3. Initialize all properties try: db_engine = create_async_engine(db_url, **kwargs) - if db_engine.dialect.name == "sqlite": # Set sqlite pragma to enable foreign keys constraints - event.listen(db_engine.sync_engine, "connect", set_sqlite_pragma) + event.listen(db_engine.sync_engine, "connect", _set_sqlite_pragma) except Exception as e: if isinstance(e, ArgumentError): @@ -485,8 +123,8 @@ async def _ensure_tables_created(self): if not self._tables_created: async with self.db_engine.begin() as conn: # Uncomment to recreate DB every time - # await conn.run_sync(Base.metadata.drop_all) - await conn.run_sync(Base.metadata.create_all) + # await conn.run_sync(BaseV0.metadata.drop_all) + await conn.run_sync(BaseV0.metadata.create_all) self._tables_created = True @override @@ -505,6 +143,9 @@ async def create_session( # 5. Return the session await self._ensure_tables_created() async with self.database_session_factory() as sql_session: + StorageSession = StorageSessionV0 + StorageAppState = StorageAppStateV0 + StorageUserState = StorageUserStateV0 if session_id and await sql_session.get( StorageSession, (app_name, user_id, session_id) @@ -573,6 +214,11 @@ async def get_session( # 2. Get all the events based on session id and filtering config # 3. Convert and return the session async with self.database_session_factory() as sql_session: + StorageSession = StorageSessionV0 + StorageEvent = StorageEventV0 + StorageAppState = StorageAppStateV0 + StorageUserState = StorageUserStateV0 + storage_session = await sql_session.get( StorageSession, (app_name, user_id, session_id) ) @@ -622,6 +268,10 @@ async def list_sessions( ) -> ListSessionsResponse: await self._ensure_tables_created() async with self.database_session_factory() as sql_session: + StorageSession = StorageSessionV0 + StorageAppState = StorageAppStateV0 + StorageUserState = StorageUserStateV0 + stmt = select(StorageSession).filter(StorageSession.app_name == app_name) if user_id is not None: stmt = stmt.filter(StorageSession.user_id == user_id) @@ -664,6 +314,8 @@ async def delete_session( ) -> None: await self._ensure_tables_created() async with self.database_session_factory() as sql_session: + StorageSession = StorageSessionV0 + stmt = delete(StorageSession).where( StorageSession.app_name == app_name, StorageSession.user_id == user_id, @@ -685,6 +337,11 @@ async def append_event(self, session: Session, event: Event) -> Event: # 2. Update session attributes based on event config # 3. Store event to table async with self.database_session_factory() as sql_session: + StorageSession = StorageSessionV0 + StorageEvent = StorageEventV0 + StorageAppState = StorageAppStateV0 + StorageUserState = StorageUserStateV0 + storage_session = await sql_session.get( StorageSession, (session.app_name, session.user_id, session.id) ) @@ -738,17 +395,3 @@ async def append_event(self, session: Session, event: Event) -> Event: # Also update the in-memory session await super().append_event(session=session, event=event) return event - - -def _merge_state( - app_state: dict[str, Any], - user_state: dict[str, Any], - session_state: dict[str, Any], -) -> dict[str, Any]: - """Merge app, user, and session states into a single state dictionary.""" - merged_state = copy.deepcopy(session_state) - for key in app_state.keys(): - merged_state[State.APP_PREFIX + key] = app_state[key] - for key in user_state.keys(): - merged_state[State.USER_PREFIX + key] = user_state[key] - return merged_state diff --git a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py index 30e77d5048..a0dd3a84a1 100644 --- a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py +++ b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py @@ -22,8 +22,8 @@ import sqlite3 import sys -from google.adk.sessions import database_session_service as dss from google.adk.sessions import sqlite_session_service as sss +from google.adk.sessions.schemas import v0 as v0_schema from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker @@ -35,7 +35,9 @@ def migrate(source_db_url: str, dest_db_path: str): logger.info(f"Connecting to source database: {source_db_url}") try: engine = create_engine(source_db_url) - dss.Base.metadata.create_all(engine) # Ensure tables exist for inspection + v0_schema.Base.metadata.create_all( + engine + ) # Ensure tables exist for inspection SourceSession = sessionmaker(bind=engine) source_session = SourceSession() except Exception as e: @@ -55,7 +57,7 @@ def migrate(source_db_url: str, dest_db_path: str): try: # Migrate app_states logger.info("Migrating app_states...") - app_states = source_session.query(dss.StorageAppState).all() + app_states = source_session.query(v0_schema.StorageAppState).all() for item in app_states: dest_cursor.execute( "INSERT INTO app_states (app_name, state, update_time) VALUES (?," @@ -70,7 +72,7 @@ def migrate(source_db_url: str, dest_db_path: str): # Migrate user_states logger.info("Migrating user_states...") - user_states = source_session.query(dss.StorageUserState).all() + user_states = source_session.query(v0_schema.StorageUserState).all() for item in user_states: dest_cursor.execute( "INSERT INTO user_states (app_name, user_id, state, update_time)" @@ -86,7 +88,7 @@ def migrate(source_db_url: str, dest_db_path: str): # Migrate sessions logger.info("Migrating sessions...") - sessions = source_session.query(dss.StorageSession).all() + sessions = source_session.query(v0_schema.StorageSession).all() for item in sessions: dest_cursor.execute( "INSERT INTO sessions (app_name, user_id, id, state, create_time," @@ -104,7 +106,7 @@ def migrate(source_db_url: str, dest_db_path: str): # Migrate events logger.info("Migrating events...") - events = source_session.query(dss.StorageEvent).all() + events = source_session.query(v0_schema.StorageEvent).all() for item in events: try: event_obj = item.to_event() diff --git a/src/google/adk/sessions/schemas/shared.py b/src/google/adk/sessions/schemas/shared.py new file mode 100644 index 0000000000..37fdf6b8c6 --- /dev/null +++ b/src/google/adk/sessions/schemas/shared.py @@ -0,0 +1,67 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import json + +from sqlalchemy import Dialect +from sqlalchemy import Text +from sqlalchemy.dialects import mysql +from sqlalchemy.dialects import postgresql +from sqlalchemy.types import DateTime +from sqlalchemy.types import TypeDecorator + +DEFAULT_MAX_KEY_LENGTH = 128 +DEFAULT_MAX_VARCHAR_LENGTH = 256 + + +class DynamicJSON(TypeDecorator): + """A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON serialization for other databases.""" + + impl = Text # Default implementation is TEXT + + def load_dialect_impl(self, dialect: Dialect): + if dialect.name == "postgresql": + return dialect.type_descriptor(postgresql.JSONB) + if dialect.name == "mysql": + # Use LONGTEXT for MySQL to address the data too long issue + return dialect.type_descriptor(mysql.LONGTEXT) + return dialect.type_descriptor(Text) # Default to Text for other dialects + + def process_bind_param(self, value, dialect: Dialect): + if value is not None: + if dialect.name == "postgresql": + return value # JSONB handles dict directly + return json.dumps(value) # Serialize to JSON string for TEXT + return value + + def process_result_value(self, value, dialect: Dialect): + if value is not None: + if dialect.name == "postgresql": + return value # JSONB returns dict directly + else: + return json.loads(value) # Deserialize from JSON string for TEXT + return value + + +class PreciseTimestamp(TypeDecorator): + """Represents a timestamp precise to the microsecond.""" + + impl = DateTime + cache_ok = True + + def load_dialect_impl(self, dialect): + if dialect.name == "mysql": + return dialect.type_descriptor(mysql.DATETIME(fsp=6)) + return self.impl diff --git a/src/google/adk/sessions/schemas/v0.py b/src/google/adk/sessions/schemas/v0.py new file mode 100644 index 0000000000..16a11218d7 --- /dev/null +++ b/src/google/adk/sessions/schemas/v0.py @@ -0,0 +1,373 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""V0 database schema for ADK versions from 1.19.0 to 1.21.0. + +This module defines SQLAlchemy models for storing session and event data +in a relational database with the EventActions object using pickle +serialization. To migrate from the schemas in earlier ADK versions to this +v0 schema, see +https://github.com/google/adk-python/blob/main/docs/upgrading_from_1_22_0.md. + +The latest schema is defined in `v1.py`. That module uses JSON serialization +for the EventActions data as well as other fields in the `events` table. See +https://github.com/google/adk-python/discussions/3605 for more details. +""" + +from __future__ import annotations + +from datetime import datetime +from datetime import timezone +import json +import pickle +from typing import Any +from typing import Optional +import uuid + +from google.genai import types +from sqlalchemy import Boolean +from sqlalchemy import ForeignKeyConstraint +from sqlalchemy import func +from sqlalchemy import Text +from sqlalchemy.dialects import mysql +from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.inspection import inspect +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship +from sqlalchemy.types import PickleType +from sqlalchemy.types import String +from sqlalchemy.types import TypeDecorator + +from .. import _session_util +from ...events.event import Event +from ...events.event_actions import EventActions +from ..session import Session +from .shared import DEFAULT_MAX_KEY_LENGTH +from .shared import DEFAULT_MAX_VARCHAR_LENGTH +from .shared import DynamicJSON +from .shared import PreciseTimestamp + + +class DynamicPickleType(TypeDecorator): + """Represents a type that can be pickled.""" + + impl = PickleType + + def load_dialect_impl(self, dialect): + if dialect.name == "mysql": + return dialect.type_descriptor(mysql.LONGBLOB) + if dialect.name == "spanner+spanner": + from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType + + return dialect.type_descriptor(SpannerPickleType) + return self.impl + + def process_bind_param(self, value, dialect): + """Ensures the pickled value is a bytes object before passing it to the database dialect.""" + if value is not None: + if dialect.name in ("spanner+spanner", "mysql"): + return pickle.dumps(value) + return value + + def process_result_value(self, value, dialect): + """Ensures the raw bytes from the database are unpickled back into a Python object.""" + if value is not None: + if dialect.name in ("spanner+spanner", "mysql"): + return pickle.loads(value) + return value + + +class Base(DeclarativeBase): + """Base class for v0 database tables.""" + + pass + + +class StorageSession(Base): + """Represents a session stored in the database.""" + + __tablename__ = "sessions" + + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), + primary_key=True, + default=lambda: str(uuid.uuid4()), + ) + + state: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(DynamicJSON), default={} + ) + + create_time: Mapped[datetime] = mapped_column( + PreciseTimestamp, default=func.now() + ) + update_time: Mapped[datetime] = mapped_column( + PreciseTimestamp, default=func.now(), onupdate=func.now() + ) + + storage_events: Mapped[list[StorageEvent]] = relationship( + "StorageEvent", + back_populates="storage_session", + ) + + def __repr__(self): + return f"" + + @property + def _dialect_name(self) -> Optional[str]: + session = inspect(self).session + return session.bind.dialect.name if session else None + + @property + def update_timestamp_tz(self) -> datetime: + """Returns the time zone aware update timestamp.""" + if self._dialect_name == "sqlite": + # SQLite does not support timezone. SQLAlchemy returns a naive datetime + # object without timezone information. We need to convert it to UTC + # manually. + return self.update_time.replace(tzinfo=timezone.utc).timestamp() + return self.update_time.timestamp() + + def to_session( + self, + state: dict[str, Any] | None = None, + events: list[Event] | None = None, + ) -> Session: + """Converts the storage session to a session object.""" + if state is None: + state = {} + if events is None: + events = [] + + return Session( + app_name=self.app_name, + user_id=self.user_id, + id=self.id, + state=state, + events=events, + last_update_time=self.update_timestamp_tz, + ) + + +class StorageEvent(Base): + """Represents an event stored in the database.""" + + __tablename__ = "events" + + id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + session_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + + invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType) + long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column( + Text, nullable=True + ) + branch: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) + timestamp: Mapped[PreciseTimestamp] = mapped_column( + PreciseTimestamp, default=func.now() + ) + + # === Fields from llm_response.py === + content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) + grounding_metadata: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + custom_metadata: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + usage_metadata: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + citation_metadata: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + + partial: Mapped[bool] = mapped_column(Boolean, nullable=True) + turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True) + error_code: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) + error_message: Mapped[str] = mapped_column(Text, nullable=True) + interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True) + input_transcription: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + output_transcription: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + + storage_session: Mapped[StorageSession] = relationship( + "StorageSession", + back_populates="storage_events", + ) + + __table_args__ = ( + ForeignKeyConstraint( + ["app_name", "user_id", "session_id"], + ["sessions.app_name", "sessions.user_id", "sessions.id"], + ondelete="CASCADE", + ), + ) + + @property + def long_running_tool_ids(self) -> set[str]: + return ( + set(json.loads(self.long_running_tool_ids_json)) + if self.long_running_tool_ids_json + else set() + ) + + @long_running_tool_ids.setter + def long_running_tool_ids(self, value: set[str]): + if value is None: + self.long_running_tool_ids_json = None + else: + self.long_running_tool_ids_json = json.dumps(list(value)) + + @classmethod + def from_event(cls, session: Session, event: Event) -> StorageEvent: + storage_event = StorageEvent( + id=event.id, + invocation_id=event.invocation_id, + author=event.author, + branch=event.branch, + actions=event.actions, + session_id=session.id, + app_name=session.app_name, + user_id=session.user_id, + timestamp=datetime.fromtimestamp(event.timestamp), + long_running_tool_ids=event.long_running_tool_ids, + partial=event.partial, + turn_complete=event.turn_complete, + error_code=event.error_code, + error_message=event.error_message, + interrupted=event.interrupted, + ) + if event.content: + storage_event.content = event.content.model_dump( + exclude_none=True, mode="json" + ) + if event.grounding_metadata: + storage_event.grounding_metadata = event.grounding_metadata.model_dump( + exclude_none=True, mode="json" + ) + if event.custom_metadata: + storage_event.custom_metadata = event.custom_metadata + if event.usage_metadata: + storage_event.usage_metadata = event.usage_metadata.model_dump( + exclude_none=True, mode="json" + ) + if event.citation_metadata: + storage_event.citation_metadata = event.citation_metadata.model_dump( + exclude_none=True, mode="json" + ) + if event.input_transcription: + storage_event.input_transcription = event.input_transcription.model_dump( + exclude_none=True, mode="json" + ) + if event.output_transcription: + storage_event.output_transcription = ( + event.output_transcription.model_dump(exclude_none=True, mode="json") + ) + return storage_event + + def to_event(self) -> Event: + return Event( + id=self.id, + invocation_id=self.invocation_id, + author=self.author, + branch=self.branch, + # This is needed as previous ADK version pickled actions might not have + # value defined in the current version of the EventActions model. + actions=EventActions().model_copy(update=self.actions.model_dump()), + timestamp=self.timestamp.timestamp(), + long_running_tool_ids=self.long_running_tool_ids, + partial=self.partial, + turn_complete=self.turn_complete, + error_code=self.error_code, + error_message=self.error_message, + interrupted=self.interrupted, + custom_metadata=self.custom_metadata, + content=_session_util.decode_model(self.content, types.Content), + grounding_metadata=_session_util.decode_model( + self.grounding_metadata, types.GroundingMetadata + ), + usage_metadata=_session_util.decode_model( + self.usage_metadata, types.GenerateContentResponseUsageMetadata + ), + citation_metadata=_session_util.decode_model( + self.citation_metadata, types.CitationMetadata + ), + input_transcription=_session_util.decode_model( + self.input_transcription, types.Transcription + ), + output_transcription=_session_util.decode_model( + self.output_transcription, types.Transcription + ), + ) + + +class StorageAppState(Base): + """Represents an app state stored in the database.""" + + __tablename__ = "app_states" + + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + state: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(DynamicJSON), default={} + ) + update_time: Mapped[datetime] = mapped_column( + PreciseTimestamp, default=func.now(), onupdate=func.now() + ) + + +class StorageUserState(Base): + """Represents a user state stored in the database.""" + + __tablename__ = "user_states" + + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + state: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(DynamicJSON), default={} + ) + update_time: Mapped[datetime] = mapped_column( + PreciseTimestamp, default=func.now(), onupdate=func.now() + ) diff --git a/src/google/adk/sessions/schemas/v1.py b/src/google/adk/sessions/schemas/v1.py new file mode 100644 index 0000000000..df309287fa --- /dev/null +++ b/src/google/adk/sessions/schemas/v1.py @@ -0,0 +1,239 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The v1 database schema for the DatabaseSessionService. + +This module defines SQLAlchemy models for storing session and event data +in a relational database with the "events" table using JSON +serialization for Event data. + +See https://github.com/google/adk-python/discussions/3605 for more details. +""" + +from __future__ import annotations + +from datetime import datetime +from datetime import timezone +from typing import Any +from typing import Optional +import uuid + +from sqlalchemy import ForeignKeyConstraint +from sqlalchemy import func +from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.inspection import inspect +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship +from sqlalchemy.types import String + +from ...events.event import Event +from ..session import Session +from .shared import DEFAULT_MAX_KEY_LENGTH +from .shared import DEFAULT_MAX_VARCHAR_LENGTH +from .shared import DynamicJSON +from .shared import PreciseTimestamp + + +class Base(DeclarativeBase): + """Base class for v1 database tables.""" + + pass + + +class StorageMetadata(Base): + """Represents ADK internal metadata stored in the database. + + This table is used to store internal information like the schema version. + The DatabaseSessionService will populate and utilize this table to manage + database compatibility and migrations. + """ + + __tablename__ = "adk_internal_metadata" + key: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + value: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + + +class StorageSession(Base): + """Represents a session stored in the database.""" + + __tablename__ = "sessions" + + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), + primary_key=True, + default=lambda: str(uuid.uuid4()), + ) + + state: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(DynamicJSON), default={} + ) + + create_time: Mapped[datetime] = mapped_column( + PreciseTimestamp, default=func.now() + ) + update_time: Mapped[datetime] = mapped_column( + PreciseTimestamp, default=func.now(), onupdate=func.now() + ) + + storage_events: Mapped[list[StorageEvent]] = relationship( + "StorageEvent", + back_populates="storage_session", + # Deleting a session will now automatically delete its associated events + cascade="all, delete-orphan", + ) + + def __repr__(self): + return f"" + + @property + def _dialect_name(self) -> Optional[str]: + session = inspect(self).session + return session.bind.dialect.name if session else None + + @property + def update_timestamp_tz(self) -> datetime: + """Returns the time zone aware update timestamp.""" + if self._dialect_name == "sqlite": + # SQLite does not support timezone. SQLAlchemy returns a naive datetime + # object without timezone information. We need to convert it to UTC + # manually. + return self.update_time.replace(tzinfo=timezone.utc).timestamp() + return self.update_time.timestamp() + + def to_session( + self, + state: dict[str, Any] | None = None, + events: list[Event] | None = None, + ) -> Session: + """Converts the storage session to a session object.""" + if state is None: + state = {} + if events is None: + events = [] + + return Session( + app_name=self.app_name, + user_id=self.user_id, + id=self.id, + state=state, + events=events, + last_update_time=self.update_timestamp_tz, + ) + + +class StorageEvent(Base): + """Represents an event stored in the database.""" + + __tablename__ = "events" + + id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + session_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + + invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + timestamp: Mapped[PreciseTimestamp] = mapped_column( + PreciseTimestamp, default=func.now() + ) + # The event_data uses JSON serialization to store the Event data, replacing + # various fields previously used. + event_data: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) + + storage_session: Mapped[StorageSession] = relationship( + "StorageSession", + back_populates="storage_events", + ) + + __table_args__ = ( + ForeignKeyConstraint( + ["app_name", "user_id", "session_id"], + ["sessions.app_name", "sessions.user_id", "sessions.id"], + ondelete="CASCADE", + ), + ) + + @classmethod + def from_event(cls, session: Session, event: Event) -> StorageEvent: + """Creates a StorageEvent from an Event.""" + return StorageEvent( + id=event.id, + invocation_id=event.invocation_id, + session_id=session.id, + app_name=session.app_name, + user_id=session.user_id, + timestamp=datetime.fromtimestamp(event.timestamp), + event_data=event.model_dump(exclude_none=True, mode="json"), + ) + + def to_event(self) -> Event: + """Converts the StorageEvent to an Event.""" + return Event.model_validate({ + **self.event_data, + "id": self.id, + "invocation_id": self.invocation_id, + "timestamp": self.timestamp.timestamp(), + }) + + +class StorageAppState(Base): + """Represents an app state stored in the database.""" + + __tablename__ = "app_states" + + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + state: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(DynamicJSON), default={} + ) + update_time: Mapped[datetime] = mapped_column( + PreciseTimestamp, default=func.now(), onupdate=func.now() + ) + + +class StorageUserState(Base): + """Represents a user state stored in the database.""" + + __tablename__ = "user_states" + + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + state: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(DynamicJSON), default={} + ) + update_time: Mapped[datetime] = mapped_column( + PreciseTimestamp, default=func.now(), onupdate=func.now() + ) diff --git a/tests/unittests/sessions/test_dynamic_pickle_type.py b/tests/unittests/sessions/test_dynamic_pickle_type.py index e4eb084f88..5164d665c0 100644 --- a/tests/unittests/sessions/test_dynamic_pickle_type.py +++ b/tests/unittests/sessions/test_dynamic_pickle_type.py @@ -17,7 +17,7 @@ import pickle from unittest import mock -from google.adk.sessions.database_session_service import DynamicPickleType +from google.adk.sessions.schemas.v0 import DynamicPickleType import pytest from sqlalchemy import create_engine from sqlalchemy.dialects import mysql