diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 047340c55e..11718a653c 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -16,420 +16,59 @@ import asyncio import copy from datetime import datetime -from datetime import timezone -import json import logging -import pickle from typing import Any from typing import Optional -import uuid -from google.genai import types -from sqlalchemy import Boolean from sqlalchemy import delete -from sqlalchemy import Dialect from sqlalchemy import event -from sqlalchemy import ForeignKeyConstraint -from sqlalchemy import func from sqlalchemy import select -from sqlalchemy import Text -from sqlalchemy.dialects import mysql -from sqlalchemy.dialects import postgresql from sqlalchemy.exc import ArgumentError from sqlalchemy.ext.asyncio import async_sessionmaker from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory from sqlalchemy.ext.asyncio import create_async_engine -from sqlalchemy.ext.mutable import MutableDict -from sqlalchemy.inspection import inspect -from sqlalchemy.orm import DeclarativeBase -from sqlalchemy.orm import Mapped -from sqlalchemy.orm import mapped_column -from sqlalchemy.orm import relationship from sqlalchemy.schema import MetaData -from sqlalchemy.types import DateTime -from sqlalchemy.types import PickleType -from sqlalchemy.types import String -from sqlalchemy.types import TypeDecorator from typing_extensions import override from tzlocal import get_localzone from . import _session_util from ..errors.already_exists_error import AlreadyExistsError from ..events.event import Event -from ..events.event_actions import EventActions from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig from .base_session_service import ListSessionsResponse +from .schemas.v0 import Base as BaseV0 +from .schemas.v0 import StorageAppState as StorageAppStateV0 +from .schemas.v0 import StorageEvent as StorageEventV0 +from .schemas.v0 import StorageSession as StorageSessionV0 +from .schemas.v0 import StorageUserState as StorageUserStateV0 from .session import Session from .state import State logger = logging.getLogger("google_adk." + __name__) -DEFAULT_MAX_KEY_LENGTH = 128 -DEFAULT_MAX_VARCHAR_LENGTH = 256 - -class DynamicJSON(TypeDecorator): - """A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON serialization for other databases.""" - - impl = Text # Default implementation is TEXT - - def load_dialect_impl(self, dialect: Dialect): - if dialect.name == "postgresql": - return dialect.type_descriptor(postgresql.JSONB) - if dialect.name == "mysql": - # Use LONGTEXT for MySQL to address the data too long issue - return dialect.type_descriptor(mysql.LONGTEXT) - return dialect.type_descriptor(Text) # Default to Text for other dialects - - def process_bind_param(self, value, dialect: Dialect): - if value is not None: - if dialect.name == "postgresql": - return value # JSONB handles dict directly - return json.dumps(value) # Serialize to JSON string for TEXT - return value - - def process_result_value(self, value, dialect: Dialect): - if value is not None: - if dialect.name == "postgresql": - return value # JSONB returns dict directly - else: - return json.loads(value) # Deserialize from JSON string for TEXT - return value - - -class PreciseTimestamp(TypeDecorator): - """Represents a timestamp precise to the microsecond.""" - - impl = DateTime - cache_ok = True - - def load_dialect_impl(self, dialect): - if dialect.name == "mysql": - return dialect.type_descriptor(mysql.DATETIME(fsp=6)) - return self.impl - - -class DynamicPickleType(TypeDecorator): - """Represents a type that can be pickled.""" - - impl = PickleType - - def load_dialect_impl(self, dialect): - if dialect.name == "mysql": - return dialect.type_descriptor(mysql.LONGBLOB) - if dialect.name == "spanner+spanner": - from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType - - return dialect.type_descriptor(SpannerPickleType) - return self.impl - - def process_bind_param(self, value, dialect): - """Ensures the pickled value is a bytes object before passing it to the database dialect.""" - if value is not None: - if dialect.name in ("spanner+spanner", "mysql"): - return pickle.dumps(value) - return value - - def process_result_value(self, value, dialect): - """Ensures the raw bytes from the database are unpickled back into a Python object.""" - if value is not None: - if dialect.name in ("spanner+spanner", "mysql"): - return pickle.loads(value) - return value - - -class Base(DeclarativeBase): - """Base class for database tables.""" - - pass - - -class StorageSession(Base): - """Represents a session stored in the database.""" - - __tablename__ = "sessions" - - app_name: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - user_id: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - id: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), - primary_key=True, - default=lambda: str(uuid.uuid4()), - ) - - state: Mapped[MutableDict[str, Any]] = mapped_column( - MutableDict.as_mutable(DynamicJSON), default={} - ) - - create_time: Mapped[datetime] = mapped_column( - PreciseTimestamp, default=func.now() - ) - update_time: Mapped[datetime] = mapped_column( - PreciseTimestamp, default=func.now(), onupdate=func.now() - ) - - storage_events: Mapped[list[StorageEvent]] = relationship( - "StorageEvent", - back_populates="storage_session", - ) - - def __repr__(self): - return f"" - - @property - def _dialect_name(self) -> Optional[str]: - session = inspect(self).session - return session.bind.dialect.name if session else None - - @property - def update_timestamp_tz(self) -> datetime: - """Returns the time zone aware update timestamp.""" - if self._dialect_name == "sqlite": - # SQLite does not support timezone. SQLAlchemy returns a naive datetime - # object without timezone information. We need to convert it to UTC - # manually. - return self.update_time.replace(tzinfo=timezone.utc).timestamp() - return self.update_time.timestamp() - - def to_session( - self, - state: dict[str, Any] | None = None, - events: list[Event] | None = None, - ) -> Session: - """Converts the storage session to a session object.""" - if state is None: - state = {} - if events is None: - events = [] - - return Session( - app_name=self.app_name, - user_id=self.user_id, - id=self.id, - state=state, - events=events, - last_update_time=self.update_timestamp_tz, - ) - - -class StorageEvent(Base): - """Represents an event stored in the database.""" - - __tablename__ = "events" - - id: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - app_name: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - user_id: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - session_id: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - - invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) - author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) - actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType) - long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column( - Text, nullable=True - ) - branch: Mapped[str] = mapped_column( - String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True - ) - timestamp: Mapped[PreciseTimestamp] = mapped_column( - PreciseTimestamp, default=func.now() - ) - - # === Fields from llm_response.py === - content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) - grounding_metadata: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - custom_metadata: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - usage_metadata: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - citation_metadata: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - - partial: Mapped[bool] = mapped_column(Boolean, nullable=True) - turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True) - error_code: Mapped[str] = mapped_column( - String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True - ) - error_message: Mapped[str] = mapped_column(Text, nullable=True) - interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True) - input_transcription: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - output_transcription: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - - storage_session: Mapped[StorageSession] = relationship( - "StorageSession", - back_populates="storage_events", - ) - - __table_args__ = ( - ForeignKeyConstraint( - ["app_name", "user_id", "session_id"], - ["sessions.app_name", "sessions.user_id", "sessions.id"], - ondelete="CASCADE", - ), - ) - - @property - def long_running_tool_ids(self) -> set[str]: - return ( - set(json.loads(self.long_running_tool_ids_json)) - if self.long_running_tool_ids_json - else set() - ) - - @long_running_tool_ids.setter - def long_running_tool_ids(self, value: set[str]): - if value is None: - self.long_running_tool_ids_json = None - else: - self.long_running_tool_ids_json = json.dumps(list(value)) - - @classmethod - def from_event(cls, session: Session, event: Event) -> StorageEvent: - storage_event = StorageEvent( - id=event.id, - invocation_id=event.invocation_id, - author=event.author, - branch=event.branch, - actions=event.actions, - session_id=session.id, - app_name=session.app_name, - user_id=session.user_id, - timestamp=datetime.fromtimestamp(event.timestamp), - long_running_tool_ids=event.long_running_tool_ids, - partial=event.partial, - turn_complete=event.turn_complete, - error_code=event.error_code, - error_message=event.error_message, - interrupted=event.interrupted, - ) - if event.content: - storage_event.content = event.content.model_dump( - exclude_none=True, mode="json" - ) - if event.grounding_metadata: - storage_event.grounding_metadata = event.grounding_metadata.model_dump( - exclude_none=True, mode="json" - ) - if event.custom_metadata: - storage_event.custom_metadata = event.custom_metadata - if event.usage_metadata: - storage_event.usage_metadata = event.usage_metadata.model_dump( - exclude_none=True, mode="json" - ) - if event.citation_metadata: - storage_event.citation_metadata = event.citation_metadata.model_dump( - exclude_none=True, mode="json" - ) - if event.input_transcription: - storage_event.input_transcription = event.input_transcription.model_dump( - exclude_none=True, mode="json" - ) - if event.output_transcription: - storage_event.output_transcription = ( - event.output_transcription.model_dump(exclude_none=True, mode="json") - ) - return storage_event - - def to_event(self) -> Event: - return Event( - id=self.id, - invocation_id=self.invocation_id, - author=self.author, - branch=self.branch, - # This is needed as previous ADK version pickled actions might not have - # value defined in the current version of the EventActions model. - actions=EventActions().model_copy(update=self.actions.model_dump()), - timestamp=self.timestamp.timestamp(), - long_running_tool_ids=self.long_running_tool_ids, - partial=self.partial, - turn_complete=self.turn_complete, - error_code=self.error_code, - error_message=self.error_message, - interrupted=self.interrupted, - custom_metadata=self.custom_metadata, - content=_session_util.decode_model(self.content, types.Content), - grounding_metadata=_session_util.decode_model( - self.grounding_metadata, types.GroundingMetadata - ), - usage_metadata=_session_util.decode_model( - self.usage_metadata, types.GenerateContentResponseUsageMetadata - ), - citation_metadata=_session_util.decode_model( - self.citation_metadata, types.CitationMetadata - ), - input_transcription=_session_util.decode_model( - self.input_transcription, types.Transcription - ), - output_transcription=_session_util.decode_model( - self.output_transcription, types.Transcription - ), - ) - - -class StorageAppState(Base): - """Represents an app state stored in the database.""" - - __tablename__ = "app_states" - - app_name: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - state: Mapped[MutableDict[str, Any]] = mapped_column( - MutableDict.as_mutable(DynamicJSON), default={} - ) - update_time: Mapped[datetime] = mapped_column( - PreciseTimestamp, default=func.now(), onupdate=func.now() - ) - - -class StorageUserState(Base): - """Represents a user state stored in the database.""" - - __tablename__ = "user_states" - - app_name: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - user_id: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - state: Mapped[MutableDict[str, Any]] = mapped_column( - MutableDict.as_mutable(DynamicJSON), default={} - ) - update_time: Mapped[datetime] = mapped_column( - PreciseTimestamp, default=func.now(), onupdate=func.now() - ) - - -def set_sqlite_pragma(dbapi_connection, connection_record): +def _set_sqlite_pragma(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() +def _merge_state( + app_state: dict[str, Any], + user_state: dict[str, Any], + session_state: dict[str, Any], +) -> dict[str, Any]: + """Merge app, user, and session states into a single state dictionary.""" + merged_state = copy.deepcopy(session_state) + for key in app_state.keys(): + merged_state[State.APP_PREFIX + key] = app_state[key] + for key in user_state.keys(): + merged_state[State.USER_PREFIX + key] = user_state[key] + return merged_state + + class DatabaseSessionService(BaseSessionService): """A session service that uses a database for storage.""" @@ -440,10 +79,9 @@ def __init__(self, db_url: str, **kwargs: Any): # 3. Initialize all properties try: db_engine = create_async_engine(db_url, **kwargs) - if db_engine.dialect.name == "sqlite": # Set sqlite pragma to enable foreign keys constraints - event.listen(db_engine.sync_engine, "connect", set_sqlite_pragma) + event.listen(db_engine.sync_engine, "connect", _set_sqlite_pragma) except Exception as e: if isinstance(e, ArgumentError): @@ -485,8 +123,8 @@ async def _ensure_tables_created(self): if not self._tables_created: async with self.db_engine.begin() as conn: # Uncomment to recreate DB every time - # await conn.run_sync(Base.metadata.drop_all) - await conn.run_sync(Base.metadata.create_all) + # await conn.run_sync(BaseV0.metadata.drop_all) + await conn.run_sync(BaseV0.metadata.create_all) self._tables_created = True @override @@ -505,6 +143,9 @@ async def create_session( # 5. Return the session await self._ensure_tables_created() async with self.database_session_factory() as sql_session: + StorageSession = StorageSessionV0 + StorageAppState = StorageAppStateV0 + StorageUserState = StorageUserStateV0 if session_id and await sql_session.get( StorageSession, (app_name, user_id, session_id) @@ -573,6 +214,11 @@ async def get_session( # 2. Get all the events based on session id and filtering config # 3. Convert and return the session async with self.database_session_factory() as sql_session: + StorageSession = StorageSessionV0 + StorageEvent = StorageEventV0 + StorageAppState = StorageAppStateV0 + StorageUserState = StorageUserStateV0 + storage_session = await sql_session.get( StorageSession, (app_name, user_id, session_id) ) @@ -622,6 +268,10 @@ async def list_sessions( ) -> ListSessionsResponse: await self._ensure_tables_created() async with self.database_session_factory() as sql_session: + StorageSession = StorageSessionV0 + StorageAppState = StorageAppStateV0 + StorageUserState = StorageUserStateV0 + stmt = select(StorageSession).filter(StorageSession.app_name == app_name) if user_id is not None: stmt = stmt.filter(StorageSession.user_id == user_id) @@ -664,6 +314,8 @@ async def delete_session( ) -> None: await self._ensure_tables_created() async with self.database_session_factory() as sql_session: + StorageSession = StorageSessionV0 + stmt = delete(StorageSession).where( StorageSession.app_name == app_name, StorageSession.user_id == user_id, @@ -685,6 +337,11 @@ async def append_event(self, session: Session, event: Event) -> Event: # 2. Update session attributes based on event config # 3. Store event to table async with self.database_session_factory() as sql_session: + StorageSession = StorageSessionV0 + StorageEvent = StorageEventV0 + StorageAppState = StorageAppStateV0 + StorageUserState = StorageUserStateV0 + storage_session = await sql_session.get( StorageSession, (session.app_name, session.user_id, session.id) ) @@ -738,17 +395,3 @@ async def append_event(self, session: Session, event: Event) -> Event: # Also update the in-memory session await super().append_event(session=session, event=event) return event - - -def _merge_state( - app_state: dict[str, Any], - user_state: dict[str, Any], - session_state: dict[str, Any], -) -> dict[str, Any]: - """Merge app, user, and session states into a single state dictionary.""" - merged_state = copy.deepcopy(session_state) - for key in app_state.keys(): - merged_state[State.APP_PREFIX + key] = app_state[key] - for key in user_state.keys(): - merged_state[State.USER_PREFIX + key] = user_state[key] - return merged_state diff --git a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py index 30e77d5048..a0dd3a84a1 100644 --- a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py +++ b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py @@ -22,8 +22,8 @@ import sqlite3 import sys -from google.adk.sessions import database_session_service as dss from google.adk.sessions import sqlite_session_service as sss +from google.adk.sessions.schemas import v0 as v0_schema from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker @@ -35,7 +35,9 @@ def migrate(source_db_url: str, dest_db_path: str): logger.info(f"Connecting to source database: {source_db_url}") try: engine = create_engine(source_db_url) - dss.Base.metadata.create_all(engine) # Ensure tables exist for inspection + v0_schema.Base.metadata.create_all( + engine + ) # Ensure tables exist for inspection SourceSession = sessionmaker(bind=engine) source_session = SourceSession() except Exception as e: @@ -55,7 +57,7 @@ def migrate(source_db_url: str, dest_db_path: str): try: # Migrate app_states logger.info("Migrating app_states...") - app_states = source_session.query(dss.StorageAppState).all() + app_states = source_session.query(v0_schema.StorageAppState).all() for item in app_states: dest_cursor.execute( "INSERT INTO app_states (app_name, state, update_time) VALUES (?," @@ -70,7 +72,7 @@ def migrate(source_db_url: str, dest_db_path: str): # Migrate user_states logger.info("Migrating user_states...") - user_states = source_session.query(dss.StorageUserState).all() + user_states = source_session.query(v0_schema.StorageUserState).all() for item in user_states: dest_cursor.execute( "INSERT INTO user_states (app_name, user_id, state, update_time)" @@ -86,7 +88,7 @@ def migrate(source_db_url: str, dest_db_path: str): # Migrate sessions logger.info("Migrating sessions...") - sessions = source_session.query(dss.StorageSession).all() + sessions = source_session.query(v0_schema.StorageSession).all() for item in sessions: dest_cursor.execute( "INSERT INTO sessions (app_name, user_id, id, state, create_time," @@ -104,7 +106,7 @@ def migrate(source_db_url: str, dest_db_path: str): # Migrate events logger.info("Migrating events...") - events = source_session.query(dss.StorageEvent).all() + events = source_session.query(v0_schema.StorageEvent).all() for item in events: try: event_obj = item.to_event() diff --git a/src/google/adk/sessions/schemas/shared.py b/src/google/adk/sessions/schemas/shared.py new file mode 100644 index 0000000000..37fdf6b8c6 --- /dev/null +++ b/src/google/adk/sessions/schemas/shared.py @@ -0,0 +1,67 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import json + +from sqlalchemy import Dialect +from sqlalchemy import Text +from sqlalchemy.dialects import mysql +from sqlalchemy.dialects import postgresql +from sqlalchemy.types import DateTime +from sqlalchemy.types import TypeDecorator + +DEFAULT_MAX_KEY_LENGTH = 128 +DEFAULT_MAX_VARCHAR_LENGTH = 256 + + +class DynamicJSON(TypeDecorator): + """A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON serialization for other databases.""" + + impl = Text # Default implementation is TEXT + + def load_dialect_impl(self, dialect: Dialect): + if dialect.name == "postgresql": + return dialect.type_descriptor(postgresql.JSONB) + if dialect.name == "mysql": + # Use LONGTEXT for MySQL to address the data too long issue + return dialect.type_descriptor(mysql.LONGTEXT) + return dialect.type_descriptor(Text) # Default to Text for other dialects + + def process_bind_param(self, value, dialect: Dialect): + if value is not None: + if dialect.name == "postgresql": + return value # JSONB handles dict directly + return json.dumps(value) # Serialize to JSON string for TEXT + return value + + def process_result_value(self, value, dialect: Dialect): + if value is not None: + if dialect.name == "postgresql": + return value # JSONB returns dict directly + else: + return json.loads(value) # Deserialize from JSON string for TEXT + return value + + +class PreciseTimestamp(TypeDecorator): + """Represents a timestamp precise to the microsecond.""" + + impl = DateTime + cache_ok = True + + def load_dialect_impl(self, dialect): + if dialect.name == "mysql": + return dialect.type_descriptor(mysql.DATETIME(fsp=6)) + return self.impl diff --git a/src/google/adk/sessions/schemas/v0.py b/src/google/adk/sessions/schemas/v0.py new file mode 100644 index 0000000000..16a11218d7 --- /dev/null +++ b/src/google/adk/sessions/schemas/v0.py @@ -0,0 +1,373 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""V0 database schema for ADK versions from 1.19.0 to 1.21.0. + +This module defines SQLAlchemy models for storing session and event data +in a relational database with the EventActions object using pickle +serialization. To migrate from the schemas in earlier ADK versions to this +v0 schema, see +https://github.com/google/adk-python/blob/main/docs/upgrading_from_1_22_0.md. + +The latest schema is defined in `v1.py`. That module uses JSON serialization +for the EventActions data as well as other fields in the `events` table. See +https://github.com/google/adk-python/discussions/3605 for more details. +""" + +from __future__ import annotations + +from datetime import datetime +from datetime import timezone +import json +import pickle +from typing import Any +from typing import Optional +import uuid + +from google.genai import types +from sqlalchemy import Boolean +from sqlalchemy import ForeignKeyConstraint +from sqlalchemy import func +from sqlalchemy import Text +from sqlalchemy.dialects import mysql +from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.inspection import inspect +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship +from sqlalchemy.types import PickleType +from sqlalchemy.types import String +from sqlalchemy.types import TypeDecorator + +from .. import _session_util +from ...events.event import Event +from ...events.event_actions import EventActions +from ..session import Session +from .shared import DEFAULT_MAX_KEY_LENGTH +from .shared import DEFAULT_MAX_VARCHAR_LENGTH +from .shared import DynamicJSON +from .shared import PreciseTimestamp + + +class DynamicPickleType(TypeDecorator): + """Represents a type that can be pickled.""" + + impl = PickleType + + def load_dialect_impl(self, dialect): + if dialect.name == "mysql": + return dialect.type_descriptor(mysql.LONGBLOB) + if dialect.name == "spanner+spanner": + from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType + + return dialect.type_descriptor(SpannerPickleType) + return self.impl + + def process_bind_param(self, value, dialect): + """Ensures the pickled value is a bytes object before passing it to the database dialect.""" + if value is not None: + if dialect.name in ("spanner+spanner", "mysql"): + return pickle.dumps(value) + return value + + def process_result_value(self, value, dialect): + """Ensures the raw bytes from the database are unpickled back into a Python object.""" + if value is not None: + if dialect.name in ("spanner+spanner", "mysql"): + return pickle.loads(value) + return value + + +class Base(DeclarativeBase): + """Base class for v0 database tables.""" + + pass + + +class StorageSession(Base): + """Represents a session stored in the database.""" + + __tablename__ = "sessions" + + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), + primary_key=True, + default=lambda: str(uuid.uuid4()), + ) + + state: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(DynamicJSON), default={} + ) + + create_time: Mapped[datetime] = mapped_column( + PreciseTimestamp, default=func.now() + ) + update_time: Mapped[datetime] = mapped_column( + PreciseTimestamp, default=func.now(), onupdate=func.now() + ) + + storage_events: Mapped[list[StorageEvent]] = relationship( + "StorageEvent", + back_populates="storage_session", + ) + + def __repr__(self): + return f"" + + @property + def _dialect_name(self) -> Optional[str]: + session = inspect(self).session + return session.bind.dialect.name if session else None + + @property + def update_timestamp_tz(self) -> datetime: + """Returns the time zone aware update timestamp.""" + if self._dialect_name == "sqlite": + # SQLite does not support timezone. SQLAlchemy returns a naive datetime + # object without timezone information. We need to convert it to UTC + # manually. + return self.update_time.replace(tzinfo=timezone.utc).timestamp() + return self.update_time.timestamp() + + def to_session( + self, + state: dict[str, Any] | None = None, + events: list[Event] | None = None, + ) -> Session: + """Converts the storage session to a session object.""" + if state is None: + state = {} + if events is None: + events = [] + + return Session( + app_name=self.app_name, + user_id=self.user_id, + id=self.id, + state=state, + events=events, + last_update_time=self.update_timestamp_tz, + ) + + +class StorageEvent(Base): + """Represents an event stored in the database.""" + + __tablename__ = "events" + + id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + session_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + + invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType) + long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column( + Text, nullable=True + ) + branch: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) + timestamp: Mapped[PreciseTimestamp] = mapped_column( + PreciseTimestamp, default=func.now() + ) + + # === Fields from llm_response.py === + content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) + grounding_metadata: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + custom_metadata: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + usage_metadata: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + citation_metadata: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + + partial: Mapped[bool] = mapped_column(Boolean, nullable=True) + turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True) + error_code: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) + error_message: Mapped[str] = mapped_column(Text, nullable=True) + interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True) + input_transcription: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + output_transcription: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + + storage_session: Mapped[StorageSession] = relationship( + "StorageSession", + back_populates="storage_events", + ) + + __table_args__ = ( + ForeignKeyConstraint( + ["app_name", "user_id", "session_id"], + ["sessions.app_name", "sessions.user_id", "sessions.id"], + ondelete="CASCADE", + ), + ) + + @property + def long_running_tool_ids(self) -> set[str]: + return ( + set(json.loads(self.long_running_tool_ids_json)) + if self.long_running_tool_ids_json + else set() + ) + + @long_running_tool_ids.setter + def long_running_tool_ids(self, value: set[str]): + if value is None: + self.long_running_tool_ids_json = None + else: + self.long_running_tool_ids_json = json.dumps(list(value)) + + @classmethod + def from_event(cls, session: Session, event: Event) -> StorageEvent: + storage_event = StorageEvent( + id=event.id, + invocation_id=event.invocation_id, + author=event.author, + branch=event.branch, + actions=event.actions, + session_id=session.id, + app_name=session.app_name, + user_id=session.user_id, + timestamp=datetime.fromtimestamp(event.timestamp), + long_running_tool_ids=event.long_running_tool_ids, + partial=event.partial, + turn_complete=event.turn_complete, + error_code=event.error_code, + error_message=event.error_message, + interrupted=event.interrupted, + ) + if event.content: + storage_event.content = event.content.model_dump( + exclude_none=True, mode="json" + ) + if event.grounding_metadata: + storage_event.grounding_metadata = event.grounding_metadata.model_dump( + exclude_none=True, mode="json" + ) + if event.custom_metadata: + storage_event.custom_metadata = event.custom_metadata + if event.usage_metadata: + storage_event.usage_metadata = event.usage_metadata.model_dump( + exclude_none=True, mode="json" + ) + if event.citation_metadata: + storage_event.citation_metadata = event.citation_metadata.model_dump( + exclude_none=True, mode="json" + ) + if event.input_transcription: + storage_event.input_transcription = event.input_transcription.model_dump( + exclude_none=True, mode="json" + ) + if event.output_transcription: + storage_event.output_transcription = ( + event.output_transcription.model_dump(exclude_none=True, mode="json") + ) + return storage_event + + def to_event(self) -> Event: + return Event( + id=self.id, + invocation_id=self.invocation_id, + author=self.author, + branch=self.branch, + # This is needed as previous ADK version pickled actions might not have + # value defined in the current version of the EventActions model. + actions=EventActions().model_copy(update=self.actions.model_dump()), + timestamp=self.timestamp.timestamp(), + long_running_tool_ids=self.long_running_tool_ids, + partial=self.partial, + turn_complete=self.turn_complete, + error_code=self.error_code, + error_message=self.error_message, + interrupted=self.interrupted, + custom_metadata=self.custom_metadata, + content=_session_util.decode_model(self.content, types.Content), + grounding_metadata=_session_util.decode_model( + self.grounding_metadata, types.GroundingMetadata + ), + usage_metadata=_session_util.decode_model( + self.usage_metadata, types.GenerateContentResponseUsageMetadata + ), + citation_metadata=_session_util.decode_model( + self.citation_metadata, types.CitationMetadata + ), + input_transcription=_session_util.decode_model( + self.input_transcription, types.Transcription + ), + output_transcription=_session_util.decode_model( + self.output_transcription, types.Transcription + ), + ) + + +class StorageAppState(Base): + """Represents an app state stored in the database.""" + + __tablename__ = "app_states" + + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + state: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(DynamicJSON), default={} + ) + update_time: Mapped[datetime] = mapped_column( + PreciseTimestamp, default=func.now(), onupdate=func.now() + ) + + +class StorageUserState(Base): + """Represents a user state stored in the database.""" + + __tablename__ = "user_states" + + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + state: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(DynamicJSON), default={} + ) + update_time: Mapped[datetime] = mapped_column( + PreciseTimestamp, default=func.now(), onupdate=func.now() + ) diff --git a/src/google/adk/sessions/schemas/v1.py b/src/google/adk/sessions/schemas/v1.py new file mode 100644 index 0000000000..df309287fa --- /dev/null +++ b/src/google/adk/sessions/schemas/v1.py @@ -0,0 +1,239 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The v1 database schema for the DatabaseSessionService. + +This module defines SQLAlchemy models for storing session and event data +in a relational database with the "events" table using JSON +serialization for Event data. + +See https://github.com/google/adk-python/discussions/3605 for more details. +""" + +from __future__ import annotations + +from datetime import datetime +from datetime import timezone +from typing import Any +from typing import Optional +import uuid + +from sqlalchemy import ForeignKeyConstraint +from sqlalchemy import func +from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.inspection import inspect +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship +from sqlalchemy.types import String + +from ...events.event import Event +from ..session import Session +from .shared import DEFAULT_MAX_KEY_LENGTH +from .shared import DEFAULT_MAX_VARCHAR_LENGTH +from .shared import DynamicJSON +from .shared import PreciseTimestamp + + +class Base(DeclarativeBase): + """Base class for v1 database tables.""" + + pass + + +class StorageMetadata(Base): + """Represents ADK internal metadata stored in the database. + + This table is used to store internal information like the schema version. + The DatabaseSessionService will populate and utilize this table to manage + database compatibility and migrations. + """ + + __tablename__ = "adk_internal_metadata" + key: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + value: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + + +class StorageSession(Base): + """Represents a session stored in the database.""" + + __tablename__ = "sessions" + + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), + primary_key=True, + default=lambda: str(uuid.uuid4()), + ) + + state: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(DynamicJSON), default={} + ) + + create_time: Mapped[datetime] = mapped_column( + PreciseTimestamp, default=func.now() + ) + update_time: Mapped[datetime] = mapped_column( + PreciseTimestamp, default=func.now(), onupdate=func.now() + ) + + storage_events: Mapped[list[StorageEvent]] = relationship( + "StorageEvent", + back_populates="storage_session", + # Deleting a session will now automatically delete its associated events + cascade="all, delete-orphan", + ) + + def __repr__(self): + return f"" + + @property + def _dialect_name(self) -> Optional[str]: + session = inspect(self).session + return session.bind.dialect.name if session else None + + @property + def update_timestamp_tz(self) -> datetime: + """Returns the time zone aware update timestamp.""" + if self._dialect_name == "sqlite": + # SQLite does not support timezone. SQLAlchemy returns a naive datetime + # object without timezone information. We need to convert it to UTC + # manually. + return self.update_time.replace(tzinfo=timezone.utc).timestamp() + return self.update_time.timestamp() + + def to_session( + self, + state: dict[str, Any] | None = None, + events: list[Event] | None = None, + ) -> Session: + """Converts the storage session to a session object.""" + if state is None: + state = {} + if events is None: + events = [] + + return Session( + app_name=self.app_name, + user_id=self.user_id, + id=self.id, + state=state, + events=events, + last_update_time=self.update_timestamp_tz, + ) + + +class StorageEvent(Base): + """Represents an event stored in the database.""" + + __tablename__ = "events" + + id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + session_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + + invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + timestamp: Mapped[PreciseTimestamp] = mapped_column( + PreciseTimestamp, default=func.now() + ) + # The event_data uses JSON serialization to store the Event data, replacing + # various fields previously used. + event_data: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) + + storage_session: Mapped[StorageSession] = relationship( + "StorageSession", + back_populates="storage_events", + ) + + __table_args__ = ( + ForeignKeyConstraint( + ["app_name", "user_id", "session_id"], + ["sessions.app_name", "sessions.user_id", "sessions.id"], + ondelete="CASCADE", + ), + ) + + @classmethod + def from_event(cls, session: Session, event: Event) -> StorageEvent: + """Creates a StorageEvent from an Event.""" + return StorageEvent( + id=event.id, + invocation_id=event.invocation_id, + session_id=session.id, + app_name=session.app_name, + user_id=session.user_id, + timestamp=datetime.fromtimestamp(event.timestamp), + event_data=event.model_dump(exclude_none=True, mode="json"), + ) + + def to_event(self) -> Event: + """Converts the StorageEvent to an Event.""" + return Event.model_validate({ + **self.event_data, + "id": self.id, + "invocation_id": self.invocation_id, + "timestamp": self.timestamp.timestamp(), + }) + + +class StorageAppState(Base): + """Represents an app state stored in the database.""" + + __tablename__ = "app_states" + + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + state: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(DynamicJSON), default={} + ) + update_time: Mapped[datetime] = mapped_column( + PreciseTimestamp, default=func.now(), onupdate=func.now() + ) + + +class StorageUserState(Base): + """Represents a user state stored in the database.""" + + __tablename__ = "user_states" + + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + state: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(DynamicJSON), default={} + ) + update_time: Mapped[datetime] = mapped_column( + PreciseTimestamp, default=func.now(), onupdate=func.now() + ) diff --git a/tests/unittests/sessions/test_dynamic_pickle_type.py b/tests/unittests/sessions/test_dynamic_pickle_type.py index e4eb084f88..5164d665c0 100644 --- a/tests/unittests/sessions/test_dynamic_pickle_type.py +++ b/tests/unittests/sessions/test_dynamic_pickle_type.py @@ -17,7 +17,7 @@ import pickle from unittest import mock -from google.adk.sessions.database_session_service import DynamicPickleType +from google.adk.sessions.schemas.v0 import DynamicPickleType import pytest from sqlalchemy import create_engine from sqlalchemy.dialects import mysql