From d47c3a328578eda672812184c60ffcebae261aa6 Mon Sep 17 00:00:00 2001 From: Daniil Yarmalkevich Date: Fri, 14 Nov 2025 18:27:00 +0200 Subject: [PATCH 1/6] code drop 2025-11-14 --- .env.template | 3 + CODE_STYLE.md | 22 +- Makefile | 3 + poetry.lock | 8 +- pyproject.toml | 2 +- ..._add_data_hashes_for_a_channel_dataset_.py | 64 +++++ src/admin_portal/app.py | 6 +- src/admin_portal/auth/oidc.py | 5 +- src/admin_portal/routers/channel.py | 12 + src/admin_portal/services/background_tasks.py | 41 ++- src/admin_portal/services/dataset.py | 261 ++++++++++++++++-- src/admin_portal/settings/exim.py | 8 + src/common/config/versions.py | 2 +- src/common/data/base/dataset.py | 18 +- src/common/data/base/datasource.py | 18 ++ src/common/data/base/dimension.py | 10 + src/common/data/quanthub/v21/dataset.py | 4 +- src/common/data/quanthub/v21/datasource.py | 24 +- src/common/data/sdmx/common/codelist.py | 34 ++- src/common/data/sdmx/common/dimension.py | 5 +- src/common/data/sdmx/v21/dataflow_loader.py | 9 +- src/common/data/sdmx/v21/dataset.py | 192 ++++++++----- src/common/data/sdmx/v21/datasource.py | 108 +++++++- src/common/hybrid_indexer/indexer.py | 2 +- src/common/models/models.py | 11 +- src/common/schemas/__init__.py | 9 +- src/common/schemas/channel_dataset.py | 33 ++- src/common/schemas/dial.py | 67 +---- src/common/utils/__init__.py | 3 + src/common/utils/elastic.py | 24 +- src/common/utils/misc.py | 46 ++- src/common/utils/timer.py | 20 +- src/statgpt/app.py | 3 +- src/statgpt/application/app_factory.py | 2 +- .../query_builder/query/finalize_query.py | 4 +- src/statgpt/chains/main.py | 10 +- src/statgpt/chains/supreme_agent.py | 8 +- .../chains/web_search/response_producer.py | 2 +- src/statgpt/schemas/__init__.py | 2 +- src/statgpt/schemas/dial_app_configuration.py | 6 +- src/statgpt/schemas/service.py | 10 + src/statgpt/services/chat_facade.py | 25 +- src/statgpt/services/onboarding.py | 40 ++- src/statgpt/utils/message_history.py | 4 +- .../utils/message_interceptors/base.py | 2 +- .../commands_interceptor.py | 2 +- .../system_msg_interceptor.py | 2 +- 47 files changed, 906 insertions(+), 290 deletions(-) create mode 100644 src/admin_portal/alembic/versions/2025_11_14_1400-d528d881ece8_add_data_hashes_for_a_channel_dataset_.py diff --git a/.env.template b/.env.template index 80f52ed..abdd7f4 100644 --- a/.env.template +++ b/.env.template @@ -1,6 +1,9 @@ # Poetry POETRY_PYTHON=python3 +# AI DIAL SDK +PYDANTIC_V2=True + # DataBase PGVECTOR_HOST=localhost PGVECTOR_PORT=5432 diff --git a/CODE_STYLE.md b/CODE_STYLE.md index 46a7013..cc69cb2 100644 --- a/CODE_STYLE.md +++ b/CODE_STYLE.md @@ -59,20 +59,22 @@ readability, and maintainability across the codebase. - **Use the correct Built-In Generics**: - It’s recommended to use built-in generic types instead of `typing.List`, `typing.Dict`, `typing.Tuple`, etc. For example: - - Use `list[str]` instead of `typing.List[str]`. - - Use `dict[str, int]` instead of `typing.Dict[str, int]`. - - Use `tuple[str, float]` instead of `typing.Tuple[str, float]`. + - Use `list[str]` instead of `typing.List[str]`. + - Use `dict[str, int]` instead of `typing.Dict[str, int]`. + - Use `tuple[str, float]` instead of `typing.Tuple[str, float]`. - Some other generics have been moved from the `typing` module to `collections.abc.` It is recommended to import them from the new module. For, example: - - Use `collections.abc.Iterable` instead of `typing.Iterable`. - - Use `collections.abc.Iterator` instead of `typing.Iterator`. - - Use `collections.abc.Callable` instead of `typing.Callable`. + - Use `collections.abc.Iterable` instead of `typing.Iterable`. + - Use `collections.abc.Iterator` instead of `typing.Iterator`. + - Use `collections.abc.Callable` instead of `typing.Callable`. - **Annotations for Readability**: - - Write function signatures with type hints, e.g., `def my_func(name: str) -> None: ...`. - - This improves code readability and assists with IDE-based autocompletion. + - Write function signatures with type hints, e.g., `def my_func(name: str) -> None: ...`. + - This improves code readability and assists with IDE-based autocompletion. - **Union Types**: - - In Python 3.10+, you can use the “pipe” (`|`) symbol to indicate union types. For example, `str | None` instead of - `Optional[str]`. + - In Python 3.10+, you can use the “pipe” (`|`) symbol to indicate union types. For example, `str | None` instead of `Optional[str]` +- **Factory Methods** + - For a `@classmethod` that acts as a factory method, use `typing.Self` as return type. + [Reference](https://docs.python.org/3/library/typing.html#typing.Self) ## 8. Code Organization diff --git a/Makefile b/Makefile index 5284310..290fba0 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,9 @@ MYPY_DIRS = src/common src/admin_portal src/statgpt -include .env export +# AI DIAL SDK: pydantic v2 mode +export PYDANTIC_V2=True + remove_venv: poetry env remove --all || true $(POETRY_PYTHON) -m venv .venv diff --git a/poetry.lock b/poetry.lock index ab23c3d..80aa440 100644 --- a/poetry.lock +++ b/poetry.lock @@ -28,14 +28,14 @@ s3fs = ["s3fs (>=2024.3.1,<2025.0.0)"] [[package]] name = "aidial-sdk" -version = "0.25.1" +version = "0.27.0" description = "Framework to create applications and model adapters for AI DIAL" optional = false python-versions = "<4.0,>=3.9" groups = ["main"] files = [ - {file = "aidial_sdk-0.25.1-py3-none-any.whl", hash = "sha256:e1db54564991456765f4031cb8bce0829255e3c360748216742198805a87e1fb"}, - {file = "aidial_sdk-0.25.1.tar.gz", hash = "sha256:f7110bc42d85fa0a40dc113eca3147d0b6de10d3ac578966bbb23db6d8dddc51"}, + {file = "aidial_sdk-0.27.0-py3-none-any.whl", hash = "sha256:efa5be2c81432df2eb8989792dd3fc8fd2cb351c81f9881717a795556d1c7608"}, + {file = "aidial_sdk-0.27.0.tar.gz", hash = "sha256:fca4ea9f1085a02ee6a517d5314074973d6a17d3f9dbb2b5ecfac4f48ba2e48c"}, ] [package.dependencies] @@ -7949,4 +7949,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.1" python-versions = ">=3.11,<3.12" -content-hash = "2dc04847af2ba8fd284461409382b9f270abf4728c0d08bca4a0c860240a70b9" +content-hash = "73e2796d82a9bc9fd3d9300a52686fbe065e1b556f0125d950f68ec10a918e64" diff --git a/pyproject.toml b/pyproject.toml index 1cbaee2..f1fdc03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ license = "MIT" license-files = ["LICENSE"] dependencies = [ # Core framework & API - 'aidial-sdk[telemetry] (>=0.25.1,<0.26.0)', # DIAL integration SDK + 'aidial-sdk[telemetry] (>=0.27.0,<0.28.0)', # DIAL integration SDK 'fastapi (>=0.121.0,<0.122.0)', # Web framework 'pydantic (>=2.11.7,<3.0.0)', # Data validation 'pydantic-core (>=2.33.2,<3.0.0)', # Pydantic core functionality diff --git a/src/admin_portal/alembic/versions/2025_11_14_1400-d528d881ece8_add_data_hashes_for_a_channel_dataset_.py b/src/admin_portal/alembic/versions/2025_11_14_1400-d528d881ece8_add_data_hashes_for_a_channel_dataset_.py new file mode 100644 index 0000000..71c760a --- /dev/null +++ b/src/admin_portal/alembic/versions/2025_11_14_1400-d528d881ece8_add_data_hashes_for_a_channel_dataset_.py @@ -0,0 +1,64 @@ +"""Add data hashes for a channel dataset version + +Revision ID: d528d881ece8 +Revises: 10fe795dc09d +Create Date: 2025-11-14 14:00:16.313951 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'd528d881ece8' +down_revision: str | None = '10fe795dc09d' +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column( + 'channel_dataset_versions', + sa.Column( + 'structure_metadata', + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + server_default=None, + ), + ) + op.add_column( + 'channel_dataset_versions', + sa.Column('structure_hash', sa.String(length=10), nullable=True, server_default=None), + ) + op.add_column( + 'channel_dataset_versions', + sa.Column( + 'indicator_dimensions_hash', sa.String(length=10), nullable=True, server_default=None + ), + ) + op.add_column( + 'channel_dataset_versions', + sa.Column( + 'non_indicator_dimensions_hash', + sa.String(length=10), + nullable=True, + server_default=None, + ), + ) + op.add_column( + 'channel_dataset_versions', + sa.Column( + 'special_dimensions_hash', sa.String(length=10), nullable=True, server_default=None + ), + ) + + +def downgrade() -> None: + op.drop_column('channel_dataset_versions', 'special_dimensions_hash') + op.drop_column('channel_dataset_versions', 'non_indicator_dimensions_hash') + op.drop_column('channel_dataset_versions', 'indicator_dimensions_hash') + op.drop_column('channel_dataset_versions', 'structure_hash') + op.drop_column('channel_dataset_versions', 'structure_metadata') diff --git a/src/admin_portal/app.py b/src/admin_portal/app.py index 3dbe366..8e4864e 100644 --- a/src/admin_portal/app.py +++ b/src/admin_portal/app.py @@ -7,7 +7,7 @@ import dotenv from aidial_sdk.telemetry.init import init_telemetry from aidial_sdk.telemetry.types import MetricsConfig, TelemetryConfig, TracingConfig -from fastapi import FastAPI +from fastapi import FastAPI, status module_path = Path(__file__).parent.parent.absolute() sys.path.append(str(module_path)) @@ -58,8 +58,8 @@ async def lifespan(app_: FastAPI): app.include_router(router) -@app.get("/health") -def health(): +@app.get("/health", status_code=status.HTTP_200_OK) +async def health(): return {"status": "ok"} diff --git a/src/admin_portal/auth/oidc.py b/src/admin_portal/auth/oidc.py index f02650a..e372aa4 100644 --- a/src/admin_portal/auth/oidc.py +++ b/src/admin_portal/auth/oidc.py @@ -1,4 +1,5 @@ import typing as t +from collections.abc import Iterable from dataclasses import dataclass import jwt @@ -84,7 +85,7 @@ def __str__(self): class AdminGroupsClaimValidator(TokenPayloadValidator): - def __init__(self, admin_group_claim: str, admin_groups_values: t.Iterable[str]): + def __init__(self, admin_group_claim: str, admin_groups_values: Iterable[str]): self.admin_group_claim = admin_group_claim self.admin_groups_values = admin_groups_values @@ -123,7 +124,7 @@ def validate(self, token_payload: dict): class TokenValidator(TokenPayloadValidator): - def __init__(self, validators: t.Iterable[TokenPayloadValidator]): + def __init__(self, validators: Iterable[TokenPayloadValidator]): self.validators = validators def validate(self, token_payload: dict): diff --git a/src/admin_portal/routers/channel.py b/src/admin_portal/routers/channel.py index 601feea..c0d3d61 100644 --- a/src/admin_portal/routers/channel.py +++ b/src/admin_portal/routers/channel.py @@ -409,6 +409,18 @@ async def get_channel_dataset_versions( ) +@router.get(path="/{channel_id}/datasets/{dataset_id}/versions/check-latest-up-to-date") +async def is_channel_dataset_latest_version_up_to_date( + channel_id: int, + dataset_id: int, + session: AsyncSession = Depends(models.get_session), +) -> schemas.ChangesBetweenVersionAndActualData: + """Check if the latest completed version of the specified channel dataset is up to date.""" + return await DataSetService(session).is_channel_dataset_latest_version_up_to_date( + channel_id=channel_id, dataset_id=dataset_id, auth_context=SystemUserAuthContext() + ) + + @router.post(path="/{channel_id}/datasets/{dataset_id}/versions/rollback") async def rollback_channel_dataset_to_previous_version( channel_id: int, diff --git a/src/admin_portal/services/background_tasks.py b/src/admin_portal/services/background_tasks.py index 6f6ea4b..7be67aa 100644 --- a/src/admin_portal/services/background_tasks.py +++ b/src/admin_portal/services/background_tasks.py @@ -1,5 +1,6 @@ import asyncio import functools +import logging from collections.abc import Awaitable, Callable from typing import ParamSpec, TypeVar @@ -8,8 +9,11 @@ Param = ParamSpec("Param") RetType = TypeVar("RetType") +_log = logging.getLogger(__name__) + _SETTINGS = BackgroundTasksSettings() _MAX_BACKGROUND_TASKS_SEMAPHORE = asyncio.Semaphore(_SETTINGS.max_concurrent) +_task_counter = 0 def background_task( @@ -19,7 +23,40 @@ def background_task( @functools.wraps(func) async def wrapper(*args: Param.args, **kwargs: Param.kwargs) -> RetType: - async with _MAX_BACKGROUND_TASKS_SEMAPHORE: - return await func(*args, **kwargs) + global _task_counter + _task_counter += 1 + task_id = f"{func.__name__}_{_task_counter}" + + # Log semaphore state before acquisition + available_before = _MAX_BACKGROUND_TASKS_SEMAPHORE._value + _log.debug( + f"[{task_id}] Attempting to acquire semaphore. " + f"Available slots: {available_before}/{_SETTINGS.max_concurrent}" + ) + + try: + async with _MAX_BACKGROUND_TASKS_SEMAPHORE: + available_after_acquire = _MAX_BACKGROUND_TASKS_SEMAPHORE._value + _log.info( + f"[{task_id}] Acquired semaphore. " + f"Available slots: {available_after_acquire}/{_SETTINGS.max_concurrent}" + ) + + try: + result = await func(*args, **kwargs) + _log.info(f"[{task_id}] Completed successfully") + return result + except asyncio.CancelledError: + _log.warning(f"[{task_id}] Was cancelled") + raise + except Exception as e: + _log.error(f"[{task_id}] Failed with exception: {e}") + raise + finally: + available_final = _MAX_BACKGROUND_TASKS_SEMAPHORE._value + _log.info( + f"[{task_id}] Released semaphore. " + f"Available slots: {available_final}/{_SETTINGS.max_concurrent}" + ) return wrapper diff --git a/src/admin_portal/services/dataset.py b/src/admin_portal/services/dataset.py index 1d94d5a..df4db7e 100644 --- a/src/admin_portal/services/dataset.py +++ b/src/admin_portal/services/dataset.py @@ -3,8 +3,8 @@ import os.path import uuid import zipfile -from collections.abc import Iterable -from typing import Any +from collections.abc import Generator, Iterable +from typing import Any, NamedTuple import yaml from fastapi import BackgroundTasks, HTTPException, status @@ -30,7 +30,7 @@ IndicatorDocumentMetadataFields, SpecialDimensionValueDocumentMetadataFields, ) -from common.utils import async_utils +from common.utils import async_utils, crc32_hash_incremental_async from common.utils.elastic import ElasticIndex, ElasticSearchFactory, SearchResult from common.vectorstore import VectorStore, VectorStoreFactory @@ -42,6 +42,12 @@ _log = logging.getLogger(__name__) +class _DataHashes(NamedTuple): + indicator_dimensions_hash: str + non_indicator_dimensions_hash: str + special_dimensions_hash: str | None + + class AdminPortalDataSetService(DataSetService): def __init__(self, session: AsyncSession) -> None: @@ -172,20 +178,11 @@ async def _export_elastic_data( ) _log.info("Finished exporting elastic data") - async def export_datasets( - self, channel: models.Channel, res_dir: str, auth_context: AuthContext - ) -> None: - channel_config = schemas.ChannelConfig.model_validate(channel.details) - - datasets = await self.get_datasets_schemas( - limit=None, - offset=0, - channel_id=channel.id, - auth_context=auth_context, - allow_offline=True, - ) + @staticmethod + def _export_datasets_config( + datasets: list[schemas.DataSet], res_dir: str + ) -> dict[int, schemas.DataSource]: data_sources = {} - data = [] for dataset in datasets: dataset_json = dataset.model_dump(mode='json', include=JobsConfig.DATASET_FIELDS) @@ -202,6 +199,34 @@ async def export_datasets( datasets_file = os.path.join(res_dir, JobsConfig.DATASETS_FILE) utils.write_yaml({'dataSets': data}, datasets_file) + return data_sources + + @staticmethod + def _export_versions( + versions: dict[uuid.UUID, schemas.ChannelDatasetVersion], res_dir: str + ) -> None: + data = { + str(ds_id): version.model_dump(mode='json', include=JobsConfig.VERSIONS_FIELDS) + for ds_id, version in versions.items() + } + datasets_file = os.path.join(res_dir, JobsConfig.VERSIONS_FILE) + utils.write_yaml({'data': data}, datasets_file) + + async def export_datasets( + self, channel: models.Channel, res_dir: str, auth_context: AuthContext + ) -> None: + channel_config = schemas.ChannelConfig.model_validate(channel.details) + + datasets = await self.get_datasets_schemas( + limit=None, + offset=0, + channel_id=channel.id, + auth_context=auth_context, + allow_offline=True, + ) + + data_sources = self._export_datasets_config(datasets, res_dir) + await DataSourceService.export_data_sources(data_sources.values(), res_dir) channel_datasets = await self.get_channel_dataset_models( @@ -210,6 +235,12 @@ async def export_datasets( latest_completed_versions = await self._get_latest_successful_dataset_version( channel_dataset_ids=[cd.id for cd in channel_datasets] ) + versions = { + next(d.id_ for d in datasets if d.id == ds_id): version.last_completed_version + for ds_id, version in latest_completed_versions.items() + if version.last_completed_version is not None + } + self._export_versions(versions, res_dir) await self._export_vector_store_data( channel, res_dir, auth_context, latest_completed_versions @@ -288,19 +319,33 @@ async def _add_datasets_to_channel( self._session.add_all(items) await self._session.commit() - async def _create_datasets_versions( - self, channel_id: int, preprocessing_status: StatusEnum + async def _import_datasets_versions( + self, zip_file: zipfile.ZipFile, datasets: list[schemas.DataSet], channel_id: int ) -> dict[int, models.ChannelDatasetVersion]: + with zip_file.open(JobsConfig.VERSIONS_FILE) as file: + versions_json = yaml.safe_load(file) + + datasets_dict = {ds.id: ds for ds in datasets} + channel_datasets = await self.get_channel_dataset_models( limit=None, offset=0, channel_id=channel_id ) versions = {} for ch_ds in channel_datasets: + dataset = datasets_dict[ch_ds.dataset_id] + + other = {} + if v := versions_json['data'].get(str(dataset.id_)): + other['creation_reason'] = "Imported from zip" + other.update(v) + else: + _log.warning(f"No version data found for dataset {dataset.title!r}") + other['creation_reason'] = "Imported from zip without version data" version = models.ChannelDatasetVersion( channel_dataset_id=ch_ds.id, # `version` will be set by the DB trigger automatically - preprocessing_status=preprocessing_status, - creation_reason="Imported from zip", + preprocessing_status=StatusEnum.IN_PROGRESS, + **other, ) versions[ch_ds.dataset_id] = version self._session.add_all(versions.values()) @@ -415,8 +460,8 @@ async def import_datasets_and_data_sources_from_zip( zip_file, data_sources, update_datasets, auth_context=auth_context # type: ignore ) await self._add_datasets_to_channel(channel_id=channel_db.id, datasets=datasets) - versions = await self._create_datasets_versions( - channel_id=channel_db.id, preprocessing_status=StatusEnum.IN_PROGRESS + versions = await self._import_datasets_versions( + zip_file, datasets, channel_id=channel_db.id ) await self._import_vector_store_tables( @@ -685,6 +730,24 @@ async def _update_channel_dataset_version_status( await self._session.commit() await self._session.refresh(item) + async def _set_version_hashes_and_metadata( + self, + item: models.ChannelDatasetVersion, + structure_hash: str, + structure_metadata: dict, + data_hashes: _DataHashes, + ) -> None: + """Sets the structure and data hashes for the given channel dataset version.""" + item.structure_metadata = structure_metadata + item.structure_hash = structure_hash + item.indicator_dimensions_hash = data_hashes.indicator_dimensions_hash + item.non_indicator_dimensions_hash = data_hashes.non_indicator_dimensions_hash + item.special_dimensions_hash = data_hashes.special_dimensions_hash + item.updated_at = func.now() + + await self._session.commit() + await self._session.refresh(item) + async def rollback_channel_dataset_to_previous_version( self, channel_id: int, dataset_id: int ) -> schemas.ChannelDatasetVersion: @@ -719,12 +782,98 @@ async def rollback_channel_dataset_to_previous_version( preprocessing_status=StatusEnum.COMPLETED, pointer_to=previous_version.version_data_id, creation_reason=f"Rolled back to previous version={previous_version.version}", + **{f: getattr(previous_version, f) for f in JobsConfig.VERSIONS_FIELDS}, ) self._session.add(new_item) await self._session.commit() await self._session.refresh(new_item) return schemas.ChannelDatasetVersion.model_validate(new_item, from_attributes=True) + async def is_channel_dataset_latest_version_up_to_date( + self, channel_id: int, dataset_id: int, auth_context: AuthContext + ) -> schemas.ChangesBetweenVersionAndActualData: + channel: models.Channel = await ChannelService(self._session).get_model_by_id(channel_id) + dataset_db: models.DataSet = await self.get_model_by_id(dataset_id) + channel_dataset = await self.get_channel_dataset_model_or_raise( + channel_id=channel.id, dataset_id=dataset_db.id + ) + + latest_completed_version = await self._get_latest_successful_dataset_version( + channel_dataset_ids=[channel_dataset.id] + ) + version = latest_completed_version[dataset_db.id].last_completed_version + + if not version: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="No completed versions found." + ) + + handler = await self._get_handler(dataset_db.source_id) + + structure_hash, meta = await handler.get_structure_hash_and_metadata( + dataset_config=dataset_db.details, auth_context=auth_context + ) + if version.structure_hash == structure_hash: + structure_change = None + else: + details = handler.get_structure_metadata_diff(version.structure_metadata, meta) + structure_change = schemas.StructureChange( + message="The dataset structure has changed.", + last_version_hash=version.structure_hash, + actual_hash=structure_hash, + details=details, + ) + + if structure_change: + data_changes = [ + schemas.DataChange( + message="The dataset structure has changed, so all data is considered changed.", + last_version_hash='N/A', + actual_hash='N/A', + ) + ] + else: + dataset = await handler.get_dataset( + entity_id=dataset_db.id_, + title=dataset_db.title, + config=dataset_db.details, + auth_context=auth_context, + allow_offline=False, + ) + data_hashes = await self._get_data_hashes( + dataset, auth_context=auth_context, allow_cached=False + ) + data_changes = self._get_data_changes(version, data_hashes) + + return schemas.ChangesBetweenVersionAndActualData( + has_changes=structure_change is not None or len(data_changes) != 0, + data_changes=data_changes, + structure_change=structure_change, + ) + + @staticmethod + def _get_data_changes( + version: schemas.ChannelDatasetVersion, data_hashes: _DataHashes + ) -> list[schemas.DataChange]: + iterable = [ + ('Indicator', version.indicator_dimensions_hash, data_hashes.indicator_dimensions_hash), + ('Special', version.special_dimensions_hash, data_hashes.special_dimensions_hash), + ( + 'Non-indicator', + version.non_indicator_dimensions_hash, + data_hashes.non_indicator_dimensions_hash, + ), + ] + return [ + schemas.DataChange( + message=f"Available data for the {name} dimensions has changed.", + last_version_hash=old_hash, + actual_hash=new_hash, + ) + for name, old_hash, new_hash in iterable + if old_hash != new_hash + ] + async def clear_channel_dataset_versions_data( self, channel_id: int, dataset_id: int, auth_context: AuthContext ): @@ -900,6 +1049,63 @@ async def reload_indicators( def _is_harmonization_supported(cls, channel: models.Channel) -> bool: return ChannelService.is_channel_hybrid(channel) + @staticmethod + async def _get_indicators_hash( + dataset: base.DataSet, auth_context: AuthContext, allow_cached: bool + ) -> str: + indicators = await dataset.get_indicators( + auth_context=auth_context, allow_cached=allow_cached + ) + indicators_values = sorted(f"{i.query_id} {i.name}" for i in indicators) + hash_value = await crc32_hash_incremental_async(indicators_values) + return str(hash_value) + + @staticmethod + async def _get_non_indicators_hash(dataset: base.DataSet) -> str: + dimensions: Generator[base.CategoricalDimension] = ( + dim + for dim in dataset.non_indicator_dimensions() + if isinstance(dim, base.CategoricalDimension) + ) + dimensions_values = sorted( + f"{category_value.query_id} {category_value.name}" + for dim in dimensions + for category_value in dim.available_values + ) + hash_value = await crc32_hash_incremental_async(dimensions_values) + return str(hash_value) + + @staticmethod + async def _get_special_dimensions_hash(dataset: base.DataSet) -> str | None: + if not dataset.special_dimensions(): + return None + dimensions: Generator[base.CategoricalDimension] = ( + dim + for dim in dataset.special_dimensions().values() + if isinstance(dim, base.CategoricalDimension) + ) + dimensions_values = sorted( + f"{category_value.query_id} {category_value.name}" + for dim in dimensions + for category_value in dim.available_values + ) + hash_value = await crc32_hash_incremental_async(dimensions_values) + return str(hash_value) + + async def _get_data_hashes( + self, dataset: base.DataSet, auth_context: AuthContext, allow_cached: bool + ) -> _DataHashes: + indicator_dimensions_hash = await self._get_indicators_hash( + dataset, auth_context=auth_context, allow_cached=allow_cached + ) + non_indicator_dimensions_hash = await self._get_non_indicators_hash(dataset) + special_dimensions_hash = await self._get_special_dimensions_hash(dataset) + return _DataHashes( + indicator_dimensions_hash=indicator_dimensions_hash, + non_indicator_dimensions_hash=non_indicator_dimensions_hash, + special_dimensions_hash=special_dimensions_hash, + ) + @staticmethod async def _run_semantic_indexer( dataset: base.DataSet, @@ -909,7 +1115,7 @@ async def _run_semantic_indexer( max_n_embeddings: int | None, auth_context: AuthContext, ): - indicators = await dataset.get_indicators(auth_context=auth_context) + indicators = await dataset.get_indicators(auth_context=auth_context, allow_cached=True) _log.info(f"Loaded {len(indicators)} indicators.") if max_n_embeddings: indicators = indicators[:max_n_embeddings] # for debug @@ -1139,6 +1345,15 @@ async def reload_channel_dataset_in_background( allow_offline=False, # Unable to reindex offline dataset ) + if reindex_dimensions or (reindex_indicators and not harmonize_indicator): + structure_hash, meta = await handler.get_structure_hash_and_metadata( + dataset_config=db_dataset.details, auth_context=auth_context + ) + data_hashes = await self._get_data_hashes(dataset, auth_context, allow_cached=True) + await self._set_version_hashes_and_metadata( + version, structure_hash, meta, data_hashes + ) + vector_store_factory = VectorStoreFactory(session=self._session) if reindex_dimensions: diff --git a/src/admin_portal/settings/exim.py b/src/admin_portal/settings/exim.py index dec96af..1cb00ef 100644 --- a/src/admin_portal/settings/exim.py +++ b/src/admin_portal/settings/exim.py @@ -35,6 +35,7 @@ class JobsConfig: CHANNEL_FILE = 'channel.yaml' DATASETS_FILE = 'datasets.yaml' DATA_SOURCES_FILE = 'data_sources.yaml' + VERSIONS_FILE = 'versions.yaml' GLOSSARY_TERMS_FILE = 'glossary_terms.csv' DIAL_FILES_FOLDER = 'dial_files' @@ -46,3 +47,10 @@ class JobsConfig: CHANNEL_FIELDS = {"deployment_id", "title", "description", "llm_model", "details"} DATASET_FIELDS = {"id_", "title", "details"} # and dynamic 'dataSource' field DATA_SOURCE_FIELDS = {"title", "description", "type_id", "details"} + VERSIONS_FIELDS = { + 'structure_metadata', + 'structure_hash', + 'indicator_dimensions_hash', + 'non_indicator_dimensions_hash', + 'special_dimensions_hash', + } diff --git a/src/common/config/versions.py b/src/common/config/versions.py index 5b2a33d..e612e44 100644 --- a/src/common/config/versions.py +++ b/src/common/config/versions.py @@ -7,4 +7,4 @@ class Versions: # Please update this version when you create a new alembic revision. # Needed because alembic folder exist only in the admin_portal package. # (statgpt Dockerfile doesn't copy admin_portal package to the container) - ALEMBIC_TARGET_VERSION = '10fe795dc09d' + ALEMBIC_TARGET_VERSION = 'd528d881ece8' diff --git a/src/common/data/base/dataset.py b/src/common/data/base/dataset.py index d823c36..10fbdfc 100644 --- a/src/common/data/base/dataset.py +++ b/src/common/data/base/dataset.py @@ -3,11 +3,12 @@ import typing as t import uuid from abc import ABC, abstractmethod +from collections.abc import Sequence from datetime import datetime import pandas as pd import plotly.graph_objects as go -from pydantic import BaseModel, ConfigDict, Field, StrictStr, alias_generators +from pydantic import BaseModel, ConfigDict, Field, StrictStr, alias_generators, model_validator from common.auth.auth_context import AuthContext from common.config.utils import replace_env @@ -90,6 +91,13 @@ class DataSetConfig(BaseModel, ABC): description="Column names and order to pin in the data in grid", default_factory=list ) + @model_validator(mode='after') + def _no_dups_in_special_dims_processor(self): + processor_ids = [sd.processor_id for sd in self.special_dimensions] + if len(processor_ids) != len(set(processor_ids)): + raise ValueError("Duplicate processor_id found in special_dimensions") + return self + @abstractmethod def get_source_id(self) -> str: pass @@ -328,7 +336,9 @@ def indicator_dimensions_required_for_query(self) -> list[str]: pass @abstractmethod - async def get_indicators(self, auth_context: AuthContext) -> t.Sequence[BaseIndicator]: + async def get_indicators( + self, auth_context: AuthContext, allow_cached: bool + ) -> Sequence[BaseIndicator]: pass @abstractmethod @@ -404,7 +414,9 @@ def virtual_indicator_dimensions(self) -> list[Dimension]: def indicator_dimensions_required_for_query(self) -> list[str]: return [] - async def get_indicators(self, auth_context: AuthContext) -> t.Sequence[BaseIndicator]: + async def get_indicators( + self, auth_context: AuthContext, allow_cached: bool + ) -> Sequence[BaseIndicator]: return [] async def availability_query( diff --git a/src/common/data/base/datasource.py b/src/common/data/base/datasource.py index 46c6615..d34361f 100644 --- a/src/common/data/base/datasource.py +++ b/src/common/data/base/datasource.py @@ -98,6 +98,24 @@ async def get_dataset( async def close(self): pass + @abstractmethod + async def get_structure_hash_and_metadata( + self, dataset_config: dict, auth_context: AuthContext + ) -> tuple[str, dict]: + """Get a hash calculated based on the part of the dataset structure that is important for indexing. + If this hash changes, the dataset should be re-indexed. + + Additionally, return metadata about the structure that can be used to understand what has changed. + The metadata is a JSON-serializable dictionary. + """ + + @abstractmethod + def get_structure_metadata_diff(self, old_metadata: dict | None, new_metadata: dict) -> dict: + """Get the difference between the old and new structure metadata of a dataset. + + Return a JSON-serializable dictionary describing what has changed. + """ + @abstractmethod async def get_indicator_from_document(self, documents: Document) -> BaseIndicator: pass diff --git a/src/common/data/base/dimension.py b/src/common/data/base/dimension.py index 183c202..9a1946e 100644 --- a/src/common/data/base/dimension.py +++ b/src/common/data/base/dimension.py @@ -1,6 +1,7 @@ import json import typing as t from abc import ABC, abstractmethod +from enum import StrEnum from pydantic import BaseModel, Field @@ -182,3 +183,12 @@ def available_operators(self) -> t.List[QueryOperator]: QueryOperator.LESS_THAN_OR_EQUALS, QueryOperator.BETWEEN, ] + + +class DimensionProcessingType(StrEnum): + """Dimension classification for processing purposes""" + + INDICATOR = 'INDICATOR' + NONINDICATOR = 'NONINDICATOR' + SPECIAL = 'SPECIAL' + TIME_PERIOD = 'TIME_PERIOD' diff --git a/src/common/data/quanthub/v21/dataset.py b/src/common/data/quanthub/v21/dataset.py index a877f0c..88be666 100644 --- a/src/common/data/quanthub/v21/dataset.py +++ b/src/common/data/quanthub/v21/dataset.py @@ -198,8 +198,8 @@ async def query( url = self._get_query_url(data_msg.response) # type: ignore try: - sdmx_pandas = self._data_msg_to_dataframe(data_msg) - sdmx_pandas = self._include_attributes(sdmx_pandas) + sdmx_pandas = await self._data_msg_to_dataframe(data_msg) + sdmx_pandas = await self._include_attributes(sdmx_pandas) except Exception as e: _log.exception(e) return Sdmx21DataResponse( diff --git a/src/common/data/quanthub/v21/datasource.py b/src/common/data/quanthub/v21/datasource.py index 2891e24..92e6e50 100644 --- a/src/common/data/quanthub/v21/datasource.py +++ b/src/common/data/quanthub/v21/datasource.py @@ -17,6 +17,7 @@ from common.data.sdmx.v21.schemas import Urn from common.settings.sdmx import quanthub_settings from common.utils import Cache +from common.utils.timer import debug_timer from .qh_sdmx_client import AsyncQuanthubClient @@ -67,7 +68,7 @@ async def is_dataset_available(self, config: dict, auth_context: AuthContext) -> else: raise e - async def get_dataset( + async def _get_dataset( self, entity_id: uuid.UUID, title: str, @@ -110,7 +111,7 @@ async def get_dataset( try: dataflow_loader = DataflowLoader(sdmx_client) - structure_message = await dataflow_loader.load_structure_message(urn) + structure_message = await dataflow_loader.load_structure_message(urn, mode="full") except Exception as e: if allow_offline: msg = f"Failed to load the dataflow or its associated structures. {urn=}" @@ -210,6 +211,25 @@ async def get_dataset( return res + async def get_dataset( + self, + entity_id: uuid.UUID, + title: str, + config: dict, + auth_context: AuthContext, + allow_offline: bool = False, + allow_cached: bool = False, + ) -> QuanthubSdmx21DataSet | SdmxOfflineDataSet: + with debug_timer(f"QuanthubSdmx21DataSourceHandler.get_dataset: {title}"): + return await self._get_dataset( + entity_id, + title, + config, + auth_context, + allow_offline=allow_offline, + allow_cached=allow_cached, + ) + @staticmethod def data_source_type() -> DataSourceType: return DataSourceType( diff --git a/src/common/data/sdmx/common/codelist.py b/src/common/data/sdmx/common/codelist.py index 293f7f6..90bd9af 100644 --- a/src/common/data/sdmx/common/codelist.py +++ b/src/common/data/sdmx/common/codelist.py @@ -32,7 +32,7 @@ def __contains__(self, item: str) -> bool: pass @abstractmethod - def __getitem__(self, item: str) -> CodeCategory | None: + def __getitem__(self, item: str) -> CodeCategory: pass @@ -40,24 +40,36 @@ class InMemoryCodeList(BaseSdmxCodeList): _code_list: common.Codelist _codes: t.Dict[str, CodeCategory] - def __init__( - self, - code_list: common.Codelist, - locale: str, - ): + def __init__(self, code_list: common.Codelist, locale: str): super().__init__(code_list, locale) self._code_list = code_list - self._codes = {code.id: CodeCategory(code, locale) for code in code_list.items.values()} + self._codes = {} + + def _get_item_and_cache(self, item: str) -> CodeCategory | None: + if item not in self._codes: + code = self._code_list[item] + if code is None: + return None + self._codes[item] = CodeCategory(code, self._locale) + return self._codes[item] + + def _get_item_and_cache_or_raise(self, item: str) -> CodeCategory: + code = self._get_item_and_cache(item) + if code is None: + raise KeyError(f"Code '{item}' not found in codelist '{self.code_list.id}'") + return code @property def code_list(self) -> common.Codelist: return self._code_list def codes(self) -> t.Sequence[CodeCategory]: - return list(self._codes.values()) + if len(self._codes) == len(self._code_list.items): + return list(self._codes.values()) + return [self._get_item_and_cache_or_raise(code) for code in self._code_list.items.values()] - def __getitem__(self, item: str) -> CodeCategory | None: - return self._codes.get(item) + def __getitem__(self, item: str) -> CodeCategory: + return self._get_item_and_cache_or_raise(item) def __contains__(self, item: str) -> bool: - return item in self._codes + return item in self._code_list diff --git a/src/common/data/sdmx/common/dimension.py b/src/common/data/sdmx/common/dimension.py index 8c56a0f..4c54674 100644 --- a/src/common/data/sdmx/common/dimension.py +++ b/src/common/data/sdmx/common/dimension.py @@ -1,5 +1,6 @@ import typing as t from abc import ABC +from collections.abc import Iterable from sdmx.model import common @@ -62,7 +63,7 @@ def __init__( name: str, description: t.Optional[str], code_list: BaseSdmxCodeList, - available_codes: t.Iterable[str] | None = None, + available_codes: Iterable[str] | None = None, alias: str | None = None, ): SdmxDimension.__init__(self, dimension, name, description, alias=alias) @@ -73,8 +74,6 @@ def __init__( def available_values(self) -> t.Sequence[DimensionCodeCategory]: if not self._available_codes: return self.values - # TODO: could probably use a simpler version - # return [item for item in self.values if item.query_id in self._available_codes] return [ DimensionCodeCategory.from_code_category(item, self.entity_id, self._name, self._alias) for item in self.values diff --git a/src/common/data/sdmx/v21/dataflow_loader.py b/src/common/data/sdmx/v21/dataflow_loader.py index 41f702f..f39210c 100644 --- a/src/common/data/sdmx/v21/dataflow_loader.py +++ b/src/common/data/sdmx/v21/dataflow_loader.py @@ -1,3 +1,5 @@ +from typing import Literal + from sdmx.message import StructureMessage from sdmx.model.v21 import DataStructureDefinition @@ -15,7 +17,9 @@ class DataflowLoader: def __init__(self, client: AsyncSdmxClient): self._client: AsyncSdmxClient = client - async def load_structure_message(self, urn: Urn) -> StructureMessage21: + async def load_structure_message( + self, urn: Urn, mode: Literal['full', 'shallow'] + ) -> StructureMessage21: dataflow_msg = await self._load_dataflow(urn) result_message = StructureMessage21.from_sdmx1(dataflow_msg) @@ -23,6 +27,9 @@ async def load_structure_message(self, urn: Urn) -> StructureMessage21: for scheme_msg in schemes: result_message.add_concept_schemes(scheme_msg.concept_scheme.values()) + if mode == 'shallow': + return result_message + code_lists = await self._load_code_lists(result_message, urn) for code_list_msg in code_lists: result_message.add_codelists(code_list_msg.codelist.values()) diff --git a/src/common/data/sdmx/v21/dataset.py b/src/common/data/sdmx/v21/dataset.py index a52a932..71241a8 100644 --- a/src/common/data/sdmx/v21/dataset.py +++ b/src/common/data/sdmx/v21/dataset.py @@ -1,3 +1,4 @@ +import asyncio import collections import json import logging @@ -6,7 +7,7 @@ import time import typing as t import uuid -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from datetime import datetime from functools import cached_property @@ -241,8 +242,8 @@ def _enrich_df_with_names(self, df: pd.DataFrame) -> pd.DataFrame: df = df.reset_index() sorted_columns = [] for column in df.columns: - id2name_mapping = self.dataset.map_dim_values_id_2_name( - value_ids=df[column].to_list(), dimension_name=column + id2name_mapping = self.dataset.map_component_values_id_2_name( + value_ids=df[column].to_list(), component_id=column ) if id2name_mapping is None: continue @@ -345,16 +346,14 @@ def _to_component_filters(cls, sdmx_query: SdmxDataSetQuery) -> list[JsonCompone class Sdmx21DataSet( DataSet[SdmxDataSetConfig, 'Sdmx21DataSourceHandler'], BaseNameableArtefact[DataFlow] ): - _dimensions: t.Dict[str, SdmxDimension | VirtualDimension] - _attributes: t.Dict[str, Sdmx21Attribute] - _virtual_dimensions: t.Dict[str, VirtualDimension] - _indicator_dimensions: t.Dict[str, SdmxCodeListDimension | VirtualDimension] + _dimensions: dict[str, SdmxDimension | VirtualDimension] + _attributes: dict[str, Sdmx21Attribute] + _virtual_dimensions: dict[str, VirtualDimension] + _indicator_dimensions: dict[str, SdmxCodeListDimension | VirtualDimension] _indicator_dimensions_required_for_query: list[str] _country_dimension: SdmxCodeListDimension | VirtualDimension | None _fixed_indicator: FixedItem | None - # dimension_id -> {code_id -> code_name} _dim_values_id_2_name: dict[str, dict[str, str]] | None - _attrib_values_id_2_name: dict[str, dict[str, str]] | None def __init__( self, @@ -375,10 +374,8 @@ def __init__( self._indicator_dimensions_required_for_query = [] self._fixed_indicator = config.fixed_indicator self._virtual_dimensions = {} - self._dim_values_id_2_name = None - self._attributes = {attribute.entity_id: attribute for attribute in attributes} - self._attrib_values_id_2_name = None + self._dim_values_id_2_name = None # virtual dimensions for virtual_dimension_config in config.virtual_dimensions: @@ -472,7 +469,7 @@ def dataset_url(self) -> str | None: return self.config.citation.get_url() return None - def _indicators_from_fixed_indicator(self) -> t.Sequence[ComplexIndicator]: + def _indicators_from_fixed_indicator(self) -> list[ComplexIndicator]: if self._fixed_indicator is None: raise ValueError("fixed_indicator is None") @@ -573,6 +570,25 @@ async def _get_available_series( return series, queries_count + def _save_indicator_combinations( + self, + file_path: str, + series: list[dict[str, str]], + order: list[str], + virtual_indicator_dimensions: Sequence[VirtualDimension], + ) -> int: + series_df = pd.DataFrame(series) + series_df.sort_values(order, inplace=True) + + for dim in virtual_indicator_dimensions: + series_df[dim.entity_id] = dim.value.entity_id + + series_df.to_csv(file_path, index=False) + _log.info(f"{self.source_id}. Saved indicator combinations to '{file_path}'") + duplicates = series_df.duplicated().sum() + + return duplicates + async def _load_indicator_combinations_to( self, file_path: str, auth_context: AuthContext ) -> None: @@ -591,7 +607,7 @@ async def _load_indicator_combinations_to( } dim_2_avail_values_cnt_sorted = sorted( - {k: len(v) for k, v in avail_values.items()}.items(), key=lambda x: x[1] + ((k, len(v)) for k, v in avail_values.items()), key=lambda x: x[1] ) _log.info(f"{self.source_id} {dim_2_avail_values_cnt_sorted=}") order = [x[0] for x in dim_2_avail_values_cnt_sorted] @@ -609,43 +625,49 @@ async def _load_indicator_combinations_to( f"{self.source_id}. Number of series extracted: {len(series)}. Number of queries sent: {queries_count}" ) - elsapsed_time = time.time() - time_start - _log.info(f'{self.source_id}. elapsed time: {elsapsed_time :.3f} sec') + elapsed_time = time.time() - time_start + _log.info(f'{self.source_id}. elapsed time: {elapsed_time :.3f} sec') virtual_indicator_dimensions = self.virtual_indicator_dimensions() - series_df = pd.DataFrame(series) - series_df.sort_values(order, inplace=True) - - for dim in virtual_indicator_dimensions: - series_df[dim.entity_id] = dim.value.entity_id - - series_df.to_csv(file_path, index=False) - _log.info(f"{self.source_id}. Saved indicator combinations to '{file_path}'") - - duplicates = series_df.duplicated().sum() + duplicates = await asyncio.to_thread( + self._save_indicator_combinations, + file_path, + series, + order, + virtual_indicator_dimensions, + ) if duplicates: raise ValueError( f"{self.source_id}. Found {duplicates} duplicates in the indicator combinations" ) _log.info(f"{self.source_id}. No duplicates found in the indicator combinations") - async def _get_or_load_indicator_combinations(self, auth_context: AuthContext) -> pd.DataFrame: + def _read_indicator_combinations_from_file(self, file_path: str) -> pd.DataFrame: + _log.debug(f"{self.source_id}. Reading indicator combinations from '{file_path}'") + df = pd.read_csv(file_path, dtype=str, keep_default_na=False) + _log.debug(f"{self.source_id}. Read {len(df)} indicator combinations from '{file_path}'") + return df + + async def _get_or_load_indicator_combinations( + self, auth_context: AuthContext, allow_cached: bool + ) -> pd.DataFrame: file_name = escape_invalid_filename_chars(f"{self.source_id}.csv") if cache_dir := sdmx_settings.cache_dir: dir_name = os.path.join(cache_dir, sdmx_settings.indicator_combinations_subdir) file_path = os.path.join(str(dir_name), file_name) - if not os.path.exists(file_path): - os.makedirs(dir_name, exist_ok=True) + if not allow_cached or not os.path.exists(file_path): + status = 'not found' if allow_cached else 'not allowed' + _log.info(f"{self.source_id}. Indicator combinations cache {status}.") # Create cache of available indicator combinations: - _log.info(f"{self.source_id}. Indicator combinations cache not found.") + os.makedirs(dir_name, exist_ok=True) await self._load_indicator_combinations_to(file_path, auth_context=auth_context) else: _log.info(f"{self.source_id}. Getting indicator combinations from cache.") - return pd.read_csv(file_path, dtype=str, keep_default_na=False) + return await asyncio.to_thread(self._read_indicator_combinations_from_file, file_path) else: _log.info( f"{self.source_id}. Indicator combinations cache disabled. Loading to temp dir..." @@ -653,20 +675,21 @@ async def _get_or_load_indicator_combinations(self, auth_context: AuthContext) - with tempfile.TemporaryDirectory() as tmp_dir: file_path = os.path.join(tmp_dir, file_name) await self._load_indicator_combinations_to(file_path, auth_context) - return pd.read_csv(file_path, dtype=str, keep_default_na=False) + return await asyncio.to_thread( + self._read_indicator_combinations_from_file, file_path + ) async def _indicators_from_dimensions( - self, auth_context: AuthContext - ) -> t.Sequence[ComplexIndicator]: + self, auth_context: AuthContext, allow_cached: bool + ) -> list[ComplexIndicator]: if not self._indicator_dimensions: raise ValueError("No indicator dimensions") - df_avail_dim_combinations = await self._get_or_load_indicator_combinations(auth_context) + df_avail_dim_combinations = await self._get_or_load_indicator_combinations( + auth_context=auth_context, allow_cached=allow_cached + ) - # use the same columns order as in the list of indicator dimensions from dataset config. - # NOTE: here we rely on the order of items in the dict, which is generally a bad practice. - # however, it seems to work here, probably becase we don't modify the dict - # after the pydantic model is created. and it seems to preser the insert order of keys. + # Use the same columns order as in the list of indicator dimensions from dataset config. df_avail_dim_combinations = df_avail_dim_combinations[self._indicator_dimensions.keys()] dim_cat_id_2_model: dict[str, dict[str, CodeCategory | VirtualDimensionCategory]] = ( @@ -679,15 +702,16 @@ async def _indicators_from_dimensions( elif isinstance(dimension, VirtualDimension): dim_cat_id_2_model[dimension.entity_id][dimension.value.entity_id] = dimension.value - df_indicators = df_avail_dim_combinations.apply( # type: ignore - lambda row: ComplexIndicator( + def _create_indicator(row): + return ComplexIndicator( CodeIndicator(dim_cat_id_2_model[dim_id][row[dim_id]]) for dim_id in df_avail_dim_combinations.columns - ), - axis=1, + ) + + df_indicators = await asyncio.to_thread( + df_avail_dim_combinations.apply, _create_indicator, axis=1 # type: ignore ) indicators = df_indicators.to_list() - return indicators @staticmethod @@ -883,18 +907,22 @@ def virtual_indicator_dimensions(self) -> t.Sequence[VirtualDimension]: def indicator_dimensions_required_for_query(self) -> list[str]: return self._indicator_dimensions_required_for_query - async def get_indicators(self, auth_context: AuthContext) -> t.Sequence[ComplexIndicator]: + async def get_indicators( + self, auth_context: AuthContext, allow_cached: bool + ) -> Sequence[ComplexIndicator]: if self._fixed_indicator: return self._indicators_from_fixed_indicator() elif self._indicator_dimensions: - return await self._indicators_from_dimensions(auth_context=auth_context) + return await self._indicators_from_dimensions( + auth_context=auth_context, allow_cached=allow_cached + ) else: raise ValueError("No indicators") def country_dimension(self) -> CategoricalDimension | None: return self._country_dimension - def get_dim_values_id_2_name_mapping(self) -> dict[str, dict[str, str]]: + def _get_dim_values_id_2_name_mapping(self) -> dict[str, dict[str, str]]: if self._dim_values_id_2_name is not None: return self._dim_values_id_2_name @@ -908,32 +936,42 @@ def get_dim_values_id_2_name_mapping(self) -> dict[str, dict[str, str]]: return self._dim_values_id_2_name - def get_attrib_values_id_2_name_mapping(self) -> dict[str, dict[str, str]]: - if self._attrib_values_id_2_name is not None: - return self._attrib_values_id_2_name - - self._attrib_values_id_2_name = {} - for attrib in self.attributes(): - if not isinstance(attrib, Sdmx21CodeListAttribute): - continue - self._attrib_values_id_2_name[attrib.entity_id] = { - code.query_id: code.name for code in attrib.code_list.codes() - } - - return self._attrib_values_id_2_name - - def map_dim_values_id_2_name( - self, value_ids: t.Iterable[str], dimension_name: str + def map_component_values_id_2_name( + self, value_ids: Iterable[str], component_id: str ) -> dict[str, str] | None: """Map dimension or attribute ids to their corresponding names.""" - id2name = ( - self.get_dim_values_id_2_name_mapping() | self.get_attrib_values_id_2_name_mapping() - ) - cur_dim_id2name = id2name.get(dimension_name) - if cur_dim_id2name is None: + component: SdmxCodeListDimension | Sdmx21CodeListAttribute + if component_id in self._dimensions: + dimension = self._dimensions[component_id] + if not isinstance(dimension, SdmxCodeListDimension): + _log.debug( + "Dimension %s of dataset %s is not a code list dimension", + component_id, + self.short_urn, + ) + return None + component = dimension + elif component_id in self._attributes: + attribute = self._attributes[component_id] + if not isinstance(attribute, Sdmx21CodeListAttribute): + _log.debug( + "Attribute %s of dataset %s is not a code list attribute", + component_id, + self.short_urn, + ) + return None + component = attribute + else: + _log.debug("Component %s not found in dataset %s", component_id, self.short_urn) return None - res = {_id: cur_dim_id2name.get(_id, '') for _id in value_ids} + + code_list = component.code_list + res = { + value_id: code_list[value_id].name + for value_id in value_ids + if isinstance(value_id, str) + } return res def map_dim_queries_2_names(self, queries: dict[str, list[str]]): @@ -942,7 +980,7 @@ def map_dim_queries_2_names(self, queries: dict[str, list[str]]): """ res = {} for dim_id, value_ids in queries.items(): - id2name = self.map_dim_values_id_2_name(value_ids, dim_id) + id2name = self.map_component_values_id_2_name(value_ids, dim_id) if id2name is None: raise ValueError(f'Unexpected dimension id: "{dim_id}"') res[dim_id] = id2name @@ -1052,7 +1090,7 @@ async def availability_query( result = self._availability_result_to_query(availability_result) return result - def _data_msg_to_dataframe(self, data_msg: DataMessage) -> pd.DataFrame: + def _data_msg_to_dataframe_sync(self, data_msg: DataMessage) -> pd.DataFrame: """Convert SDMX data message to Pandas DataFrame.""" kwargs = {} @@ -1067,7 +1105,10 @@ def _data_msg_to_dataframe(self, data_msg: DataMessage) -> pd.DataFrame: return sdmx_pandas - def _include_attributes(self, df: pd.DataFrame) -> pd.DataFrame: + async def _data_msg_to_dataframe(self, data_msg: DataMessage) -> pd.DataFrame: + return await asyncio.to_thread(self._data_msg_to_dataframe_sync, data_msg) + + def _include_attributes_sync(self, df: pd.DataFrame) -> pd.DataFrame: if df.empty or not self.config.include_attributes: return df @@ -1094,6 +1135,9 @@ def _extract_attribute_value(x: t.Any) -> t.Any: res_df = res_df.set_index(keys=attributes, append=True) return res_df + async def _include_attributes(self, df: pd.DataFrame) -> pd.DataFrame: + return await asyncio.to_thread(self._include_attributes_sync, df) + async def _query_sdmx_data( self, sdmx_query: SdmxDataSetQuery, auth_context: AuthContext ) -> DataMessage: @@ -1149,8 +1193,8 @@ async def query( url = self._get_query_url(data_msg.response) # type: ignore try: - sdmx_pandas = self._data_msg_to_dataframe(data_msg) - sdmx_pandas = self._include_attributes(sdmx_pandas) + sdmx_pandas = await self._data_msg_to_dataframe(data_msg) + sdmx_pandas = await self._include_attributes(sdmx_pandas) except Exception as e: _log.exception(e) return Sdmx21DataResponse( diff --git a/src/common/data/sdmx/v21/datasource.py b/src/common/data/sdmx/v21/datasource.py index f8b81e6..9a2e285 100644 --- a/src/common/data/sdmx/v21/datasource.py +++ b/src/common/data/sdmx/v21/datasource.py @@ -1,6 +1,7 @@ import typing as t import uuid from abc import ABC +from operator import itemgetter from langchain_core.documents import Document from sdmx.message import StructureMessage @@ -28,10 +29,12 @@ from common.data.sdmx.v21.dataset import Sdmx21DataSet, SdmxOfflineDataSet from common.data.sdmx.v21.dimensions_creator import DimensionsCreator from common.data.sdmx.v21.sdmx_client import AsyncSdmxClient +from common.utils import crc32_hash +from common.utils.timer import debug_timer from .dataset_hierarchy import CategorySchemaDataSetHierarchyCreator from .ratelimiter import SdmxRateLimiterFactory -from .schemas import Urn +from .schemas import StructureMessage21, Urn class Sdmx21DataSourceHandler( @@ -107,14 +110,13 @@ async def list_datasets(self, auth_context: AuthContext) -> list[DataSetDescript def entity_id(self) -> str: return self._config.get_id() - async def get_dataset( + async def _get_dataset( self, entity_id: uuid.UUID, title: str, config: dict, auth_context: AuthContext, allow_offline: bool = False, - allow_cached: bool = False, ) -> Sdmx21DataSet | SdmxOfflineDataSet: dataset_config = self.parse_data_set_config(config) @@ -140,7 +142,7 @@ async def get_dataset( try: dataflow_loader = DataflowLoader(sdmx_client) - structure_message = await dataflow_loader.load_structure_message(urn) + structure_message = await dataflow_loader.load_structure_message(urn, mode="full") except Exception as e: if allow_offline: msg = f"Failed to load the dataflow or its associated structures. {urn=}" @@ -192,6 +194,104 @@ async def get_dataset( else: raise e + async def get_dataset( + self, + entity_id: uuid.UUID, + title: str, + config: dict, + auth_context: AuthContext, + allow_offline: bool = False, + allow_cached: bool = False, + ) -> Sdmx21DataSet | SdmxOfflineDataSet: + with debug_timer(f"Sdmx21DataSourceHandler.get_dataset: {title}"): + return await self._get_dataset( + entity_id, + title, + config, + auth_context, + allow_offline=allow_offline, + ) + + async def get_structure_hash_and_metadata( + self, dataset_config: dict, auth_context: AuthContext + ) -> tuple[str, dict]: + config = self.parse_data_set_config(dataset_config) + _urn = self._urn_parser.parse(config.urn) + urn = Urn( + agency_id=_urn.agency_id, + resource_id=_urn.resource_id, + version=_urn.version if _urn.version else "latest", + ) + + sdmx_client = await self.create_sdmx_client(auth_context) + + dataflow_loader = DataflowLoader(sdmx_client) + structure_message = await dataflow_loader.load_structure_message(urn, mode="shallow") + + meta_json = { + "dimensions": self._get_dimensions_from(structure_message, urn), + } + return str(crc32_hash(str(meta_json))), meta_json + + def _get_dimensions_from( + self, structure_message: StructureMessage21, dataflow_urn: Urn + ) -> list[dict[str, str]]: + locale = self.config.locale + dsd = structure_message.dataflow[dataflow_urn].structure + + res: list[dict[str, str]] = [] + for dimension in dsd.dimensions.components: + if dimension.concept_identity is None: + continue + scheme_urn = Urn.for_artifact(dimension.concept_identity.parent) # type: ignore[arg-type] + scheme = structure_message.concept_scheme[scheme_urn] + concept_id = dimension.concept_identity.id + dimension_name = scheme.items[concept_id].name.localized_default(locale) + res.append({"entity_id": dimension.id, "name": dimension_name}) + res = sorted(res, key=itemgetter("entity_id")) + return res + + def get_structure_metadata_diff(self, old_metadata: dict | None, new_metadata: dict) -> dict: + if old_metadata is None: + return {'message': 'No previous metadata to compare.'} + + try: + old_dimensions = {dim['entity_id']: dim for dim in old_metadata.get('dimensions', [])} + new_dimensions = {dim['entity_id']: dim for dim in new_metadata['dimensions']} + return self._compare_dimension_meta(old_dimensions, new_dimensions) + except Exception as e: + logger.warning(f"Cannot compute structure metadata diff: {e}", exc_info=True) + return {'message': 'Could not compute diff due to error.'} + + @staticmethod + def _compare_dimension_meta(old: dict[str, dict], current: dict[str, dict]) -> dict: + result: dict[str, t.Any] = {} + + new = set(current.keys()).difference(old.keys()) + if new: + result['new_dimensions'] = [current[dim_id] for dim_id in new] + + removed = set(old.keys()).difference(current.keys()) + if removed: + result['removed_dimensions'] = [old[dim_id] for dim_id in removed] + + modified = {} + for dim_id in set(old.keys()).intersection(current.keys()): + old_dim = old[dim_id] + current_dim = current[dim_id] + changes = {} + for field in ['name']: + old_value = getattr(old_dim, field, None) + current_value = getattr(current_dim, field, None) + if old_value != current_value: + changes[field] = {'old': old_value, 'new': current_value} + if changes: + modified[dim_id] = changes + if modified: + result['modified_dimensions'] = modified + + return result + async def close(self): # do nothing pass diff --git a/src/common/hybrid_indexer/indexer.py b/src/common/hybrid_indexer/indexer.py index 3a64e7b..50b68be 100644 --- a/src/common/hybrid_indexer/indexer.py +++ b/src/common/hybrid_indexer/indexer.py @@ -134,7 +134,7 @@ async def _normalize( max_n_indicators: int | None, auth_context: AuthContext, ) -> None: - indicators = await dataset.get_indicators(auth_context=auth_context) + indicators = await dataset.get_indicators(auth_context=auth_context, allow_cached=True) if max_n_indicators is not None: indicators = indicators[:max_n_indicators] diff --git a/src/common/models/models.py b/src/common/models/models.py index 00b2e5a..562d184 100644 --- a/src/common/models/models.py +++ b/src/common/models/models.py @@ -1,7 +1,7 @@ import uuid from typing import Any -from sqlalchemy import ForeignKey, UniqueConstraint +from sqlalchemy import ForeignKey, String, UniqueConstraint from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -152,6 +152,15 @@ class ChannelDatasetVersion(DefaultBase): creation_reason: Mapped[str] reason_for_failure: Mapped[str | None] = mapped_column(default=None) + structure_metadata: Mapped[dict | None] = mapped_column(type_=postgresql.JSONB, default=None) + structure_hash: Mapped[str | None] = mapped_column(type_=String(10), default=None) + # Data hashes: + indicator_dimensions_hash: Mapped[str | None] = mapped_column(type_=String(10), default=None) + non_indicator_dimensions_hash: Mapped[str | None] = mapped_column( + type_=String(10), default=None + ) + special_dimensions_hash: Mapped[str | None] = mapped_column(type_=String(10), default=None) + # relationships channel_dataset: Mapped[ChannelDataset] = relationship(back_populates="versions") pointer = relationship( diff --git a/src/common/schemas/__init__.py b/src/common/schemas/__init__.py index 4a3eb01..b39a4aa 100644 --- a/src/common/schemas/__init__.py +++ b/src/common/schemas/__init__.py @@ -1,6 +1,13 @@ from .base import ListResponse from .channel import Channel, ChannelBase, ChannelConfig, ChannelUpdate, SupremeAgentConfig -from .channel_dataset import ChannelDatasetBase, ChannelDatasetExpanded, ChannelDatasetVersion +from .channel_dataset import ( + ChangesBetweenVersionAndActualData, + ChannelDatasetBase, + ChannelDatasetExpanded, + ChannelDatasetVersion, + DataChange, + StructureChange, +) from .data_query_tool import DataQueryDetails, HybridSearchConfig from .data_source import DataSource, DataSourceBase, DataSourceType, DataSourceUpdate from .dataset import DataSet, DataSetBase, DataSetDescriptor, DataSetUpdate diff --git a/src/common/schemas/channel_dataset.py b/src/common/schemas/channel_dataset.py index 19f03ec..da942f0 100644 --- a/src/common/schemas/channel_dataset.py +++ b/src/common/schemas/channel_dataset.py @@ -1,4 +1,4 @@ -from pydantic import Field +from pydantic import BaseModel, Field from .base import DbDefaultBase from .dataset import DataSet @@ -28,6 +28,11 @@ class ChannelDatasetVersion(DbDefaultBase): " if the rollback version was also a rollback to a previous version." ) ) + structure_metadata: dict | None + structure_hash: str | None + indicator_dimensions_hash: str | None + non_indicator_dimensions_hash: str | None + special_dimensions_hash: str | None @property def version_data_id(self) -> int: @@ -49,3 +54,29 @@ class ChannelDatasetExpanded(ChannelDatasetBase): ) ) latest_version: ChannelDatasetVersion | None + + +class BaseChange(BaseModel): + message: str + last_version_hash: str | None + actual_hash: str | None + + +class DataChange(BaseChange): + pass + + +class StructureChange(BaseChange): + details: dict + + +class ChangesBetweenVersionAndActualData(BaseModel): + data_changes: list[DataChange] = Field( + description="List of changes between the indexed data of the latest completed version and the actual data." + ) + has_changes: bool = Field( + description="Indicates whether there are any changes between the version data and the actual data." + ) + structure_change: StructureChange | None = Field( + description="Details about the structure change, if any." + ) diff --git a/src/common/schemas/dial.py b/src/common/schemas/dial.py index e0d14b4..456b8b4 100644 --- a/src/common/schemas/dial.py +++ b/src/common/schemas/dial.py @@ -1,69 +1,4 @@ -"""A copy of the `aidial_sdk` schemas, but we inherit them from pydantic v2 base model""" - -from typing import Any, Literal - -from aidial_sdk.chat_completion import Role, Status -from pydantic import BaseModel, ConfigDict, Field, StrictStr - - -class ExtraAllowModel(BaseModel): - model_config = ConfigDict(populate_by_name=True, extra='allow') - - -class Attachment(ExtraAllowModel): - type: StrictStr | None = Field(default="text/markdown") - title: StrictStr | None = Field(default=None) - data: StrictStr | None = Field(default=None) - url: StrictStr | None = Field(default=None) - reference_type: StrictStr | None = Field(default=None) - reference_url: StrictStr | None = Field(default=None) - - -class Stage(ExtraAllowModel): - name: StrictStr - status: Status | None - content: StrictStr | None = Field(default=None) - attachments: list[Attachment] | None = Field(default=None) - - -class CustomContent(ExtraAllowModel): - stages: list[Stage] | None = Field(default=None) - attachments: list[Attachment] | None = Field(default=None) - state: Any | None = Field(default=None) - form_value: Any | None = None - form_schema: Any | None = None - - -class CacheBreakpoint(ExtraAllowModel): - expire_at: StrictStr | None = None - - -class MessageCustomFields(ExtraAllowModel): - cache_breakpoint: CacheBreakpoint | None = None - - -class FunctionCall(ExtraAllowModel): - name: str - arguments: str - - -class ToolCall(ExtraAllowModel): - # OpenAI API doesn't strictly specify existence of the index field - index: int | None - id: StrictStr - type: Literal["function"] - function: FunctionCall - - -class Message(ExtraAllowModel): - role: Role - content: StrictStr | None = Field(default=None) - custom_content: CustomContent | None = Field(default=None) - custom_fields: MessageCustomFields | None = None - name: StrictStr | None = Field(default=None) - tool_calls: list[ToolCall] | None = Field(default=None) - tool_call_id: StrictStr | None = Field(default=None) - function_call: FunctionCall | None = Field(default=None) +from pydantic import BaseModel class Pricing(BaseModel): diff --git a/src/common/utils/__init__.py b/src/common/utils/__init__.py index bf7650a..966d102 100644 --- a/src/common/utils/__init__.py +++ b/src/common/utils/__init__.py @@ -31,6 +31,9 @@ from .media_types import MediaTypes from .misc import ( batched, + crc32_hash, + crc32_hash_incremental, + crc32_hash_incremental_async, create_base64_uuid, get_last_commit_hash_for, secret_2_safe_str, diff --git a/src/common/utils/elastic.py b/src/common/utils/elastic.py index 342e9b1..143d3d6 100644 --- a/src/common/utils/elastic.py +++ b/src/common/utils/elastic.py @@ -2,6 +2,7 @@ from typing import Any from elasticsearch import AsyncElasticsearch, helpers +from elasticsearch.exceptions import BadRequestError from pydantic import BaseModel, ConfigDict, Field from common.settings.elastic import ElasticSearchSettings @@ -102,11 +103,22 @@ def name(self) -> str: async def exists(self) -> bool: return bool(await self._client.indices.exists(index=self.name)) - async def create(self): - await self._client.indices.create( - index=self._name, - settings=self._settings.index_settings, - ) + async def create(self, ignore_if_exists: bool = True) -> None: + """Creates the index with the defined settings. + + Args: + ignore_if_exists: If True, does not raise an error if the index already exists. + """ + try: + await self._client.indices.create( + index=self._name, + settings=self._settings.index_settings, + ) + except BadRequestError as ex: + if ignore_if_exists and ex.error == 'resource_already_exists_exception': + pass + else: + raise async def add(self, document: dict[str, str]) -> None: """Adds a JSON document to the index and makes it searchable. @@ -189,7 +201,7 @@ async def get_index(cls, name: str, allow_creation: bool = False) -> ElasticInde if not await index.exists(): if allow_creation: - await index.create() + await index.create(ignore_if_exists=True) else: raise RuntimeError(f"Index '{name}' does not exist.") diff --git a/src/common/utils/misc.py b/src/common/utils/misc.py index a65ffad..6fb47e9 100644 --- a/src/common/utils/misc.py +++ b/src/common/utils/misc.py @@ -1,12 +1,15 @@ +import asyncio import base64 import hashlib import itertools import subprocess import typing as t import uuid +import zlib +from collections.abc import Iterable -def batched(iterable: t.Iterable, n: int): +def batched(iterable: Iterable, n: int): """Batch data from the iterable into tuples of length n. The last batch may be shorter than n. In Python 3.12 and later, use the built-in `itertools.batched` function. @@ -28,6 +31,47 @@ def get_last_commit_hash_for(path: str) -> str: return commit_hash if (commit_hash := proc.stdout) is not None else "" +def crc32_hash(data: str) -> int: + """Compute CRC32 hash of a string and return it as a positive integer.""" + return zlib.crc32(data.encode("utf-8")) & 0xFFFFFFFF + + +def crc32_hash_incremental(values: list[str]) -> int: + """ + Compute CRC32 hash incrementally from a list of strings. + + This avoids creating a large intermediate string, reducing memory usage + and making the operation more efficient for large lists. + + Args: + values: Sorted list of strings to hash + + Returns: + CRC32 hash as a positive integer + """ + crc = 0 + for value in values: + # Hash each value with newline separator + crc = zlib.crc32(f"{value}\n".encode("utf-8"), crc) + return crc & 0xFFFFFFFF + + +async def crc32_hash_incremental_async(values: list[str]) -> int: + """ + Async version of crc32_hash_incremental. + + Offloads the blocking hash computation to a thread pool to avoid + blocking the asyncio event loop during large dataset processing. + + Args: + values: Sorted list of strings to hash + + Returns: + CRC32 hash as a positive integer + """ + return await asyncio.to_thread(crc32_hash_incremental, values) + + def str2bool(var: str) -> bool: return var.strip().lower() == "true" diff --git a/src/common/utils/timer.py b/src/common/utils/timer.py index af34c33..697f6b7 100644 --- a/src/common/utils/timer.py +++ b/src/common/utils/timer.py @@ -3,35 +3,47 @@ from collections.abc import Callable -class Timer: +class OptionalTimer: start: float format: str printer: Callable[[str], None] + enabled: bool def __init__( self, format: str = "Elapsed time: {time}", printer: Callable[[str], None] = print, + enabled: bool = True, ): self.start = time.perf_counter() self.format = format self.printer = printer + self.enabled = enabled def stop(self) -> float: + if not self.enabled: + return 0.0 return time.perf_counter() - self.start def __str__(self) -> str: + if not self.enabled: + return "disabled" return f"{self.stop():.3f}s" def __enter__(self): return def __exit__(self, type, value, traceback): - self.printer(self.format.format(time=self)) + if self.enabled: + self.printer(self.format.format(time=self)) _log = logging.getLogger(__name__) -def debug_timer(title: str) -> Timer: - return Timer(format="timer." + title + ": {time}", printer=_log.debug) +def debug_timer(title: str) -> OptionalTimer: + return OptionalTimer( + format="timer." + title + ": {time}", + printer=_log.debug, + enabled=_log.isEnabledFor(logging.DEBUG), + ) diff --git a/src/statgpt/app.py b/src/statgpt/app.py index a4de3dc..23725e3 100644 --- a/src/statgpt/app.py +++ b/src/statgpt/app.py @@ -7,7 +7,6 @@ sys.path.append(str(statgpt_path)) import dotenv -from aidial_sdk import DIALApp dotenv_path = os.path.join(os.getcwd(), ".env") @@ -21,6 +20,8 @@ _log.info("Initializing StatGPT application") +from aidial_sdk import DIALApp + def run_dial_app(app: DIALApp): import uvicorn diff --git a/src/statgpt/application/app_factory.py b/src/statgpt/application/app_factory.py index 0cbccf3..8655c08 100644 --- a/src/statgpt/application/app_factory.py +++ b/src/statgpt/application/app_factory.py @@ -51,7 +51,7 @@ def create_app(self) -> DIALApp: app.add_chat_completion_with_dependencies( "{deployment_id}", AppChatCompletion(), - heartbeat_interval=10, + heartbeat_interval=5, ) app.include_router(service_router) diff --git a/src/statgpt/chains/data_query/query_builder/query/finalize_query.py b/src/statgpt/chains/data_query/query_builder/query/finalize_query.py index 3773b32..7e70df8 100644 --- a/src/statgpt/chains/data_query/query_builder/query/finalize_query.py +++ b/src/statgpt/chains/data_query/query_builder/query/finalize_query.py @@ -329,8 +329,8 @@ def _map_dimension_ids_to_names(self, inputs: dict) -> DatasetDimensionTermNameT dataset: Sdmx21DataSet = datasets[dataset_id].data # type: ignore[assignment] dataset_dimension_id_to_name = {} for dimension, dimension_query in dataset_query.dimensions_queries_dict.items(): - id2name_mapping = dataset.map_dim_values_id_2_name( - value_ids=dimension_query.values, dimension_name=dimension + id2name_mapping = dataset.map_component_values_id_2_name( + value_ids=dimension_query.values, component_id=dimension ) # `None` is returned if the dimension has no corresponding code list, e.g., # when it's time period dimension. diff --git a/src/statgpt/chains/main.py b/src/statgpt/chains/main.py index 7f7684f..844b404 100644 --- a/src/statgpt/chains/main.py +++ b/src/statgpt/chains/main.py @@ -1,10 +1,7 @@ -import typing as t - from langchain_core.runnables import Runnable, RunnablePassthrough from common.config import logger from common.schemas import ChannelConfig -from common.schemas.dial import Message as DialMessage from statgpt.chains.out_of_scope_checker import OutOfScopeChecker from statgpt.chains.parameters import ChainParameters from statgpt.chains.supreme_agent import SupremeAgentExecutor, ToolCaller @@ -24,13 +21,8 @@ async def _init_history(inputs: dict) -> History: state = ChainParameters.get_state(inputs) data_service = ChainParameters.get_data_service(inputs) - # NOTE: we introduced custom Message model that uses pydantic v2, - # since aidial_sdk uses pydantic v1 models. - # the interface should be the same, but this is source of potential bugs. - dial_messages: list[DialMessage] = t.cast(list[DialMessage], request.messages) - return await History.from_dial_with_interceptors( - messages=dial_messages, state=state, data_service=data_service + messages=request.messages, state=state, data_service=data_service ) async def create_chain(self) -> Runnable: diff --git a/src/statgpt/chains/supreme_agent.py b/src/statgpt/chains/supreme_agent.py index b0ae8ae..00dde6c 100644 --- a/src/statgpt/chains/supreme_agent.py +++ b/src/statgpt/chains/supreme_agent.py @@ -4,7 +4,10 @@ from datetime import datetime from typing import NamedTuple -from aidial_sdk.chat_completion import Choice, Role +from aidial_sdk.chat_completion import Choice, FunctionCall +from aidial_sdk.chat_completion import Message as DialMessage +from aidial_sdk.chat_completion import Role +from aidial_sdk.chat_completion import ToolCall as DialToolCall from langchain_core.messages import AIMessage, AIMessageChunk, SystemMessage, ToolCall, ToolMessage from langchain_core.prompts import ( ChatPromptTemplate, @@ -16,9 +19,6 @@ from common.auth.auth_context import AuthContext from common.config import multiline_logger as logger from common.schemas import ChannelConfig, FakeCall -from common.schemas.dial import FunctionCall -from common.schemas.dial import Message as DialMessage -from common.schemas.dial import ToolCall as DialToolCall from common.utils import InvalidLLMStreamResponse from common.utils.markdown import format_as_markdown_list from common.utils.models import get_chat_model diff --git a/src/statgpt/chains/web_search/response_producer.py b/src/statgpt/chains/web_search/response_producer.py index ffc5501..54b51e4 100644 --- a/src/statgpt/chains/web_search/response_producer.py +++ b/src/statgpt/chains/web_search/response_producer.py @@ -1,11 +1,11 @@ import abc from typing import Any +from aidial_sdk.chat_completion import Attachment from openai import APIError from common.config import multiline_logger as logger from common.schemas import StagesConfig -from common.schemas.dial import Attachment from statgpt.chains.parameters import ChainParameters from statgpt.config import StateVarsConfig from statgpt.utils import OpenAiToDialStreamer, openai diff --git a/src/statgpt/schemas/__init__.py b/src/statgpt/schemas/__init__.py index e5b9c3d..991f2aa 100644 --- a/src/statgpt/schemas/__init__.py +++ b/src/statgpt/schemas/__init__.py @@ -5,7 +5,7 @@ LLMSelectionCandidateBase, SelectedCandidates, ) -from .service import GitVersionResponse, SettingsResponse +from .service import DimTypesResponse, GitVersionResponse, SettingsResponse from .state import ChatState from .tool_artifact import ( BaseFileRagArtifact, diff --git a/src/statgpt/schemas/dial_app_configuration.py b/src/statgpt/schemas/dial_app_configuration.py index 9b29b9f..0a398a9 100644 --- a/src/statgpt/schemas/dial_app_configuration.py +++ b/src/statgpt/schemas/dial_app_configuration.py @@ -2,12 +2,10 @@ from functools import cached_property from zoneinfo import ZoneInfo -from pydantic import ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator -from common.schemas.dial import ExtraAllowModel - -class StatGPTConfiguration(ExtraAllowModel): +class StatGPTConfiguration(BaseModel): """ Dynamic DIAL configuration for StatGPT application. """ diff --git a/src/statgpt/schemas/service.py b/src/statgpt/schemas/service.py index 14e2534..44062a6 100644 --- a/src/statgpt/schemas/service.py +++ b/src/statgpt/schemas/service.py @@ -1,5 +1,7 @@ from pydantic import BaseModel, Field +from common.data.base.dimension import DimensionProcessingType + class GitVersionResponse(BaseModel): git_commit: str = Field() @@ -9,3 +11,11 @@ class SettingsResponse(BaseModel): enable_dev_commands: bool = Field() enable_direct_tool_calls: bool = Field() git_commit: str = Field() + + +class DimTypesResponse(BaseModel): + channel_name: str + n_datasets: int + dataset_dim_types: dict[str, dict[str, DimensionProcessingType]] = Field( + default_factory=dict, description="dataset -> dimension -> types mapping" + ) diff --git a/src/statgpt/services/chat_facade.py b/src/statgpt/services/chat_facade.py index ca279fa..1ef19b8 100644 --- a/src/statgpt/services/chat_facade.py +++ b/src/statgpt/services/chat_facade.py @@ -8,9 +8,9 @@ from typing import Any import pandas as pd -from aidial_sdk.chat_completion import Button, FormMetaclass -from aidial_sdk.pydantic_v1 import BaseModel as PydanticV1BaseModel -from aidial_sdk.pydantic_v1 import Field as PydanticV1Field +from aidial_sdk.chat_completion.form import Button, FormMetaclass +from aidial_sdk.pydantic.v2 import ConfigDict as DialConfigDict +from aidial_sdk.pydantic.v2 import Field as DialField from pydantic import BaseModel, ConfigDict, Field, computed_field from sqlalchemy.ext.asyncio import AsyncSession @@ -245,11 +245,10 @@ def dial_channel_configuration(self) -> dict[str, Any]: f"No conversation starters configuration found for channel {self._channel.title}" ) - class InitConfiguration(PydanticV1BaseModel, metaclass=FormMetaclass): - class Config: - chat_message_input_disabled = False + class InitConfiguration(BaseModel, metaclass=FormMetaclass): + model_config = DialConfigDict(chat_message_input_disabled=False) - return InitConfiguration.schema() + return InitConfiguration.model_json_schema() intro_text: str = conversation_starters_config.intro_text _log.info( f"Conversation starters configuration found for channel {self._channel.title}, {conversation_starters_config=}" @@ -264,21 +263,21 @@ class Config: for i, button in enumerate(conversation_starters_config.buttons) ] - class StatGPTConfiguration(PydanticV1BaseModel, metaclass=FormMetaclass): - class Config: - chat_message_input_disabled = False + class StatGPTConfiguration(BaseModel, metaclass=FormMetaclass): + model_config = DialConfigDict(chat_message_input_disabled=False) - starter: int | None = PydanticV1Field( + starter: int | None = DialField( + default=None, description=intro_text, buttons=buttons, ) - timezone: str = PydanticV1Field( + timezone: str = DialField( description="Timezone in IANA format, e.g. 'Europe/Berlin', 'America/New_York'. " "Used to interpret and display dates and times.", default="UTC", ) - return StatGPTConfiguration.schema() + return StatGPTConfiguration.model_json_schema() def get_named_entity_types(self) -> list[str]: return self.channel_config.list_named_entity_types() diff --git a/src/statgpt/services/onboarding.py b/src/statgpt/services/onboarding.py index 796f897..9deec48 100644 --- a/src/statgpt/services/onboarding.py +++ b/src/statgpt/services/onboarding.py @@ -4,8 +4,8 @@ from typing import Literal from aidial_sdk.chat_completion import Button, Choice, FormMetaclass -from aidial_sdk.pydantic_v1 import BaseModel as PydanticV1BaseModel -from aidial_sdk.pydantic_v1 import Field as PydanticV1Field +from aidial_sdk.pydantic.v2 import ConfigDict as DialConfigDict +from aidial_sdk.pydantic.v2 import Field as DialField from pydantic import BaseModel, Field from common.auth.auth_context import AuthContext @@ -59,9 +59,8 @@ def set_completion_button(self) -> None: self.current_path = [] -class CompletedSchema(PydanticV1BaseModel, metaclass=FormMetaclass): - class Config: - chat_message_input_disabled = True +class CompletedSchema(BaseModel, metaclass=FormMetaclass): + model_config = DialConfigDict(chat_message_input_disabled=True) class OnboardingService: @@ -119,10 +118,10 @@ def get_form_schema( if button_clicked == "complete" and state.is_completion_button(): # User clicked the final completion button state.set_completed() - return CompletedSchema.schema() + return CompletedSchema.model_json_schema() if state.is_completed(): - return CompletedSchema.schema() + return CompletedSchema.model_json_schema() if not button_clicked: # No button clicked, show current state @@ -239,25 +238,23 @@ def _create_navigation_form(self, state: OnboardingState, show_intro: bool = Fal description = "\n".join(description_parts) - class NavigationForm(PydanticV1BaseModel, metaclass=FormMetaclass): - class Config: - chat_message_input_disabled = True + class NavigationForm(BaseModel, metaclass=FormMetaclass): + model_config = DialConfigDict(chat_message_input_disabled=True) - choice: str | None = PydanticV1Field( + choice: str | None = DialField( description=description, buttons=buttons, ) - return NavigationForm.schema() + return NavigationForm.model_json_schema() def _create_completion_button_form(self) -> dict: """Create completion form shown when all topics have been explored.""" - class CompletionButtonForm(PydanticV1BaseModel, metaclass=FormMetaclass): - class Config: - chat_message_input_disabled = True + class CompletionButtonForm(BaseModel, metaclass=FormMetaclass): + model_config = DialConfigDict(chat_message_input_disabled=True) - completion: int | None = PydanticV1Field( + completion: int | None = DialField( buttons=[ Button( const="complete", @@ -268,7 +265,7 @@ class Config: ], ) - return CompletionButtonForm.schema() + return CompletionButtonForm.model_json_schema() def _create_completion_form(self) -> dict: """ @@ -276,16 +273,15 @@ def _create_completion_form(self) -> dict: This form simply disables input. """ - class CompletionForm(PydanticV1BaseModel, metaclass=FormMetaclass): - class Config: - chat_message_input_disabled = True + class CompletionForm(BaseModel, metaclass=FormMetaclass): + model_config = DialConfigDict(chat_message_input_disabled=True) - completion: int | None = PydanticV1Field( + completion: int | None = DialField( description=self.config.completion_message, buttons=[], ) - return CompletionForm.schema() + return CompletionForm.model_json_schema() def get_response_for_path(self, path: list[str]) -> Response | None: """ diff --git a/src/statgpt/utils/message_history.py b/src/statgpt/utils/message_history.py index cfe1ea4..1a54631 100644 --- a/src/statgpt/utils/message_history.py +++ b/src/statgpt/utils/message_history.py @@ -2,7 +2,9 @@ import typing as t from collections.abc import Sequence +from aidial_sdk.chat_completion import Message as DialMessage from aidial_sdk.chat_completion import Role +from aidial_sdk.chat_completion import ToolCall as DialToolCall from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -14,8 +16,6 @@ from langchain_core.messages import ToolMessage from common.config import multiline_logger as logger -from common.schemas.dial import Message as DialMessage -from common.schemas.dial import ToolCall as DialToolCall from statgpt.config import StateVarsConfig from statgpt.schemas.tool_artifact import ToolArtifact from statgpt.services.chat_facade import ChannelServiceFacade diff --git a/src/statgpt/utils/message_interceptors/base.py b/src/statgpt/utils/message_interceptors/base.py index c6f4022..a4690f6 100644 --- a/src/statgpt/utils/message_interceptors/base.py +++ b/src/statgpt/utils/message_interceptors/base.py @@ -2,7 +2,7 @@ import typing as t from abc import ABC, abstractmethod -from common.schemas.dial import Message as DialMessage +from aidial_sdk.chat_completion import Message as DialMessage _log = logging.getLogger(__name__) diff --git a/src/statgpt/utils/message_interceptors/commands_interceptor.py b/src/statgpt/utils/message_interceptors/commands_interceptor.py index ad75be7..cc59cf3 100644 --- a/src/statgpt/utils/message_interceptors/commands_interceptor.py +++ b/src/statgpt/utils/message_interceptors/commands_interceptor.py @@ -1,11 +1,11 @@ import re import typing as t +from aidial_sdk.chat_completion import Message as DialMessage from aidial_sdk.chat_completion import Role from pydantic import BaseModel from common.config import multiline_logger as logger -from common.schemas.dial import Message as DialMessage from statgpt.config import StateVarsConfig from statgpt.settings.dial_app import dial_app_settings diff --git a/src/statgpt/utils/message_interceptors/system_msg_interceptor.py b/src/statgpt/utils/message_interceptors/system_msg_interceptor.py index c3793b1..fb3f580 100644 --- a/src/statgpt/utils/message_interceptors/system_msg_interceptor.py +++ b/src/statgpt/utils/message_interceptors/system_msg_interceptor.py @@ -1,11 +1,11 @@ import logging import typing as t +from aidial_sdk.chat_completion import Message as DialMessage from aidial_sdk.chat_completion import Role from aidial_sdk.exceptions import InvalidRequestError from pydantic import ValidationError -from common.schemas.dial import Message as DialMessage from common.schemas.query import JsonQuery from statgpt.services.chat_facade import ChannelServiceFacade From 23f556a6fc6c9b79f6dcedde007d12d931c0745f Mon Sep 17 00:00:00 2001 From: Daniil Yarmalkevich Date: Mon, 17 Nov 2025 14:38:14 +0100 Subject: [PATCH 2/6] added new code --- src/admin_portal/services/dataset.py | 109 +++++++++++++++--- src/common/config/versions.py | 2 +- src/common/data/sdmx/v21/sdmx_client.py | 42 ++++--- src/common/models/models.py | 5 + src/common/schemas/channel_dataset.py | 3 + src/common/services/data_preloader.py | 1 + src/common/services/dataset.py | 1 + .../pg_vector_store/pg_vector_store.py | 73 +++++++++--- .../services/test_sdmx_datasets.py | 14 ++- 9 files changed, 200 insertions(+), 50 deletions(-) diff --git a/src/admin_portal/services/dataset.py b/src/admin_portal/services/dataset.py index df4db7e..371c9d0 100644 --- a/src/admin_portal/services/dataset.py +++ b/src/admin_portal/services/dataset.py @@ -713,6 +713,16 @@ async def _create_new_channel_dataset_version( await self._session.refresh(item) return item + async def _update_channel_dataset_status( + self, item: models.ChannelDataset, new_status: StatusEnum, do_commit: bool = True + ) -> None: + item.clearing_status = new_status + item.updated_at = func.now() + + if do_commit: + await self._session.commit() + await self._session.refresh(item) + async def _update_channel_dataset_version_status( self, item: models.ChannelDatasetVersion, @@ -966,7 +976,19 @@ async def reload_all_indicators( max_n_embeddings=max_n_embeddings, ) + for ch_ds in channel_datasets: + background_tasks.add_task( + clear_channel_dataset_data_in_background_task, + channel_dataset_id=ch_ds.id, + auth_context=auth_context, + ) + await self._update_channel_dataset_status( + ch_ds, StatusEnum.NOT_STARTED, do_commit=False + ) + await self._session.commit() + for ch_ds in channel_datasets: + await self._session.refresh(ch_ds) for version in new_versions.values(): await self._session.refresh(version) @@ -1034,6 +1056,13 @@ async def reload_indicators( max_n_embeddings=max_n_embeddings, ) + background_tasks.add_task( + clear_channel_dataset_data_in_background_task, + channel_dataset_id=channel_dataset.id, + auth_context=auth_context, + ) + await self._update_channel_dataset_status(channel_dataset, StatusEnum.NOT_STARTED) + latest_version = schemas.ChannelDatasetVersion.model_validate(version, from_attributes=True) last_completed_versions_mapping = ( await self._get_latest_successful_channel_dataset_versions( @@ -1316,26 +1345,31 @@ async def reload_channel_dataset_in_background( status_on_completion: StatusEnum = StatusEnum.COMPLETED, ) -> None: version = await self._get_channel_dataset_version_or_raise(channel_dataset_version_id) - channel_dataset = await self._get_channel_dataset_model_or_raise(version.channel_dataset_id) - channel: models.Channel = await ChannelService(self._session).get_model_by_id( - channel_dataset.channel_id - ) - db_dataset: models.DataSet = await self.get_model_by_id(channel_dataset.dataset_id) - - if await self._invalid_version_status(version): - return - - handler_class = await DataSourceTypeService( - self._session - ).get_data_source_handler_class_by_id(db_dataset.source.type_id) - config = handler_class.parse_config(db_dataset.source.details) - _log.info(f"Start processing {version} of {channel_dataset}") + _log.info(f"Start processing {version}") try: + if await self._invalid_version_status(version): + return + await self._update_channel_dataset_version_status( version, new_status=StatusEnum.IN_PROGRESS ) + channel_dataset = await self._get_channel_dataset_model_or_raise( + version.channel_dataset_id + ) + _log.info( + f"Processing version(id={version.id}, version={version.version}) of {channel_dataset}" + ) + channel = await ChannelService(self._session).get_model_by_id( + channel_dataset.channel_id + ) + db_dataset: models.DataSet = await self.get_model_by_id(channel_dataset.dataset_id) + handler_class = await DataSourceTypeService( + self._session + ).get_data_source_handler_class_by_id(db_dataset.source.type_id) + config = handler_class.parse_config(db_dataset.source.details) + handler = handler_class(config=config) dataset = await handler.get_dataset( entity_id=db_dataset.id_, @@ -1383,17 +1417,39 @@ async def reload_channel_dataset_in_background( await self._update_channel_dataset_version_status( version, new_status=status_on_completion ) + if status_on_completion is StatusEnum.COMPLETED: + await self._update_channel_dataset_status( + channel_dataset, new_status=StatusEnum.QUEUED + ) _log.info(f'Finished processing {version} of {channel_dataset}') except Exception as e: - _log.exception(f"Failed to reindex {version} of {channel_dataset}") + _log.exception(f"Failed to reindex {version}") await self._update_channel_dataset_version_status( version, new_status=StatusEnum.FAILED, reason_for_failure=str(e) ) - # In case of failure, we clear the data that might have been partially indexed - # In case of success, we clear previous version data to save space - if status_on_completion == StatusEnum.COMPLETED: # only clear in last indexing job - await self.clear_channel_dataset_versions_data(channel.id, db_dataset.id, auth_context) + async def clear_channel_dataset_data_in_background( + self, channel_dataset_id: int, auth_context: AuthContext + ) -> None: + channel_dataset = await self._get_channel_dataset_model_or_raise(channel_dataset_id) + + _log.info(f"Clear data after reindexing {channel_dataset}") + try: + await self._update_channel_dataset_status( + channel_dataset, new_status=StatusEnum.IN_PROGRESS + ) + + # In case of failure, we clear the data that might have been partially indexed + # In case of success, we clear previous version data to save space + await self.clear_channel_dataset_versions_data( + channel_dataset.channel_id, channel_dataset.dataset_id, auth_context + ) + await self._update_channel_dataset_status( + channel_dataset, new_status=StatusEnum.COMPLETED + ) + except Exception: + _log.exception(f"Failed to clear data after reindexing {channel_dataset}") + await self._update_channel_dataset_status(channel_dataset, new_status=StatusEnum.FAILED) @background_task @@ -1422,3 +1478,18 @@ async def reload_indicators_in_background_task( ) except Exception as e: _log.exception(e) + + +@background_task +async def clear_channel_dataset_data_in_background_task( + channel_dataset_id: int, auth_context: AuthContext +) -> None: + try: + async with models.get_session_contex_manager() as session: + service = AdminPortalDataSetService(session) + await service.clear_channel_dataset_data_in_background( + channel_dataset_id=channel_dataset_id, + auth_context=auth_context, + ) + except Exception as e: + _log.exception(e) diff --git a/src/common/config/versions.py b/src/common/config/versions.py index e612e44..84bd034 100644 --- a/src/common/config/versions.py +++ b/src/common/config/versions.py @@ -7,4 +7,4 @@ class Versions: # Please update this version when you create a new alembic revision. # Needed because alembic folder exist only in the admin_portal package. # (statgpt Dockerfile doesn't copy admin_portal package to the container) - ALEMBIC_TARGET_VERSION = 'd528d881ece8' + ALEMBIC_TARGET_VERSION = '65c149c7db9e' diff --git a/src/common/data/sdmx/v21/sdmx_client.py b/src/common/data/sdmx/v21/sdmx_client.py index 79ef43c..a348165 100644 --- a/src/common/data/sdmx/v21/sdmx_client.py +++ b/src/common/data/sdmx/v21/sdmx_client.py @@ -313,21 +313,35 @@ async def _perform_request(self, req: PreparedRequest, max_retries=3, delay=3) - try: while True: attempts += 1 - resp = await self._httpx_client.request( - method=req.method, # type: ignore[arg-type] - url=req.url, # type: ignore[arg-type] - headers=req.headers, - content=req.body, - ) - if attempts == max_retries or resp.status_code < 500: - resp.raise_for_status() - return resp - else: - logger.error( - f"Server failed to respond after {attempts} attempts: {resp.status_code} {resp.text}\n" - f"Retrying in {delay} seconds...\nRequest: {req.method} {req.url} body={req.body!r}" + try: + resp = await self._httpx_client.request( + method=req.method, # type: ignore[arg-type] + url=req.url, # type: ignore[arg-type] + headers=req.headers, + content=req.body, ) - await asyncio.sleep(delay) + if attempts == max_retries or resp.status_code < 500: + resp.raise_for_status() + return resp + else: + logger.error( + f"Server failed to respond after {attempts} attempts: {resp.status_code} {resp.text}\n" + f"Retrying in {delay} seconds...\nRequest: {req.method} {req.url} body={req.body!r}" + ) + await asyncio.sleep(delay) + except httpx.ConnectTimeout: + if attempts == max_retries: + logger.exception( + f"Connection timed out after {attempts} attempts: " + f"{req.method} {req.url} body={req.body!r}" + ) + raise + else: + logger.error( + f"Connection timed out after {attempts} attempts. Retrying in {delay} seconds..." + f"\nRequest: {req.method} {req.url} body={req.body!r}\n" + ) + await asyncio.sleep(delay) except Exception: logger.exception( f"Server failed to respond, after {attempts} attempts: " diff --git a/src/common/models/models.py b/src/common/models/models.py index 562d184..8f47a94 100644 --- a/src/common/models/models.py +++ b/src/common/models/models.py @@ -123,6 +123,11 @@ class ChannelDataset(DefaultBase): channel_id: Mapped[int] = mapped_column(ForeignKey("channels.id")) dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id")) + clearing_status: Mapped[PreprocessingStatusEnum] = mapped_column( + default=PreprocessingStatusEnum.NOT_STARTED + ) + """The status of the data clearing task run after reindexing is complete.""" + # relationships channel: Mapped[Channel] = relationship(back_populates="mapped_datasets") dataset: Mapped[DataSet] = relationship(back_populates="mapped_channels") diff --git a/src/common/schemas/channel_dataset.py b/src/common/schemas/channel_dataset.py index da942f0..0e9d1a5 100644 --- a/src/common/schemas/channel_dataset.py +++ b/src/common/schemas/channel_dataset.py @@ -45,6 +45,9 @@ class ChannelDatasetExpanded(ChannelDatasetBase): preprocessing_status: PreprocessingStatusEnum = Field( description="The preprocessing status of the latest version." ) + clearing_status: PreprocessingStatusEnum = Field( + description="The clearing status of the channel dataset." + ) last_completed_version: ChannelDatasetVersion | None previous_completed_version: ChannelDatasetVersion | None = Field( diff --git a/src/common/services/data_preloader.py b/src/common/services/data_preloader.py index 2ee2a26..8ad2503 100644 --- a/src/common/services/data_preloader.py +++ b/src/common/services/data_preloader.py @@ -35,6 +35,7 @@ async def preload_data(allow_cached_datasets: bool) -> None: offset=0, auth_context=_DataPreloaderAuthContext(), allow_cached_datasets=allow_cached_datasets, + allow_offline=True, ) _log.info(f'{len(datasets)} datasets loaded') except Exception: diff --git a/src/common/services/dataset.py b/src/common/services/dataset.py index 273b842..ac59cc6 100644 --- a/src/common/services/dataset.py +++ b/src/common/services/dataset.py @@ -74,6 +74,7 @@ def db_to_schema( channel_id=item_db.channel_id, dataset_id=item_db.dataset_id, preprocessing_status=preprocessing_status, + clearing_status=item_db.clearing_status, dataset=dataset, latest_version=latest_version, last_completed_version=last_completed_versions.last_completed_version, diff --git a/src/common/vectorstore/pg_vector_store/pg_vector_store.py b/src/common/vectorstore/pg_vector_store/pg_vector_store.py index 3ec38e4..b9ee1dc 100644 --- a/src/common/vectorstore/pg_vector_store/pg_vector_store.py +++ b/src/common/vectorstore/pg_vector_store/pg_vector_store.py @@ -169,6 +169,9 @@ async def _acquire_dataset_lock(self, dataset_id: uuid.UUID) -> int: """ lock_key = self._dataset_lock_key(dataset_id) async with self._lock_session() as session: + # Ensure session is in a clean state before acquiring lock + if session.in_transaction() and not session.is_active: + await session.rollback() await session.execute( text("SELECT pg_advisory_lock(:lock_key)"), {"lock_key": lock_key} ) @@ -178,11 +181,26 @@ async def _acquire_dataset_lock(self, dataset_id: uuid.UUID) -> int: return lock_key async def _release_dataset_lock(self, lock_key: int) -> None: - """Releases a session-level advisory lock.""" + """Releases a session-level advisory lock. + + Handles PendingRollback state to ensure lock is always released. + """ async with self._lock_session() as session: - await session.execute( - text("SELECT pg_advisory_unlock(:lock_key)"), {"lock_key": lock_key} - ) + try: + # Rollback if session is in a failed transaction state + if session.in_transaction() and not session.is_active: + await session.rollback() + await session.execute( + text("SELECT pg_advisory_unlock(:lock_key)"), {"lock_key": lock_key} + ) + except Exception as e: + _log.error(f"Failed to release advisory lock (key={lock_key}): {e}") + # Still try to rollback to clean up the session + try: + await session.rollback() + except Exception: + pass + raise _log.debug(f"Released advisory lock (key={lock_key})") @asynccontextmanager @@ -804,9 +822,13 @@ def _validate_import_metadata_from_zipfile( async def _import_documents_from_zipfile( self, zip_file: zipfile.ZipFile, document_model: type[BaseDocument], file_path: str - ) -> int: - """Imports documents from Parquet file in zip archive in batches.""" + ) -> dict[int, int]: + """Imports documents from Parquet file in zip archive in batches. + + Returns a mapping from old document IDs to new auto-generated IDs. + """ doc_count = 0 + id_mapping: dict[int, int] = {} # Read Parquet file from zip archive with zip_file.open(file_path) as f: @@ -820,7 +842,7 @@ async def _import_documents_from_zipfile( for i in range(len(batch_dict['id'])): try: doc = document_model( - id=batch_dict['id'][i], + # Don't set id - let PostgreSQL auto-generate it document=self._sanitize_text(batch_dict['document'][i]), embeddings=( batch_dict['embeddings'][i].tolist() @@ -836,11 +858,18 @@ async def _import_documents_from_zipfile( raise ValueError(f"Corrupt document data: {e}") self._session.add_all(doc_batch) + await self._session.flush() + + # Map old IDs to new auto-generated IDs + for i, doc in enumerate(doc_batch): + old_id = batch_dict['id'][i] + id_mapping[old_id] = doc.id + await self._session.commit() doc_count += len(doc_batch) _log.info(f"Imported {doc_count} documents...") - return doc_count + return id_mapping async def _import_mappings_from_zipfile( self, @@ -849,8 +878,13 @@ async def _import_mappings_from_zipfile( file_path: str, dataset_versions: dict[uuid.UUID, int], data_sources: dict[uuid.UUID, int], + id_mapping: dict[int, int], ) -> int: - """Imports metadata mappings from Parquet file in zip archive in batches.""" + """Imports metadata mappings from Parquet file in zip archive in batches. + + Args: + id_mapping: Maps old document IDs to new auto-generated IDs + """ mapping_count = 0 # Read Parquet file from zip archive @@ -865,6 +899,18 @@ async def _import_mappings_from_zipfile( for i in range(len(batch_dict['id'])): try: dataset_id = uuid.UUID(batch_dict['dataset_id'][i]) + old_document_id = batch_dict['document_id'][i] + + # Look up new document ID + new_document_id = id_mapping.get(old_document_id) + if new_document_id is None: + _log.error( + f"Mapping references non-existent document_id={old_document_id}. " + f"This indicates corrupted or mismatched export data." + ) + raise ValueError( + f"Mapping references document_id={old_document_id} which was not found in documents" + ) # 'details' field processing details = json.loads(batch_dict['details'][i]) @@ -876,8 +922,8 @@ async def _import_mappings_from_zipfile( ) mapping = metadata_model( - id=batch_dict['id'][i], - document_id=batch_dict['document_id'][i], + # Don't set id - let PostgreSQL auto-generate it + document_id=new_document_id, # Use new document ID dataset_id=dataset_id, version_id=dataset_versions[dataset_id], details=details, @@ -943,13 +989,14 @@ async def import_from_zipfile( doc_count = 0 mapping_count = 0 else: - doc_count = await self._import_documents_from_zipfile( + id_mapping = await self._import_documents_from_zipfile( zip_file, document_model, documents_path ) + doc_count = len(id_mapping) _log.info(f"Imported {doc_count} documents") mapping_count = await self._import_mappings_from_zipfile( - zip_file, metadata_model, mappings_path, dataset_versions, data_sources + zip_file, metadata_model, mappings_path, dataset_versions, data_sources, id_mapping ) _log.info(f"Imported {mapping_count} mappings") diff --git a/tests/integration/services/test_sdmx_datasets.py b/tests/integration/services/test_sdmx_datasets.py index 5a21391..3adb258 100644 --- a/tests/integration/services/test_sdmx_datasets.py +++ b/tests/integration/services/test_sdmx_datasets.py @@ -7,7 +7,10 @@ from admin_portal.services import AdminPortalChannelService as ChannelService from admin_portal.services import AdminPortalDataSetService as DataSetService from admin_portal.services import AdminPortalDataSourceService as DataSourceService -from admin_portal.services.dataset import reload_indicators_in_background_task +from admin_portal.services.dataset import ( + clear_channel_dataset_data_in_background_task, + reload_indicators_in_background_task, +) from common import schemas from common.data.base import DatasetCitation, IndexerConfig from common.data.base.dataset import IndexerIndicatorConfig @@ -505,16 +508,21 @@ async def test_reload_all_indicators(session, clear_all, sdmx_clint_mock): assert channel_ds.latest_version.reason_for_failure is None assert isinstance(channel_ds.latest_version.id, int) + channel_dataset_ids = {channel_ds.id for channel_ds in res} version_ids = {channel_ds.latest_version.id for channel_ds in res if channel_ds.latest_version} - assert len(background_tasks.tasks) == 2 - for f, args, kwargs in background_tasks.tasks: + assert len(background_tasks.tasks) == 4 + for f, args, kwargs in background_tasks.tasks[:2]: assert f == reload_indicators_in_background_task assert kwargs['channel_dataset_version_id'] in version_ids assert kwargs['version_ids'] == version_ids assert kwargs['max_n_embeddings'] == 5 assert kwargs['status_on_completion'] == schemas.PreprocessingStatusEnum.COMPLETED + for f, args, kwargs in background_tasks.tasks[2:]: + assert f == clear_channel_dataset_data_in_background_task + assert kwargs['channel_dataset_id'] in channel_dataset_ids + res2 = await dataset_service.get_channel_dataset_schemas( limit=100, offset=0, channel_id=channel.id, auth_context=SystemUserAuthContext() ) From 010ce24c19d14e8340227ddfef5ce96552c4f7d2 Mon Sep 17 00:00:00 2001 From: Daniil Yarmalkevich Date: Wed, 3 Dec 2025 12:43:30 +0200 Subject: [PATCH 3/6] added new code --- src/admin_portal/admin.sh | 32 +++ ...e37902_reset_sequences_for_vector_store.py | 82 ++++++++ ..._add_cleanup_status_to_channel_dataset_.py | 35 ++++ src/admin_portal/app.py | 2 +- src/admin_portal/fix_statuses.py | 34 ++++ src/admin_portal/routers/channel.py | 30 ++- src/admin_portal/routers/data_source.py | 6 + src/admin_portal/routers/dataset.py | 3 + src/admin_portal/routers/glossary_of_terms.py | 3 + src/admin_portal/services/channel.py | 1 - src/admin_portal/services/dataset.py | 191 +++++++++++++++++- src/common/README.md | 4 + src/common/models/database.py | 87 ++++++-- src/common/models/health_checker.py | 40 ++-- src/common/prompts/assets/indexer.yaml | 10 +- src/common/schemas/__init__.py | 13 +- src/common/schemas/channel.py | 54 ++++- src/common/schemas/channel_dataset.py | 10 + src/common/schemas/enums.py | 15 ++ src/common/settings/database.py | 18 +- src/common/utils/cancel_dependency.py | 53 +++++ src/common/vectorstore/base.py | 16 ++ .../pg_vector_store/pg_vector_store.py | 117 +++++++++++ src/statgpt/README.md | 43 ++-- src/statgpt/application/app_factory.py | 6 + src/statgpt/application/application.py | 2 +- .../chains/candidates_selection_simple.py | 36 +++- .../data_query_artifacts_displayer.py | 10 +- .../query_builder/query/summarize_query.py | 5 +- src/statgpt/chains/supreme_agent.py | 9 +- .../default_prompts/assets/data_query.yaml | 9 +- src/statgpt/schemas/dial_app_configuration.py | 6 + src/statgpt/services/chat_facade.py | 29 +-- src/statgpt/services/hybrid_searcher.py | 70 ++++--- src/statgpt/settings/dial_app.py | 6 + tests/unit/test_alembic_version.py | 99 +++++++++ 36 files changed, 1052 insertions(+), 134 deletions(-) create mode 100644 src/admin_portal/admin.sh create mode 100644 src/admin_portal/alembic/versions/2025_11_15_0834-c64458e37902_reset_sequences_for_vector_store.py create mode 100644 src/admin_portal/alembic/versions/2025_11_17_1013-65c149c7db9e_add_cleanup_status_to_channel_dataset_.py create mode 100644 src/admin_portal/fix_statuses.py create mode 100644 src/common/utils/cancel_dependency.py create mode 100644 tests/unit/test_alembic_version.py diff --git a/src/admin_portal/admin.sh b/src/admin_portal/admin.sh new file mode 100644 index 0000000..396ed13 --- /dev/null +++ b/src/admin_portal/admin.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +echo "ADMIN_MODE = '$ADMIN_MODE'" + +case $ADMIN_MODE in + + APP) + uvicorn "admin_portal.app:app" --host "0.0.0.0" --port 8000 --lifespan on + ;; + + ALEMBIC_UPGRADE) + alembic upgrade head + ;; + + FIX_STATUSES) + python -m admin_portal.fix_statuses + ;; + + INIT) + alembic upgrade head + python -m admin_portal.fix_statuses + ;; + + *) + echo "Unknown ADMIN_MODE = '$ADMIN_MODE'. Possible values:" + echo " APP - start the admin portal application" + echo " ALEMBIC_UPGRADE - run alembic migrations to upgrade the database" + echo " FIX_STATUSES - fix inconsistent statuses in the database" + echo " INIT - run alembic migrations and fix inconsistent statuses" + exit 1 + ;; +esac diff --git a/src/admin_portal/alembic/versions/2025_11_15_0834-c64458e37902_reset_sequences_for_vector_store.py b/src/admin_portal/alembic/versions/2025_11_15_0834-c64458e37902_reset_sequences_for_vector_store.py new file mode 100644 index 0000000..0b56f68 --- /dev/null +++ b/src/admin_portal/alembic/versions/2025_11_15_0834-c64458e37902_reset_sequences_for_vector_store.py @@ -0,0 +1,82 @@ +"""Reset sequences for vector store + +Revision ID: c64458e37902 +Revises: d528d881ece8 +Create Date: 2025-11-15 08:34:48.017064 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = 'c64458e37902' +down_revision: str | None = 'd528d881ece8' +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Reset sequences for vector store tables to prevent duplicate key errors. + + This migration fixes sequence values for tables that had explicit IDs set during import, + which caused the sequences to fall behind the actual maximum ID values in the tables. + """ + conn = op.get_bind() + + # Get all tables in the collections schema matching the prefixes + result = conn.execute( + sa.text( + """ + SELECT table_name + FROM information_schema.tables + WHERE table_schema = 'collections' + AND table_type = 'BASE TABLE' + AND (table_name LIKE 'AvailableDimensions%' + OR table_name LIKE 'Indicators%' + OR table_name LIKE 'SpecialDimensions%') + ORDER BY table_name + """ + ) + ) + + tables = [row[0] for row in result] + + for table_name in tables: + # Get the sequence name for the id column + seq_result = conn.execute( + sa.text( + """ + SELECT pg_get_serial_sequence(:full_table_name, 'id') + """ + ), + {"full_table_name": f'collections."{table_name}"'}, + ) + + sequence_name = seq_result.scalar() + + if sequence_name: + # Reset the sequence to MAX(id) + 1 + # Use COALESCE to handle empty tables (set to 1 in that case) + conn.execute( + sa.text( + f""" + SELECT setval( + :sequence_name, + COALESCE((SELECT MAX(id) FROM collections."{table_name}"), 1) + ) + """ + ), + {"sequence_name": sequence_name}, + ) + + print(f"Reset sequence for table: {table_name} -> {sequence_name}") + else: + print(f"No sequence found for table: {table_name}") + + +def downgrade() -> None: + # This migration cannot be reversed as we don't know the original sequence values + pass diff --git a/src/admin_portal/alembic/versions/2025_11_17_1013-65c149c7db9e_add_cleanup_status_to_channel_dataset_.py b/src/admin_portal/alembic/versions/2025_11_17_1013-65c149c7db9e_add_cleanup_status_to_channel_dataset_.py new file mode 100644 index 0000000..9b204ce --- /dev/null +++ b/src/admin_portal/alembic/versions/2025_11_17_1013-65c149c7db9e_add_cleanup_status_to_channel_dataset_.py @@ -0,0 +1,35 @@ +"""Add cleanup status to channel dataset table + +Revision ID: 65c149c7db9e +Revises: d528d881ece8 +Create Date: 2025-11-14 16:33:52.089023 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '65c149c7db9e' +down_revision: str | None = 'c64458e37902' +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column( + 'channel_datasets', + sa.Column( + 'clearing_status', + postgresql.ENUM(name='preprocessingstatusenum', create_type=False), + nullable=False, + server_default='NOT_STARTED', + ), + ) + + +def downgrade() -> None: + op.drop_column('channel_datasets', 'clearing_status') diff --git a/src/admin_portal/app.py b/src/admin_portal/app.py index 8e4864e..fdb5ce9 100644 --- a/src/admin_portal/app.py +++ b/src/admin_portal/app.py @@ -30,7 +30,7 @@ async def lifespan(app_: FastAPI): async with optional_msi_token_manager_context(): # Check resources' availability: - await DatabaseHealthChecker.check() + await DatabaseHealthChecker().check() # Start data preloading in the background asyncio.create_task(preload_data(allow_cached_datasets=False)) diff --git a/src/admin_portal/fix_statuses.py b/src/admin_portal/fix_statuses.py new file mode 100644 index 0000000..d1b26b0 --- /dev/null +++ b/src/admin_portal/fix_statuses.py @@ -0,0 +1,34 @@ +""" +Script to fix statuses for channel dataset versions after migrations. +Sets failed status for any channel dataset versions that were left in processing state. +""" + +import asyncio +import logging + +from admin_portal.services import AdminPortalDataSetService +from common.models import get_session_contex_manager + +_log = logging.getLogger(__name__) + + +async def fix_statuses(): + """Fix statuses from previous runs by setting failed status for stuck channel dataset versions.""" + async with get_session_contex_manager() as session: + service = AdminPortalDataSetService(session) + await service.set_failed_status_for_channel_dataset_version() + _log.info("Successfully fixed statuses for channel dataset versions") + + +async def main(): + try: + _log.info("Starting fix_statuses script...") + await fix_statuses() + _log.info("fix_statuses script completed successfully") + except Exception: + _log.exception("Error in fix_statuses script:") + raise + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/src/admin_portal/routers/channel.py b/src/admin_portal/routers/channel.py index c0d3d61..caf1525 100644 --- a/src/admin_portal/routers/channel.py +++ b/src/admin_portal/routers/channel.py @@ -10,12 +10,13 @@ import common.models as models import common.schemas as schemas from admin_portal.auth.auth_context import SystemUserAuthContext -from admin_portal.auth.user import User, require_jwt_auth +from admin_portal.auth.user import require_jwt_auth from admin_portal.services import AdminPortalChannelService as ChannelService from admin_portal.services import AdminPortalDataSetService as DataSetService from admin_portal.services import JobsService from admin_portal.settings.exim import JobsConfig from common.settings.dial import dial_settings +from common.utils.cancel_dependency import cancel_on_disconnect router = APIRouter( prefix="/channels", @@ -29,6 +30,7 @@ async def get_channels( limit: int = 100, offset: int = 0, session: AsyncSession = Depends(models.get_session), + _=Depends(cancel_on_disconnect), ) -> schemas.ListResponse[schemas.Channel]: """Returns a list of channels""" @@ -59,7 +61,7 @@ async def create_channel( async def get_channel_by_id( item_id: int, session: AsyncSession = Depends(models.get_session), - user: User = Depends(require_jwt_auth, use_cache=False), + _=Depends(cancel_on_disconnect), ) -> schemas.Channel: return await ChannelService(session).get_schema_by_id(item_id) @@ -119,6 +121,7 @@ async def get_jobs( limit: int = 100, offset: int = 0, session: AsyncSession = Depends(models.get_session), + _=Depends(cancel_on_disconnect), ) -> schemas.ListResponse[schemas.Job]: """Get a list of import/export jobs for the specified channel""" @@ -139,6 +142,7 @@ async def get_jobs( async def get_job_by_id( job_id: int, session: AsyncSession = Depends(models.get_session), + _=Depends(cancel_on_disconnect), ) -> schemas.Job: """Get information (e.g. status) about the import/export job""" @@ -149,6 +153,7 @@ async def get_job_by_id( async def download_job_result_by_id( job_id: int, session: AsyncSession = Depends(models.get_session), + _=Depends(cancel_on_disconnect), ) -> StreamingResponse: """Download the zip file with the exported channel data by job id. @@ -236,6 +241,7 @@ async def get_list_of_channel_datasets( limit: int = 100, offset: int = 0, session: AsyncSession = Depends(models.get_session), + _=Depends(cancel_on_disconnect), ) -> schemas.ListResponse[schemas.ChannelDatasetExpanded]: """Returns a list of datasets for the specified channel""" @@ -316,11 +322,29 @@ async def deduplicate_channel( ) +@router.get(path="/{channel_id}/index-status") +async def get_index_status( + channel_id: int, + scope: schemas.ChannelIndexStatusScope, + session: AsyncSession = Depends(models.get_session), + _=Depends(cancel_on_disconnect), +) -> schemas.ChannelIndexStatus: + """Get the index status for the specified channel.""" + + service = DataSetService(session) + return await service.check_index_status( + channel_id=channel_id, + auth_context=SystemUserAuthContext(), + scope=scope, + ) + + @router.get("/{channel_id}/datasets/{dataset_id}") async def get_channel_dataset( channel_id: int, dataset_id: int, session: AsyncSession = Depends(models.get_session), + _=Depends(cancel_on_disconnect), ) -> schemas.ChannelDatasetExpanded: return await DataSetService(session).get_channel_dataset_schema( channel_id=channel_id, dataset_id=dataset_id, auth_context=SystemUserAuthContext() @@ -386,6 +410,7 @@ async def get_channel_dataset_versions( limit: int = 100, offset: int = 0, session: AsyncSession = Depends(models.get_session), + _=Depends(cancel_on_disconnect), ) -> schemas.ListResponse[schemas.ChannelDatasetVersion]: """Returns a list of dataset versions for the specified channel and dataset""" @@ -414,6 +439,7 @@ async def is_channel_dataset_latest_version_up_to_date( channel_id: int, dataset_id: int, session: AsyncSession = Depends(models.get_session), + _=Depends(cancel_on_disconnect), ) -> schemas.ChangesBetweenVersionAndActualData: """Check if the latest completed version of the specified channel dataset is up to date.""" return await DataSetService(session).is_channel_dataset_latest_version_up_to_date( diff --git a/src/admin_portal/routers/data_source.py b/src/admin_portal/routers/data_source.py index 24bfcc8..8b78851 100644 --- a/src/admin_portal/routers/data_source.py +++ b/src/admin_portal/routers/data_source.py @@ -7,6 +7,7 @@ from admin_portal.services import AdminPortalDataSetService as DataSetService from admin_portal.services import AdminPortalDataSourceService as DataSourceService from common.services import DataSourceTypeService +from common.utils.cancel_dependency import cancel_on_disconnect router = APIRouter( prefix="/data-sources", tags=["data-sources"], dependencies=[Depends(require_jwt_auth)] @@ -18,6 +19,7 @@ async def get_data_source_types( limit: int = 100, offset: int = 0, session: AsyncSession = Depends(models.get_session), + _=Depends(cancel_on_disconnect), ) -> schemas.ListResponse[schemas.DataSourceType]: service = DataSourceTypeService(session) data_source_types = await service.get_data_source_types(limit=limit, offset=offset) @@ -36,6 +38,7 @@ async def get_data_source_types( async def get_schema_config_of_data_source_type( item_id: int, session: AsyncSession = Depends(models.get_session), + _=Depends(cancel_on_disconnect), ): """Returns the JSON schema for a specific data source type.""" @@ -48,6 +51,7 @@ async def get_data_sources( limit: int = 100, offset: int = 0, session: AsyncSession = Depends(models.get_session), + _=Depends(cancel_on_disconnect), ) -> schemas.ListResponse[schemas.DataSource]: """Returns a list of data sources""" @@ -78,6 +82,7 @@ async def create_data_source( async def get_data_source_by_id( item_id: int, session: AsyncSession = Depends(models.get_session), + _=Depends(cancel_on_disconnect), ) -> schemas.DataSource: return await DataSourceService(session).get_schema_by_id(item_id) @@ -86,6 +91,7 @@ async def get_data_source_by_id( async def get_available_datasets( item_id: int, session: AsyncSession = Depends(models.get_session), + _=Depends(cancel_on_disconnect), ) -> schemas.ListResponse[schemas.DataSetDescriptor]: """Returns a list of datasets that can be loaded from the data source""" diff --git a/src/admin_portal/routers/dataset.py b/src/admin_portal/routers/dataset.py index fad5253..5b22c39 100644 --- a/src/admin_portal/routers/dataset.py +++ b/src/admin_portal/routers/dataset.py @@ -6,6 +6,7 @@ from admin_portal.auth.auth_context import SystemUserAuthContext from admin_portal.auth.user import require_jwt_auth from admin_portal.services import AdminPortalDataSetService as DataSetService +from common.utils.cancel_dependency import cancel_on_disconnect router = APIRouter(prefix="/datasets", tags=["datasets"], dependencies=[Depends(require_jwt_auth)]) @@ -17,6 +18,7 @@ async def get_datasets( data_source_id: int | None = None, channel_id: int | None = None, session: AsyncSession = Depends(models.get_session), + _=Depends(cancel_on_disconnect), ) -> schemas.ListResponse[schemas.DataSet]: """Returns a list of added datasets""" @@ -56,6 +58,7 @@ async def register_dataset( async def get_dataset_by_id( item_id: int, session: AsyncSession = Depends(models.get_session), + _=Depends(cancel_on_disconnect), ) -> schemas.DataSet: return await DataSetService(session).get_schema_by_id( item_id, auth_context=SystemUserAuthContext(), allow_offline=True diff --git a/src/admin_portal/routers/glossary_of_terms.py b/src/admin_portal/routers/glossary_of_terms.py index 5a5ae4e..d55fcd2 100644 --- a/src/admin_portal/routers/glossary_of_terms.py +++ b/src/admin_portal/routers/glossary_of_terms.py @@ -5,6 +5,7 @@ import common.schemas as schemas from admin_portal.auth.user import require_jwt_auth from admin_portal.services import AdminPortalGlossaryOfTermsService as GlossaryOfTermsService +from common.utils.cancel_dependency import cancel_on_disconnect terms_router = APIRouter(prefix="/terms", tags=["glossary_of_terms"]) channel_terms_router = APIRouter( @@ -20,6 +21,7 @@ async def get_channel_glossary_terms( limit: int = 100, offset: int = 0, session: AsyncSession = Depends(models.get_session), + _=Depends(cancel_on_disconnect), ) -> schemas.ListResponse[schemas.GlossaryTerm]: """Returns a list of glossary terms for the channel.""" @@ -93,6 +95,7 @@ async def delete_terms_bulk( async def get_glossary_term_by_id( item_id: int, session: AsyncSession = Depends(models.get_session), + _=Depends(cancel_on_disconnect), ) -> schemas.GlossaryTerm: """Returns a glossary term by id.""" diff --git a/src/admin_portal/services/channel.py b/src/admin_portal/services/channel.py index ca2d96f..5dd2189 100644 --- a/src/admin_portal/services/channel.py +++ b/src/admin_portal/services/channel.py @@ -260,7 +260,6 @@ async def deduplicate_all_dimensions( """ channel = await self.get_model_by_id(channel_id) - # Start background task background_tasks.add_task( deduplicate_dimensions_in_background_task, channel_id=channel_id, diff --git a/src/admin_portal/services/dataset.py b/src/admin_portal/services/dataset.py index 371c9d0..8ddab5c 100644 --- a/src/admin_portal/services/dataset.py +++ b/src/admin_portal/services/dataset.py @@ -10,7 +10,7 @@ from fastapi import BackgroundTasks, HTTPException, status from pydantic import ValidationError from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.sql.expression import func +from sqlalchemy.sql.expression import func, text, update import common.models as models import common.schemas as schemas @@ -21,7 +21,7 @@ from common.data import base from common.data.base.dataset import DataSetConfigType from common.hybrid_indexer import Indexer -from common.schemas import HybridSearchConfig +from common.schemas import ChannelIndexStatusScope, HybridSearchConfig from common.schemas import PreprocessingStatusEnum as StatusEnum from common.services import ChannelDataSetSerializer, DataSetSerializer, DataSetService from common.services.dataset import LastCompletedVersions @@ -73,7 +73,7 @@ async def _export_vector_store_data( version_ids: set[int] = set() for versions in latest_completed_versions.values(): if versions.last_completed_version: - version_ids.add(versions.last_completed_version.id) + version_ids.add(versions.last_completed_version.version_data_id) _log.info(f"Exporting {len(version_ids)} version(s): {sorted(version_ids)}") @@ -917,6 +917,34 @@ async def clear_channel_dataset_versions_data( else: _log.info("No versions to clear data for.") + async def set_failed_status_for_channel_dataset_version(self) -> None: + """Sets the status of all not-completed channel dataset versions to FAILED.""" + + _log.info("Setting FAILED status for all non-completed channel dataset versions...") + + query = ( + update(models.ChannelDatasetVersion) + .where( + models.ChannelDatasetVersion.preprocessing_status.notin_( + StatusEnum.final_statuses() + ), + models.ChannelDatasetVersion.updated_at < text("NOW() - INTERVAL '12 hours'"), + ) + .values( + preprocessing_status=StatusEnum.FAILED, + reason_for_failure=func.coalesce( + models.ChannelDatasetVersion.reason_for_failure, + "The version had invalid status.", + ), + updated_at=func.now(), + ) + ) + + result = await self._session.execute(query) + await self._session.commit() + + _log.info(f"Updated {result.rowcount} channel dataset version(s) to FAILED status") + async def reload_all_indicators( self, background_tasks: BackgroundTasks, @@ -1451,6 +1479,163 @@ async def clear_channel_dataset_data_in_background( _log.exception(f"Failed to clear data after reindexing {channel_dataset}") await self._update_channel_dataset_status(channel_dataset, new_status=StatusEnum.FAILED) + async def _get_deduplication_status_by_versions( + self, + available_dims_store: VectorStore, + special_dims_store: VectorStore, + indicator_dims_store: VectorStore, + versions: set[int], + ) -> schemas.DeduplicationStatus: + available_has_duplicates, available_count = ( + await available_dims_store.has_duplicates_in_versions(version_ids=versions) + ) + special_has_duplicates, special_count = await special_dims_store.has_duplicates_in_versions( + version_ids=versions + ) + _, indicator_count = await indicator_dims_store.has_duplicates_in_versions( + version_ids=versions + ) + + # we only consider available and special dimensions for deduplication requirement + deduplication_required = available_has_duplicates or special_has_duplicates + total_duplicates = available_count + special_count + indicator_count + + return schemas.DeduplicationStatus( + deduplication_required=deduplication_required, + total_duplicate_count=total_duplicates, + available_dimensions_duplicate_count=available_count, + special_dimensions_duplicate_count=special_count, + indicator_dimensions_duplicate_count=indicator_count, + ) + + async def _get_full_deduplication_status( + self, + available_dims_store: VectorStore, + special_dims_store: VectorStore, + indicator_dims_store: VectorStore, + ) -> schemas.DeduplicationStatus: + available_has_duplicates, available_count = await available_dims_store.has_duplicates() + special_has_duplicates, special_count = await special_dims_store.has_duplicates() + _, indicator_count = await indicator_dims_store.has_duplicates() + + # we only consider available and special dimensions for deduplication requirement + deduplication_required = available_has_duplicates or special_has_duplicates + total_duplicates = available_count + special_count + indicator_count + + return schemas.DeduplicationStatus( + deduplication_required=deduplication_required, + total_duplicate_count=total_duplicates, + available_dimensions_duplicate_count=available_count, + special_dimensions_duplicate_count=special_count, + indicator_dimensions_duplicate_count=indicator_count, + ) + + async def _check_latest_versions_status( + self, + channel: models.Channel, + available_dims_store: VectorStore, + special_dims_store: VectorStore, + indicator_dims_store: VectorStore, + ) -> schemas.ChannelIndexStatus: + latest_successful_versions = await self.get_latest_successful_dataset_versions_for_channel( + channel_id=channel.id + ) + versions = { + v.last_completed_version.version_data_id + for v in latest_successful_versions.values() + if v.last_completed_version is not None + } + + deduplication_status = await self._get_deduplication_status_by_versions( + available_dims_store, + special_dims_store, + indicator_dims_store, + versions, + ) + sizes = schemas.VectorStoreSizes( + available_dimensions_size=await available_dims_store.get_size(version_ids=versions), + special_dimensions_size=await special_dims_store.get_size(version_ids=versions), + indicator_dimensions_size=await indicator_dims_store.get_size(version_ids=versions), + ) + + vector_store_status = schemas.VectorStoreStatus( + deduplication=deduplication_status, + sizes=sizes, + ) + return schemas.ChannelIndexStatus( + vector_store=vector_store_status, + scope=ChannelIndexStatusScope.LATEST_COMPLETED_VERSIONS, + ) + + async def _check_full_index_status( + self, + channel: models.Channel, + available_dims_store: VectorStore, + special_dims_store: VectorStore, + indicator_dims_store: VectorStore, + ) -> schemas.ChannelIndexStatus: + deduplication_status = await self._get_full_deduplication_status( + available_dims_store, + special_dims_store, + indicator_dims_store, + ) + sizes = schemas.VectorStoreSizes( + available_dimensions_size=await available_dims_store.get_total_size(), + special_dimensions_size=await special_dims_store.get_total_size(), + indicator_dimensions_size=await indicator_dims_store.get_total_size(), + ) + + vector_store_status = schemas.VectorStoreStatus( + deduplication=deduplication_status, + sizes=sizes, + ) + return schemas.ChannelIndexStatus( + vector_store=vector_store_status, scope=ChannelIndexStatusScope.FULL + ) + + async def check_index_status( + self, + channel_id: int, + auth_context: AuthContext, + scope: schemas.ChannelIndexStatusScope, + ) -> schemas.ChannelIndexStatus: + """Checks index status for channel""" + channel = await ChannelService(self._session).get_model_by_id(channel_id) + vector_store_factory = VectorStoreFactory(session=self._session) + + available_dims_store = await vector_store_factory.get_vector_store( + collection_name=channel.available_dimensions_table_name, + auth_context=auth_context, + embedding_model_name=channel.llm_model, + ) + special_dims_store = await vector_store_factory.get_vector_store( + collection_name=channel.special_dimensions_table_name, + auth_context=auth_context, + embedding_model_name=channel.llm_model, + ) + indicator_dims_store = await vector_store_factory.get_vector_store( + collection_name=channel.indicator_table_name, + auth_context=auth_context, + embedding_model_name=channel.llm_model, + ) + + if scope == schemas.ChannelIndexStatusScope.FULL: + return await self._check_full_index_status( + channel=channel, + available_dims_store=available_dims_store, + special_dims_store=special_dims_store, + indicator_dims_store=indicator_dims_store, + ) + elif scope == schemas.ChannelIndexStatusScope.LATEST_COMPLETED_VERSIONS: + return await self._check_latest_versions_status( + channel=channel, + available_dims_store=available_dims_store, + special_dims_store=special_dims_store, + indicator_dims_store=indicator_dims_store, + ) + else: + raise ValueError(f"Unknown scope: {scope}") + @background_task async def reload_indicators_in_background_task( diff --git a/src/common/README.md b/src/common/README.md index dadebf0..37cdfb0 100644 --- a/src/common/README.md +++ b/src/common/README.md @@ -20,6 +20,10 @@ application-specific environment variables. | PGVECTOR_USE_MSI | No | If enabled, the application uses MSI to authenticate to the database. Otherwise, it uses the password from the environment variable. | `true`, `false` | `false` | | PGVECTOR_MSI_SCOPE | No | The scope used to obtain the MSI token. | | `"https://ossrdbms-aad.database.windows.net/.default"` | | PGVECTOR_MSI_TOKEN_REFRESH_TIMEOUT | No | Delay between MSI token refreshes. | | `23 * 3600` | +| PGVECTOR_CONNECTION_MAX_RETRIES | No | Maximum number of retries for the Postgres connection | | `5` | +| PGVECTOR_CONNECTION_RETRY_INTERVAL | No | Initial retry interval in seconds (uses exponential backoff) | | `10.0` | +| PGVECTOR_ALEMBIC_MAX_RETRIES | No | Maximum number of alembic version check retry attempts | | `5` | +| PGVECTOR_ALEMBIC_RETRY_INTERVAL | No | Initial retry interval in seconds for alembic version check (uses exponential backoff) | | `10.0` | | ELASTIC_CONNECTION_STRING | No | Connection string for the ElasticSearch instance | | | | ELASTIC_AUTH_USER | No | User for the ElasticSearch instance | | | | ELASTIC_AUTH_PASSWORD | No | Password for the ElasticSearch instance | | | diff --git a/src/common/models/database.py b/src/common/models/database.py index 14663d9..da144a6 100644 --- a/src/common/models/database.py +++ b/src/common/models/database.py @@ -3,7 +3,7 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from sqlalchemy import event +from sqlalchemy import event, text from sqlalchemy.ext.asyncio import ( AsyncAttrs, AsyncEngine, @@ -27,6 +27,11 @@ class Base(AsyncAttrs, DeclarativeBase): metadata = Base.metadata # for Alembic migrations + +class DatabaseConnectionError(RuntimeError): + pass + + # The MSI token manager is used to store and update the MSI token for Postgres in the background. _MSI_TOKEN_MANAGER: ValueUpdater[msi.MsiTokenResponse] | None = None @@ -96,32 +101,82 @@ def __init__(self, engine_config: dict): self._engine_config = engine_config self._postgres_settings = PostgresSettings() + @staticmethod + async def _test_connection(engine: AsyncEngine) -> bool: + """Test if database connection is working""" + try: + async with engine.connect() as conn: + await conn.execute(text("SELECT 1")) + return True + except Exception as e: + _log.debug(f"Connection test failed: {e}") + return False + + async def _create_engine_with_retry(self, engine_factory, description: str) -> AsyncEngine: + """Create engine with retry logic""" + max_retries = self._postgres_settings.connection_max_retries + retry_interval = self._postgres_settings.connection_retry_interval + + for attempt in range(max_retries): + try: + _log.info( + f"Attempting to create {description} engine (attempt {attempt + 1}/{max_retries})" + ) + engine = await engine_factory() + + # Test the connection + _log.debug("Testing database connection...") + if await self._test_connection(engine): + _log.info(f"{description} engine created and connection verified") + return engine + else: + _log.warning("Connection test failed, closing engine") + await engine.dispose() + + except Exception as e: + _log.warning( + f"Failed to create {description} engine (attempt {attempt + 1}/{max_retries}): {e}" + ) + + if attempt < max_retries - 1: + # Exponential backoff + sleep_time = retry_interval * (2**attempt) + _log.info(f"Retrying in {sleep_time:.2f} seconds...") + await asyncio.sleep(sleep_time) + + raise DatabaseConnectionError(f"Failed to connect to database after {max_retries} attempts") + async def _create_default_engine(self) -> AsyncEngine: _log.debug(f"Creating default engine with config: {self._engine_config}") - engine = create_async_engine( - self._postgres_settings.create_default_uri(), **self._engine_config - ) - _log.debug(f"Default engine created: {engine}") - return engine + + async def factory(): + return create_async_engine( + self._postgres_settings.create_default_uri(), **self._engine_config + ) + + return await self._create_engine_with_retry(factory, "default") async def _create_msi_engine(self) -> AsyncEngine: _log.debug(f"Creating MSI engine with config: {self._engine_config}") - engine = create_async_engine( - self._postgres_settings.create_msi_uri(), **self._engine_config - ) msi_token_manager = _get_msi_token_manager() if msi_token_manager is None or not msi_token_manager.is_initialized: raise RuntimeError("Cannot create engine before MSI token manager is initialized") - # event not supported for async engine - provide token must be synchronous - @event.listens_for(engine.sync_engine, "do_connect") - def provide_token(dialect, conn_rec, cargs, cparams): - _log.debug("Providing MSI token for database connection") - cparams["password"] = msi_token_manager.value.access_token + async def factory(): + engine = create_async_engine( + self._postgres_settings.create_msi_uri(), **self._engine_config + ) - _log.debug(f"MSI engine created: {engine}") - return engine + # event not supported for async engine - provide token must be synchronous + @event.listens_for(engine.sync_engine, "do_connect") + def provide_token(dialect, conn_rec, cargs, cparams): + _log.debug("Providing MSI token for database connection") + cparams["password"] = msi_token_manager.value.access_token + + return engine + + return await self._create_engine_with_retry(factory, "MSI") async def create_engine(self) -> AsyncEngine: _log.debug(f"Creating engine (USE_MSI={self._postgres_settings.use_msi})") diff --git a/src/common/models/health_checker.py b/src/common/models/health_checker.py index aa35d7e..a95c183 100644 --- a/src/common/models/health_checker.py +++ b/src/common/models/health_checker.py @@ -1,15 +1,14 @@ +import asyncio + from sqlalchemy import text from common.config import Versions from common.config import multiline_logger as logger +from common.settings.database import PostgresSettings from .database import get_session_contex_manager -class DatabaseConnectionError(RuntimeError): - pass - - class AlembicTableNotFoundError(RuntimeError): pass @@ -19,16 +18,9 @@ class WrongAlembicVersionError(RuntimeError): class DatabaseHealthChecker: - @classmethod - async def check_connection(cls): - try: - async with get_session_contex_manager() as session: - await session.execute(text("SELECT 1;")) - except Exception as e: - msg = "Connection to database failed. Check if the database is running and the connection string is correct." - logger.error(msg) - logger.error(e) - raise DatabaseConnectionError(msg) + + def __init__(self): + self._postgres_settings = PostgresSettings() @classmethod async def check_alembic_version(cls) -> None: @@ -51,7 +43,19 @@ async def check_alembic_version(cls) -> None: ) raise WrongAlembicVersionError("Alembic version is not correct") - @classmethod - async def check(cls): - await cls.check_connection() - await cls.check_alembic_version() + async def check(self) -> None: + for attempt in range(self._postgres_settings.alembic_max_retries): + try: + await self.check_alembic_version() + except (AlembicTableNotFoundError, WrongAlembicVersionError) as e: + logger.error( + f"Failed to check alembic version on attempt {attempt + 1}" + f"/{self._postgres_settings.alembic_max_retries}: {str(e)}" + ) + + if attempt < self._postgres_settings.alembic_max_retries - 1: + sleep_time = self._postgres_settings.alembic_retry_interval * (2**attempt) + logger.info(f"Retrying in {sleep_time:.2f} seconds...") + await asyncio.sleep(sleep_time) + else: + raise diff --git a/src/common/prompts/assets/indexer.yaml b/src/common/prompts/assets/indexer.yaml index 7f83668..5fb56f7 100644 --- a/src/common/prompts/assets/indexer.yaml +++ b/src/common/prompts/assets/indexer.yaml @@ -80,12 +80,12 @@ search_normalize: {input} Examples: - 1. if case is described as JSON {{"Input": "Provide the total unemployment data as a percentage of the labor force for Germany for the last 15 years", "Named Entities": ["last 15 years (Time frequency)", "Germany (Country/Reference area)"], "Time Period": "from 2009-08-08 to 2024-08-08"] then expected cleaned input is "total unemployment data as a percent of the labor force" - 2. if case is described as JSON {{"Input": "Get the Gross Domestic Product per capita growth rate for the United Kingdom since Brexit", "Named Entities": ["United Kingdom (Country/Reference area)"], "Time Period": "from 2020-01-31"] then expected cleaned input is "Get Gross Domestic Product per capita growth rate" - 3. if case is described as JSON {{"Input": "I need to calculate the misery index for Austria. What indicators would you recommend?", "Named Entities": ["Austria (Country/Reference area)"]}} then expected cleaned input is "calculate the misery index" + 1. if case is described as JSON {{"Input": "Provide the total unemployment data as a percentage of the labor force for Germany for the last 15 years", "Named Entities": ["last 15 years (Time frequency) (REMOVE)", "Germany (Country/Reference area) (REMOVE)"], "Time Period": "from 2009-08-08 to 2024-08-08"] then expected cleaned input is "total unemployment data as a percent of the labor force" + 2. if case is described as JSON {{"Input": "Get the Gross Domestic Product per capita growth rate for the United Kingdom since Brexit", "Named Entities": ["United Kingdom (Country/Reference area) (REMOVE)"], "Time Period": "from 2020-01-31 (REMOVE)"] then expected cleaned input is "Gross Domestic Product per capita growth rate" + 3. if case is described as JSON {{"Input": "I need to calculate the misery index for Austria. What indicators would you recommend?", "Named Entities": ["Austria (Country/Reference area) (REMOVE)"]}} then expected cleaned input is "misery index" 4. if case is described as JSON {{"Input": "Query the unemployment rate and inflation rate to calculate the Misery Index"]}} then expected cleaned input is "unemployment rate and inflation rate to calculate the Misery Index" - 5. if case is described as JSON {{"Input": "I would like to evaluate net trade flow of manufacturing items between the United States and Eurozone (as a group). What indicators would you suggest?", "Named Entities": ["United States (Country/Reference area)", Euro Area (Counterpart area)]}} then expected cleaned input is "evaluate net trade flow of manufacturing items" - + 5. if case is described as JSON {{"Input": "I would like to evaluate net trade flow of manufacturing items between the United States and Eurozone (as a group). What indicators would you suggest?", "Named Entities": ["United States (Country/Reference area) (REMOVE)", "Euro Area (Counterpart area) (REMOVE)""]}} then expected cleaned input is "net trade flow of manufacturing items" + 6. if case is described as JSON {{"Input": "Number of stays of Poland citizens in hotels?", "Named Entities": ["Poland (Country/Reference area) (REMOVE)", "Poland citizen (Citizenship of a certain country (citizen of Finland/Uzbek/foreigner from Greece/etc.) (DO NOT REMOVE)"]] then expected cleaned input is "Number of stays of Poland citizens in hotels" separate_subjects: systemPrompt: |- You are an expert in statistical indicators. diff --git a/src/common/schemas/__init__.py b/src/common/schemas/__init__.py index b39a4aa..a7f4f2b 100644 --- a/src/common/schemas/__init__.py +++ b/src/common/schemas/__init__.py @@ -1,5 +1,15 @@ from .base import ListResponse -from .channel import Channel, ChannelBase, ChannelConfig, ChannelUpdate, SupremeAgentConfig +from .channel import ( + Channel, + ChannelBase, + ChannelConfig, + ChannelIndexStatus, + ChannelUpdate, + DeduplicationStatus, + SupremeAgentConfig, + VectorStoreSizes, + VectorStoreStatus, +) from .channel_dataset import ( ChangesBetweenVersionAndActualData, ChannelDatasetBase, @@ -12,6 +22,7 @@ from .data_source import DataSource, DataSourceBase, DataSourceType, DataSourceUpdate from .dataset import DataSet, DataSetBase, DataSetDescriptor, DataSetUpdate from .enums import ( + ChannelIndexStatusScope, DecoderOfLatestEnum, IndexerVersion, IndicatorSelectionVersion, diff --git a/src/common/schemas/channel.py b/src/common/schemas/channel.py index 82f652b..8714c38 100644 --- a/src/common/schemas/channel.py +++ b/src/common/schemas/channel.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, Field from .base import BaseYamlModel, DbDefaultBase -from .enums import LocaleEnum +from .enums import ChannelIndexStatusScope, LocaleEnum from .model_config import LLMModelConfig from .onboarding import OnboardingConfig from .tools import ( @@ -198,3 +198,55 @@ class ChannelUpdate(BaseModel): class Channel(DbDefaultBase, ChannelBase): pass + + +class DeduplicationStatus(BaseModel): + """Status information about channel deduplication requirements.""" + + deduplication_required: bool = Field( + description="Whether deduplication is required for the channel" + ) + total_duplicate_count: int = Field( + description="Total number of duplicate documents across all dimension stores" + ) + available_dimensions_duplicate_count: int = Field( + description="Number of duplicate documents in available dimensions store" + ) + special_dimensions_duplicate_count: int = Field( + description="Number of duplicate documents in special dimensions store" + ) + indicator_dimensions_duplicate_count: int = Field( + description="Number of duplicate documents in indicator dimensions store" + ) + + +class VectorStoreSizes(BaseModel): + """Size information about channel vector store index""" + + available_dimensions_size: int = Field( + description="Number of documents in available dimensions store" + ) + special_dimensions_size: int = Field( + description="Number of documents in special dimensions store" + ) + indicator_dimensions_size: int = Field( + description="Number of documents in indicator dimensions store" + ) + + +class VectorStoreStatus(BaseModel): + """Status information about channel vector store index""" + + deduplication: DeduplicationStatus = Field( + description="Deduplication status information for the channel" + ) + sizes: VectorStoreSizes = Field(description="Size information for the channel vector store") + + +class ChannelIndexStatus(BaseModel): + """Status information about channel index""" + + scope: ChannelIndexStatusScope = Field(description="The scope of the channel index status") + vector_store: VectorStoreStatus = Field( + description="Vector store status information for the channel" + ) diff --git a/src/common/schemas/channel_dataset.py b/src/common/schemas/channel_dataset.py index 0e9d1a5..ed98811 100644 --- a/src/common/schemas/channel_dataset.py +++ b/src/common/schemas/channel_dataset.py @@ -34,6 +34,16 @@ class ChannelDatasetVersion(DbDefaultBase): non_indicator_dimensions_hash: str | None special_dimensions_hash: str | None + @property + def all_hashes_dict(self) -> dict[str, str | None]: + parts = { + 'structure': self.structure_hash, + 'indicator_dims': self.indicator_dimensions_hash, + 'non_indicator_dims': self.non_indicator_dimensions_hash, + 'special_dims': self.special_dimensions_hash, + } + return parts + @property def version_data_id(self) -> int: """The ID of the version which contains the actual data for this version.""" diff --git a/src/common/schemas/enums.py b/src/common/schemas/enums.py index 2ec5f57..629cc33 100644 --- a/src/common/schemas/enums.py +++ b/src/common/schemas/enums.py @@ -73,10 +73,20 @@ class AvailableDatasetsVersion(StrEnum): full = "full" +_LANGUAGE_NAMES = { + "en": "English", + "uk": "Ukrainian", +} + + class LocaleEnum(StrEnum): EN = "en" UK = "uk" + def get_language_name(self) -> str: + """Return the full language name for this locale.""" + return _LANGUAGE_NAMES[self] + class DataRequestStatus(StrEnum): SUCCESS = "SUCCESS" @@ -89,3 +99,8 @@ class DataParsingStatus(StrEnum): SUCCESS = "SUCCESS" FAILED = "FAILED" PARTIALLY_FAILED = "PARTIALLY_FAILED" + + +class ChannelIndexStatusScope(StrEnum): + FULL = "full" + LATEST_COMPLETED_VERSIONS = "latest_completed_versions" diff --git a/src/common/settings/database.py b/src/common/settings/database.py index 15a8cb1..047f8db 100644 --- a/src/common/settings/database.py +++ b/src/common/settings/database.py @@ -1,4 +1,4 @@ -from pydantic import Field, model_validator +from pydantic import Field, PositiveFloat, PositiveInt, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict @@ -34,6 +34,22 @@ class PostgresSettings(BaseSettings): batch_size: int = Field(default=1000, description="Batch size for vector store operations") + connection_max_retries: PositiveInt = Field( + default=5, description="Maximum number of connection retry attempts" + ) + + connection_retry_interval: PositiveFloat = Field( + default=10.0, description="Initial retry interval in seconds (uses exponential backoff)" + ) + + alembic_max_retries: PositiveInt = Field( + default=3, description="Maximum number of alembic version check retry attempts" + ) + alembic_retry_interval: PositiveFloat = Field( + default=10.0, + description="Initial alembic retry interval in seconds (uses exponential backoff)", + ) + @model_validator(mode="after") def validate_config(self): """Validate that required fields are present based on authentication type""" diff --git a/src/common/utils/cancel_dependency.py b/src/common/utils/cancel_dependency.py new file mode 100644 index 0000000..69387ac --- /dev/null +++ b/src/common/utils/cancel_dependency.py @@ -0,0 +1,53 @@ +import asyncio +import logging +from collections.abc import AsyncGenerator + +from fastapi import HTTPException, Request + +_log = logging.getLogger(__name__) + + +async def cancel_on_disconnect(request: Request) -> AsyncGenerator[None, None]: + """A function developed to be a FastAPI dependency that cancels processing of a user request if the client disconnects. + + Use it to save resources when the client closes the connection on a READ request. + DO NOT use for other operations (e.g. CREATE, UPDATE, DELETE) as this may cause an unexpected app state. + + Notes: + 1. This dependency properly works with other dependencies. For example, a database dependency + will still close the session if the request is canceled. + 2. Context managers also work as expected. + 3. Coroutines scheduled by `asyncio.gather` or `asyncio.TaskGroup()` will be closed (canceled) immediately, + without waiting for completion. + 4. WARNING: Tasks in separate threads started by `asyncio.to_thread` will not be canceled and will be finished completely. + 5. WARNING: This dependency does not work properly with `fastapi.BackgroundTasks`. + A normal (non-canceled) request will fail after sending a response to the user, + and background tasks will be canceled almost immediately. + """ + + handler_task = asyncio.current_task() + if handler_task is None: + yield + return + + watch_task = asyncio.create_task(_watcher(request, handler_task)) + try: + yield + except asyncio.CancelledError: + # Optional: map to 499 like nginx "Client Closed Request" + _log.warning("Request was cancelled due to client disconnect") + raise HTTPException(status_code=499, detail="Client closed request") from None + finally: + watch_task.cancel() + + +async def _watcher(request: Request, handler_task: asyncio.Task, interval: float = 2.0) -> None: + _log.info(f"Watching {request.url}") + + while True: + if await request.is_disconnected(): + _log.info(f"Cancelling handler task for {request.url}") + handler_task.cancel() + else: + _log.debug(f"Request {request.url} still connected") + await asyncio.sleep(interval) diff --git a/src/common/vectorstore/base.py b/src/common/vectorstore/base.py index 1ff579f..b048f33 100644 --- a/src/common/vectorstore/base.py +++ b/src/common/vectorstore/base.py @@ -50,6 +50,22 @@ async def search_with_similarity_score( async def deduplicate_by_document_content(self) -> None: """Removes and remaps duplicate documents based on document content.""" + @abstractmethod + async def has_duplicates(self) -> tuple[bool, int]: + """Checks if there are duplicate documents based on document content.""" + + @abstractmethod + async def has_duplicates_in_versions(self, version_ids: set[int]) -> tuple[bool, int]: + """Checks if there are duplicate documents based on document content in the specified version IDs.""" + + @abstractmethod + async def get_total_size(self) -> int: + """Returns the total number of documents in the vector store.""" + + @abstractmethod + async def get_size(self, version_ids: set[int]) -> int: + """Returns the number of documents in the vector store for the specified version IDs.""" + @abstractmethod async def export_to_folder(self, folder_path: str, version_ids: set[int]) -> None: """Exports the vector store data to the specified folder.""" diff --git a/src/common/vectorstore/pg_vector_store/pg_vector_store.py b/src/common/vectorstore/pg_vector_store/pg_vector_store.py index b9ee1dc..8036841 100644 --- a/src/common/vectorstore/pg_vector_store/pg_vector_store.py +++ b/src/common/vectorstore/pg_vector_store/pg_vector_store.py @@ -21,6 +21,7 @@ from sqlalchemy.schema import CreateTable from sqlalchemy.sql import text from sqlalchemy.sql.elements import TextClause +from sqlalchemy.sql.expression import func from common.services.base import DbServiceBase from common.settings.database import PostgresSettings @@ -451,6 +452,122 @@ def _build_delete_orphaned_query(self, document_table: str, metadata_table: str) """ ) + @staticmethod + async def _get_duplicates( + session: AsyncSession, query: TextClause, params=None + ) -> tuple[bool, int]: + result = await session.execute(statement=query, params=params) + row = result.fetchone() + + if row: + duplicate_count = int(row.duplicate_count) + has_duplicates = duplicate_count > 0 + return (has_duplicates, duplicate_count) + + return (False, 0) + + async def has_duplicates(self) -> tuple[bool, int]: + """Checks if there are duplicate documents based on document content.""" + document_model = await self._get_document_model() + metadata_model = await self._get_metadata_model() + + async with self._lock_session() as session: + for model in [document_model, metadata_model]: + if not await self._check_if_table_exists(session, model.__tablename__): + _log.info(f"Table {model.__tablename__!r} does not exist.") + return False, 0 + + duplicate_check_query = text( + f""" + WITH duplicate_groups AS ( + SELECT d.document, COUNT(DISTINCT d.id) as doc_count + FROM collections."{document_model.__tablename__}" d + INNER JOIN collections."{metadata_model.__tablename__}" m ON d.id = m.document_id + GROUP BY d.document + HAVING COUNT(DISTINCT d.id) > 1 + ) + SELECT COUNT(*) as group_count, + COALESCE(SUM(doc_count - 1), 0) as duplicate_count + FROM duplicate_groups + """ + ) + + return await self._get_duplicates(session, duplicate_check_query, params=None) + + async def has_duplicates_in_versions(self, version_ids: set[int]) -> tuple[bool, int]: + """Checks if there are duplicate documents based on document content. + + Returns: + tuple[bool, int]: (has_duplicates, duplicate_count) + """ + document_model = await self._get_document_model() + metadata_model = await self._get_metadata_model() + + async with self._lock_session() as session: + for model in [document_model, metadata_model]: + if not await self._check_if_table_exists(session, model.__tablename__): + _log.info(f"Table {model.__tablename__!r} does not exist.") + return False, 0 + + # Build WHERE clause for version_ids filter + where_clause = "" + params = {} + if version_ids: + where_clause = "WHERE m.version_id = ANY(:version_ids)" + params = {"version_ids": list(version_ids)} + + duplicate_check_query = text( + f""" + WITH duplicate_groups AS ( + SELECT d.document, COUNT(DISTINCT d.id) as doc_count + FROM collections."{document_model.__tablename__}" d + INNER JOIN collections."{metadata_model.__tablename__}" m ON d.id = m.document_id + {where_clause} + GROUP BY d.document + HAVING COUNT(DISTINCT d.id) > 1 + ) + SELECT COUNT(*) as group_count, + COALESCE(SUM(doc_count - 1), 0) as duplicate_count + FROM duplicate_groups + """ + ) + + return await self._get_duplicates(session, duplicate_check_query, params=params) + + async def get_total_size(self) -> int: + """Returns the total number of documents in the vector store.""" + metadata_model = await self._get_metadata_model() + + async with self._lock_session() as session: + if not await self._check_if_table_exists(session, metadata_model.__tablename__): + _log.info(f"Table {metadata_model.__tablename__!r} does not exist.") + return 0 + + size_query = select(func.count(func.distinct(metadata_model.document_id))).select_from( + metadata_model + ) + result = await session.execute(size_query) + size = result.scalar_one() + return size + + async def get_size(self, version_ids: set[int]) -> int: + """Returns the number of documents in the vector store.""" + metadata_model = await self._get_metadata_model() + + async with self._lock_session() as session: + if not await self._check_if_table_exists(session, metadata_model.__tablename__): + _log.info(f"Table {metadata_model.__tablename__!r} does not exist.") + return 0 + + size_query = ( + select(func.count(func.distinct(metadata_model.document_id))) + .select_from(metadata_model) + .where(metadata_model.version_id.in_(version_ids)) + ) + result = await session.execute(size_query) + size = result.scalar_one() + return size + async def deduplicate_by_document_content(self) -> None: """Removes and remaps duplicate documents based on `document` field content. diff --git a/src/statgpt/README.md b/src/statgpt/README.md index 76c50cf..7f52ca4 100644 --- a/src/statgpt/README.md +++ b/src/statgpt/README.md @@ -5,24 +5,25 @@ Below are the environment variables specific to the Chat Application. Other required variables can be found in the [common README file](../common/README.md). -| Variable | Required | Description | Available Values | Default values | -|---------------------------|:--------:|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------|--------------------------------------------| -| DIAL_APP_NAME | No | Name of the DIAL app | | `talk-to-your-data` | -| DIAL_AUTH_MODE | No | Define the authentication mode for the LLM models used by the application. `USER_TOKEN` means that the application requests models with a user token. `API_KEY` means that all requests to LLM are sent with a single application key. | `USER_TOKEN`, `API_KEY` | `USER_TOKEN` | -| DIAL_LOG_LEVEL | No | Log level for the DIAL app | `DEBUG`, `INFO`, `WARN`, `ERROR`, `CRITICAL` | `INFO` | -| DIAL_SHOW_STAGE_SECONDS | No | Whether to show the stage seconds in the DIAL app | `true`, `false` | `false` | -| DIAL_SHOW_DEBUG_STAGES | No | Whether to show the debug stages in the DIAL app | `true`, `false` | `false` | -| ENABLE_DEV_COMMANDS | No | Whether to enable developer commands in chat. Some commands, such as `!show_debug_stages`, are allowed even if this environment variable is set to False. Must be disabled in the production. | `true`, `false` | `false` | -| ENABLE_DIRECT_TOOL_CALLS | No | Whether to allow the user to call tools directly, bypassing the `out of scope` check and `supreme agent` orchestration. | `true`, `false` | `false` | -| OFFICIAL_DATASET_LABEL | No | A label for official datasets to mark them for the user in the chat | | `⭐` | -| SKIP_OUT_OF_SCOPE_CHECK | No | Whether to skip the out of scope check for the chat | `true`, `false` | `false` | -| CMD_OUT_OF_SCOPE_ONLY | No | Whether to stop processing user query right after out-of-scope check | `true`, `false` | `false` | -| CMD_RAG_PREFILTER_ONLY | No | Whether to use pre-filters only for the RAG | `true`, `false` | `false` | -| EVAL_DIAL_ROLE | No | Allows user requests without a JWT token. A user request with this role has system access to the data. Added for evaluation in development environments and must be disabled (not set) in production. | | | -| DIAL_RAG_DEPLOYMENT_ID | No | Deployment ID for the RAG | | `dial-rag-pgvector` | -| DIAL_RAG_PGVECTOR_URL | No | URL for the RAG with pgvector, only for local development | | | -| DIAL_RAG_PGVECTOR_API_KEY | No | API key for the RAG with pgvector, only for local development | | | -| STAGE_INDICATORS | No | Stage name for indicators selection | | `Selecting Indicators` | -| STAGE_DATASET_QUERIES | No | Stage name for selection of non-Indicator Dimension Values " | | `Selecting non-Indicator Dimension Values` | -| STAGE_QUERIES_EXECUTION | No | Stage name for queries execution | | `Executing Data Queries` | -| TTYD_TOOL_PLAIN_CONTENT_* | No | Environment variables for the Plain Content tool to replace in the files content. Replace `*` with the variable name. | | | +| Variable | Required | Description | Available Values | Default values | +|-----------------------------|:--------:|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------|--------------------------------------------| +| DIAL_APP_NAME | No | Name of the DIAL app | | `talk-to-your-data` | +| DIAL_AUTH_MODE | No | Define the authentication mode for the LLM models used by the application. `USER_TOKEN` means that the application requests models with a user token. `API_KEY` means that all requests to LLM are sent with a single application key. | `USER_TOKEN`, `API_KEY` | `USER_TOKEN` | +| DIAL_LOG_LEVEL | No | Log level for the DIAL app | `DEBUG`, `INFO`, `WARN`, `ERROR`, `CRITICAL` | `INFO` | +| DIAL_SHOW_STAGE_SECONDS | No | Whether to show the stage seconds in the DIAL app | `true`, `false` | `false` | +| DIAL_SHOW_DEBUG_STAGES | No | Whether to show the debug stages in the DIAL app | `true`, `false` | `false` | +| DIAL_SHOW_DEBUG_ATTACHMENTS | No | Whether to show the debug attachments in the chat completion responses | `true`, `false` | `false` | +| ENABLE_DEV_COMMANDS | No | Whether to enable developer commands in chat. Some commands, such as `!show_debug_stages`, are allowed even if this environment variable is set to False. Must be disabled in the production. | `true`, `false` | `false` | +| ENABLE_DIRECT_TOOL_CALLS | No | Whether to allow the user to call tools directly, bypassing the `out of scope` check and `supreme agent` orchestration. | `true`, `false` | `false` | +| OFFICIAL_DATASET_LABEL | No | A label for official datasets to mark them for the user in the chat | | `⭐` | +| SKIP_OUT_OF_SCOPE_CHECK | No | Whether to skip the out of scope check for the chat | `true`, `false` | `false` | +| CMD_OUT_OF_SCOPE_ONLY | No | Whether to stop processing user query right after out-of-scope check | `true`, `false` | `false` | +| CMD_RAG_PREFILTER_ONLY | No | Whether to use pre-filters only for the RAG | `true`, `false` | `false` | +| EVAL_DIAL_ROLE | No | Allows user requests without a JWT token. A user request with this role has system access to the data. Added for evaluation in development environments and must be disabled (not set) in production. | | | +| DIAL_RAG_DEPLOYMENT_ID | No | Deployment ID for the RAG | | `dial-rag-pgvector` | +| DIAL_RAG_PGVECTOR_URL | No | URL for the RAG with pgvector, only for local development | | | +| DIAL_RAG_PGVECTOR_API_KEY | No | API key for the RAG with pgvector, only for local development | | | +| STAGE_INDICATORS | No | Stage name for indicators selection | | `Selecting Indicators` | +| STAGE_DATASET_QUERIES | No | Stage name for selection of non-Indicator Dimension Values " | | `Selecting non-Indicator Dimension Values` | +| STAGE_QUERIES_EXECUTION | No | Stage name for queries execution | | `Executing Data Queries` | +| TTYD_TOOL_PLAIN_CONTENT_* | No | Environment variables for the Plain Content tool to replace in the files content. Replace `*` with the variable name. | | | diff --git a/src/statgpt/application/app_factory.py b/src/statgpt/application/app_factory.py index 8655c08..cb5eb39 100644 --- a/src/statgpt/application/app_factory.py +++ b/src/statgpt/application/app_factory.py @@ -4,8 +4,10 @@ from aidial_sdk.chat_completion import ChatCompletion, ConfigurationRequest, Request, Response from aidial_sdk.telemetry.types import MetricsConfig, TelemetryConfig, TracingConfig from fastapi import Request as FastAPIRequest +from fastapi.params import Depends from common.settings.application import application_settings +from common.utils.cancel_dependency import cancel_on_disconnect from statgpt.settings.dial_app import dial_app_settings from .application import StatGPTApp @@ -48,10 +50,14 @@ def create_app(self) -> DIALApp: ), ) + dependencies = [Depends(cancel_on_disconnect)] + app.add_chat_completion_with_dependencies( "{deployment_id}", AppChatCompletion(), heartbeat_interval=5, + chat_completion_dependencies=dependencies, + configuration_dependencies=dependencies, ) app.include_router(service_router) diff --git a/src/statgpt/application/application.py b/src/statgpt/application/application.py index 08ee2e0..b8a14be 100644 --- a/src/statgpt/application/application.py +++ b/src/statgpt/application/application.py @@ -19,7 +19,7 @@ async def lifespan(app: "StatGPTApp"): async with optional_msi_token_manager_context(): # Check resources' availability: - await DatabaseHealthChecker.check() + await DatabaseHealthChecker().check() # Start data preloading in the background asyncio.create_task(preload_data(allow_cached_datasets=True)) diff --git a/src/statgpt/chains/candidates_selection_simple.py b/src/statgpt/chains/candidates_selection_simple.py index 77f5472..50ba357 100644 --- a/src/statgpt/chains/candidates_selection_simple.py +++ b/src/statgpt/chains/candidates_selection_simple.py @@ -1,5 +1,3 @@ -from operator import itemgetter - from langchain_core.output_parsers import PydanticOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough @@ -35,6 +33,9 @@ def get_output_type(): def _get_candidates(self, inputs: dict) -> list[LLMSelectionCandidateBase]: return inputs[self._candidates_key] + def _get_llm_response(self, inputs: dict) -> SelectedCandidates: + return inputs["parsed_response"] + def _route_based_on_candidates_presence(self, inputs: dict) -> Runnable | SelectedCandidates: candidates = self._get_candidates(inputs) if not candidates: @@ -68,7 +69,8 @@ def _route_based_on_candidates_presence(self, inputs: dict) -> Runnable | Select | parser ) | self._remove_hallucinations - | itemgetter("parsed_response") + | self._display_llm_response_in_stage + | self._get_llm_response ) logger.info( f"{self.__class__.__name__} using LLM model: {self._llm_model_config.deployment.deployment_id}" @@ -83,7 +85,7 @@ def _format_candidates(self, inputs: dict) -> str: text = candidates[0].candidates_to_llm_string(candidates) # type: ignore[arg-type] return text - def _display_formatted_candidates_in_stage(self, inputs: dict): + def _display_formatted_candidates_in_stage(self, inputs: dict) -> dict: choice = ChainParameters.get_choice(inputs) state = ChainParameters.get_state(inputs) show_debug_stages = state.get(StateVarsConfig.SHOW_DEBUG_STAGES) @@ -100,21 +102,37 @@ def _display_formatted_candidates_in_stage(self, inputs: dict): return inputs + def _display_llm_response_in_stage(self, inputs: dict) -> dict: + choice = ChainParameters.get_choice(inputs) + state = ChainParameters.get_state(inputs) + show_debug_stages = state.get(StateVarsConfig.SHOW_DEBUG_STAGES) + + if not show_debug_stages: + return inputs + + response = self._get_llm_response(inputs) + response_formatted = response.model_dump_json(indent=2) + with choice.create_stage(name='[DEBUG] Non-Indicator LLM Selection Response') as stage: + content = f'```json\n{response_formatted}\n```' + stage.append_content(content) + + return inputs + def _remove_hallucinations(self, inputs: dict): candidates = self._get_candidates(inputs) - parsed_response: SelectedCandidates = inputs["parsed_response"] + response = self._get_llm_response(inputs) candidates_ids = {x._id for x in candidates} - parsed_ids = set(parsed_response.ids) + selected_ids = set(response.ids) - hallucinations = parsed_ids.difference(candidates_ids) + hallucinations = selected_ids.difference(candidates_ids) if hallucinations: logger.warning( f"!HALLUCINATION in Selection chain! " f"{len(hallucinations)} unexpected ids found: {hallucinations}" ) - parsed_response.ids = list(parsed_ids.intersection(candidates_ids)) - inputs["parsed_response"] = parsed_response # let's be explicit + response.ids = list(selected_ids.intersection(candidates_ids)) # inplace update + inputs["parsed_response"] = response # not required, but explicit return inputs def create_chain(self) -> Runnable: diff --git a/src/statgpt/chains/data_query/data_query_artifacts_displayer.py b/src/statgpt/chains/data_query/data_query_artifacts_displayer.py index c48f526..f1340ba 100644 --- a/src/statgpt/chains/data_query/data_query_artifacts_displayer.py +++ b/src/statgpt/chains/data_query/data_query_artifacts_displayer.py @@ -13,6 +13,7 @@ from common.schemas.enums import DataParsingStatus from common.utils import AttachmentsStorage, MediaTypes, attachments_storage_factory from common.utils.async_utils import catch_and_log_async +from statgpt.schemas.dial_app_configuration import StatGPTConfiguration from statgpt.schemas.tool_artifact import DataQueryArtifact from statgpt.utils import get_json_markdown, get_python_code_markdown @@ -31,21 +32,22 @@ def __init__( self, choice: Choice, config: DataQueryAttachments, + chat_config: StatGPTConfiguration, max_cells: int, auth_context: AuthContext, ): self._choice = choice self._config = config + self._chat_config = chat_config self._auth_context = auth_context self._max_cells = max_cells async def display(self, data_query_artifacts: dict[str, DataQueryArtifact]) -> None: data_query_artifacts_list = list(data_query_artifacts.values()) responses = self._merge_data_responses(data_query_artifacts_list) - tasks = [ - self._display_data_responses(responses), - self._display_eval_attachments(data_query_artifacts), - ] + tasks = [self._display_data_responses(responses)] + if self._chat_config.enable_debug_attachments: + tasks.append(self._display_eval_attachments(data_query_artifacts)) await asyncio.gather(*tasks) async def get_system_message_content( diff --git a/src/statgpt/chains/data_query/query_builder/query/summarize_query.py b/src/statgpt/chains/data_query/query_builder/query/summarize_query.py index b654498..6076e36 100644 --- a/src/statgpt/chains/data_query/query_builder/query/summarize_query.py +++ b/src/statgpt/chains/data_query/query_builder/query/summarize_query.py @@ -77,13 +77,14 @@ async def _enrich_queries(inputs: dict) -> dict: def create_chain(self, inputs: dict) -> Runnable: auth_context = ChainParameters.get_auth_context(inputs) + locale = ChainParameters.get_data_service(inputs).channel_config.locale prompt_template = ChatPromptTemplate.from_messages( [ SystemMessagePromptTemplate.from_template(self._system_prompt), HumanMessagePromptTemplate.from_template("{formatted_queries}"), ], - ) + ).partial(language=locale.get_language_name()) chat_model = get_chat_model( api_key=auth_context.api_key, model_config=self._llm_model_config ).with_structured_output(schema=QuerySummaries, method="json_schema") @@ -94,7 +95,7 @@ def create_chain(self, inputs: dict) -> Runnable: RunnablePassthrough.assign(formatted_queries=self._format_queries) | prompt_template | chat_model - ) + ), ) | self._enrich_queries ) diff --git a/src/statgpt/chains/supreme_agent.py b/src/statgpt/chains/supreme_agent.py index 00dde6c..c163a62 100644 --- a/src/statgpt/chains/supreme_agent.py +++ b/src/statgpt/chains/supreme_agent.py @@ -243,10 +243,11 @@ async def stream_response(self, inputs: dict) -> str: data_query_artifacts: dict[str, DataQueryArtifact] = {} assert self._channel_config.data_query is not None, "data_query must be configured" data_displayer = DataQueryArtifactDisplayer( - choice, - self._channel_config.data_query.details.attachments, - self._channel_config.data_query.details.tool_response_max_cells, - auth_context, + choice=choice, + config=self._channel_config.data_query.details.attachments, + chat_config=ChainParameters.get_configuration(inputs), + max_cells=self._channel_config.data_query.details.tool_response_max_cells, + auth_context=auth_context, ) for i in range(self._channel_config.supreme_agent.max_agent_iterations): diff --git a/src/statgpt/default_prompts/assets/data_query.yaml b/src/statgpt/default_prompts/assets/data_query.yaml index e733ee6..35a74d7 100644 --- a/src/statgpt/default_prompts/assets/data_query.yaml +++ b/src/statgpt/default_prompts/assets/data_query.yaml @@ -318,7 +318,7 @@ validationSystemPrompt: >- Please provide only REQUESTED dimension values as an output. Do not include any additional data, comments or text. - Make sure you analyzed the WHOLE provdided list of dimension values and did not skip anything requested by user. + Make sure you analyzed the WHOLE provided list of dimension values and did not skip anything requested by user. Format your answer as JSON: {format_instructions} validationUserPrompt: >- @@ -350,8 +350,9 @@ datasetSelectionPrompts: (1) select dataset, (2) remove dataset reference from user query and (3) keep indicator in user query - To distinguish between indicators and datasets, pay attention to words and grammar. Datasets usually go after words like "from", "according to". - - If user asked for datasets by provider only and mentions indicator matching a dataset name, select - ALL datasets with that provider, not just the datasets that contain that indicator. + - If user asked to filter datasets only by provider + and mentions indicator matching a dataset name, + select ALL datasets with that provider, not just the datasets that contain that indicator. - Partial matches of provider names are allowed. For example, if user asks for datasets by IMF, datasets with provider "IMF.STA" or "IMF.RES" or similar should also be selected, not just datasets with provider "IMF". @@ -432,4 +433,4 @@ summarizeQueriesPrompt: >- You are an expert in Official Statistics. Your task is to interpret and summarize the data query to natural language. The summary should be concise and easy - to understand for a general audience. The summary must be in the language of the query. + to understand for a general audience. The summary **must be in {language}**. diff --git a/src/statgpt/schemas/dial_app_configuration.py b/src/statgpt/schemas/dial_app_configuration.py index 0a398a9..2fe3448 100644 --- a/src/statgpt/schemas/dial_app_configuration.py +++ b/src/statgpt/schemas/dial_app_configuration.py @@ -4,6 +4,8 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator +from statgpt.settings.dial_app import dial_app_settings + class StatGPTConfiguration(BaseModel): """ @@ -17,6 +19,10 @@ class StatGPTConfiguration(BaseModel): "Used to interpret and display dates and times.", default="UTC", ) + enable_debug_attachments: bool = Field( + description="Enable debug attachments in the chat responses.", + default=dial_app_settings.dial_show_debug_attachments, + ) @field_validator('timezone') @classmethod diff --git a/src/statgpt/services/chat_facade.py b/src/statgpt/services/chat_facade.py index 1ef19b8..86a734d 100644 --- a/src/statgpt/services/chat_facade.py +++ b/src/statgpt/services/chat_facade.py @@ -43,6 +43,7 @@ from common.utils.timer import debug_timer from common.vectorstore import ScoredVectorStoreDocument, VectorStore, VectorStoreFactory from statgpt import utils +from statgpt.settings.dial_app import dial_app_settings _log = logging.getLogger(__name__) @@ -166,6 +167,20 @@ class VersionedDataSet: data: DataSet +class BaseChannelConfiguration(BaseModel, metaclass=FormMetaclass): + model_config = DialConfigDict(chat_message_input_disabled=False) + + timezone: str = DialField( + description="Timezone in IANA format, e.g. 'Europe/Berlin', 'America/New_York'. " + "Used to interpret and display dates and times.", + default="UTC", + ) + enable_debug_attachments: bool = DialField( + description="Enable debug attachments in the chat responses.", + default=dial_app_settings.dial_show_debug_attachments, + ) + + class ChannelServiceFacade(DbServiceBase): def __init__(self, session: AsyncSession, channel: models.Channel) -> None: super().__init__(session, asyncio.Lock()) @@ -245,10 +260,7 @@ def dial_channel_configuration(self) -> dict[str, Any]: f"No conversation starters configuration found for channel {self._channel.title}" ) - class InitConfiguration(BaseModel, metaclass=FormMetaclass): - model_config = DialConfigDict(chat_message_input_disabled=False) - - return InitConfiguration.model_json_schema() + return BaseChannelConfiguration.model_json_schema() intro_text: str = conversation_starters_config.intro_text _log.info( f"Conversation starters configuration found for channel {self._channel.title}, {conversation_starters_config=}" @@ -263,19 +275,12 @@ class InitConfiguration(BaseModel, metaclass=FormMetaclass): for i, button in enumerate(conversation_starters_config.buttons) ] - class StatGPTConfiguration(BaseModel, metaclass=FormMetaclass): - model_config = DialConfigDict(chat_message_input_disabled=False) - + class StatGPTConfiguration(BaseChannelConfiguration): starter: int | None = DialField( default=None, description=intro_text, buttons=buttons, ) - timezone: str = DialField( - description="Timezone in IANA format, e.g. 'Europe/Berlin', 'America/New_York'. " - "Used to interpret and display dates and times.", - default="UTC", - ) return StatGPTConfiguration.model_json_schema() diff --git a/src/statgpt/services/hybrid_searcher.py b/src/statgpt/services/hybrid_searcher.py index 8e92837..acdbdc7 100644 --- a/src/statgpt/services/hybrid_searcher.py +++ b/src/statgpt/services/hybrid_searcher.py @@ -753,8 +753,10 @@ async def _normalize_input( if named_entities and named_entities_to_remove: for entity in named_entities: if entity.entity_type in named_entities_to_remove: - entities_str += f" - {entity.entity} ({entity.entity_type})\n" - if entities_str != "": + entities_str += f" - {entity.entity} ({entity.entity_type}) (REMOVE)\n" + else: + entities_str += f" - {entity.entity} ({entity.entity_type}) (DO NOT REMOVE)\n" + if entities_str: entities_str = "Named Entities:\n" + entities_str period_str = "" @@ -768,24 +770,32 @@ async def _normalize_input( period_str = "Time Period:\n" + period_str removal_step = "" - if entities_str != "" or period_str != "": - if entities_str and period_str: - removal_step = "- from the input remove parts related to the Named Entities and Time Period. Only listed entities and period" - elif entities_str: - removal_step = "- from the input remove parts related to the Named Entities. Only listed entities" - elif period_str: - removal_step = "- from the input remove parts related to Time Period. Only period" - - forbidden_str = "" - if forbidden and len(forbidden) > 0: - forbidden_str = ", ".join(forbidden) - forbidden_str = f"Forbidden to remove words:\n{forbidden_str}\n" + if entities_str and period_str: + removal_step = ( + "- from the input: " + "keep entities marked (DO NOT REMOVE), " + "remove entities marked (REMOVE) " + "and remove all parts related to Time Period. " + "If an entity or part of entity appears in multiple categories " + "and at least one instance is marked (DO NOT REMOVE), keep entity" + ) + elif entities_str: + removal_step = ( + "- from the input: " + "keep entities marked (DO NOT REMOVE), " + "remove entities marked (REMOVE). " + "If an entity or part of entity appears in multiple categories " + "and at least one instance is marked (DO NOT REMOVE), keep entity" + ) + elif period_str: + removal_step = "- from the input remove all parts related to Time Period. Only period" + forbidden_to_remove_str = "" forbidden_step = "" - if forbidden_str != "": - forbidden_step = ( - "- do not remove forbidden to remove words from the input if they present in input" - ) + if forbidden: + forbidden_to_remove_str = ", ".join(forbidden) + forbidden_to_remove_str = f"Forbidden to remove words:\n{forbidden_to_remove_str}\n" + forbidden_step = "- do not remove forbidden to remove words from the input if they are present in input" output = await self._normalize_chain.ainvoke( { @@ -794,22 +804,22 @@ async def _normalize_input( "input": query, "entities": entities_str, "period": period_str, - "forbidden": forbidden_str, + "forbidden": forbidden_to_remove_str, } ) return output['cleaned_input'] async def _separate_subjects(self, query: str, forbidden: set[str]) -> list[str]: forbidden_str = "" - if forbidden and len(forbidden) > 0: + if forbidden: for item in forbidden: if len(item.split()) > 0: forbidden_str += f" - {item}" - if forbidden_str != "": + if forbidden_str: forbidden_str = f"Forbidden to split phrases:\n{forbidden_str}\n" forbidden_step = "" - if forbidden_str != "": + if forbidden_str: forbidden_step = "- do not split the input into separate queries in the middle of the forbidden to split phrases if they present in input" output = await self._separate_subjects_chain.ainvoke( @@ -871,17 +881,21 @@ async def search( ) logger.info(f"[search], {good_candidates=}") logger.info(f"[search], {candidates=}") + normalized = await self._normalize_input(query, named_entities, period, forbidden) + normalized = normalized.lower() + elapsed = time.perf_counter() - pc0 if stage: + stage.append_content("> [raw input query]:\n") + stage.append_content(f"```\n{query}\n```\n") + + stage.append_content("> [normalized input query for search]:\n") + stage.append_content(f"```\n{normalized}\n```\n") + stage.append_content("> [full text] potential known terms:\n") - stage.append_content("```\n") forbidden_str = "[" + "] [".join(forbidden) + "]" - stage.append_content(f"{forbidden_str}\n") - stage.append_content("```\n") + stage.append_content(f"```\n{forbidden_str}\n```\n") - normalized = await self._normalize_input(query, named_entities, period, forbidden) - normalized = normalized.lower() - elapsed = time.perf_counter() - pc0 logger.info(f"[search], {normalized=}, (elapsed {elapsed:0.3f} sec)") queries = await self._separate_subjects(normalized, good_candidates) diff --git a/src/statgpt/settings/dial_app.py b/src/statgpt/settings/dial_app.py index bfd66ec..5a09e60 100644 --- a/src/statgpt/settings/dial_app.py +++ b/src/statgpt/settings/dial_app.py @@ -56,6 +56,12 @@ class DialAppSettings(BaseSettings): default=False, alias="DIAL_SHOW_DEBUG_STAGES", description="Show debug stages information" ) + dial_show_debug_attachments: bool = Field( + default=False, + alias="DIAL_SHOW_DEBUG_ATTACHMENTS", + description="Show debug attachments in chat completion responses", + ) + enable_dev_commands: bool = Field( default=False, alias="ENABLE_DEV_COMMANDS", description="Enable development commands" ) diff --git a/tests/unit/test_alembic_version.py b/tests/unit/test_alembic_version.py new file mode 100644 index 0000000..4339d98 --- /dev/null +++ b/tests/unit/test_alembic_version.py @@ -0,0 +1,99 @@ +"""Unit test to verify that the ALEMBIC_TARGET_VERSION matches the latest migration.""" + +import ast +from pathlib import Path + +from src.common.config.versions import Versions + + +def extract_revision_from_migration(file_path: Path) -> str: + """Extract the revision ID from an alembic migration file. + + Args: + file_path: Path to the migration file + + Returns: + The revision ID string + + Raises: + ValueError: If revision cannot be found in the file + """ + with open(file_path, 'r') as f: + content = f.read() + + tree = ast.parse(content) + + for node in ast.walk(tree): + # Handle both annotated (revision: str = 'x') and regular (revision = 'x') assignments + if isinstance(node, ast.AnnAssign): + if isinstance(node.target, ast.Name) and node.target.id == 'revision': + if isinstance(node.value, ast.Constant): + return node.value.value + elif isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == 'revision': + if isinstance(node.value, ast.Constant): + return node.value.value + + raise ValueError(f"Could not find revision in {file_path}") + + +def get_latest_migration_revision() -> str: + """Get the revision ID of the latest alembic migration. + + Returns: + The revision ID of the most recent migration file + + Raises: + FileNotFoundError: If no migration files are found + """ + base_dir = Path(__file__).parent.parent.parent + versions_dir = base_dir / "src" / "admin_portal" / "alembic" / "versions" + + if not versions_dir.exists(): + raise FileNotFoundError(f"Alembic versions directory not found: {versions_dir}") + + migration_files = sorted(versions_dir.glob("*.py")) + + if not migration_files: + raise FileNotFoundError(f"No migration files found in {versions_dir}") + + latest_migration = migration_files[-1] # Files sorted by timestamp in filename + return extract_revision_from_migration(latest_migration) + + +class TestAlembicVersion: + """Test that the ALEMBIC_TARGET_VERSION matches the latest migration.""" + + def test_version_matches_latest_migration(self): + """Verify that ALEMBIC_TARGET_VERSION equals the latest migration revision.""" + configured_version = Versions.ALEMBIC_TARGET_VERSION + latest_revision = get_latest_migration_revision() + + assert configured_version == latest_revision, ( + f"ALEMBIC_TARGET_VERSION ({configured_version}) does not match " + f"the latest migration revision ({latest_revision}). " + f"Please update ALEMBIC_TARGET_VERSION in src/common/config/versions.py" + ) + + def test_configured_version_not_unknown(self): + """Verify that ALEMBIC_TARGET_VERSION is not set to a placeholder value.""" + configured_version = Versions.ALEMBIC_TARGET_VERSION + + assert configured_version != 'unknown', ( + "ALEMBIC_TARGET_VERSION should not be 'unknown'. " + "Please set it to the latest migration revision." + ) + + assert configured_version != '', ( + "ALEMBIC_TARGET_VERSION should not be empty. " + "Please set it to the latest migration revision." + ) + + def test_latest_migration_extraction(self): + """Test that we can successfully extract a revision from migration files.""" + latest_revision = get_latest_migration_revision() + + assert isinstance(latest_revision, str) + assert len(latest_revision) == 12 + assert all(c in '0123456789abcdef' for c in latest_revision) From d0fdcdf9559209b1dcc98655366299243deca8db Mon Sep 17 00:00:00 2001 From: Daniil Yarmalkevich Date: Wed, 3 Dec 2025 12:44:59 +0200 Subject: [PATCH 4/6] updated admin Dockerfile --- docker/admin.Dockerfile | 7 +++---- scripts/admin.sh | 18 ------------------ 2 files changed, 3 insertions(+), 22 deletions(-) delete mode 100644 scripts/admin.sh diff --git a/docker/admin.Dockerfile b/docker/admin.Dockerfile index fd7d74b..dc2e42a 100644 --- a/docker/admin.Dockerfile +++ b/docker/admin.Dockerfile @@ -20,9 +20,8 @@ COPY pyproject.toml . COPY poetry.lock . RUN poetry export -f requirements.txt --without-hashes | pip install $PIP_ARGS -r /dev/stdin -# Copy scripts and source code +# Copy source code COPY ./src/alembic.ini $APP_HOME/alembic.ini -COPY ./scripts/admin.sh $APP_HOME/admin.sh COPY ./src/admin_portal $APP_HOME/admin_portal COPY ./src/common $APP_HOME/common @@ -30,7 +29,7 @@ COPY ./src/common $APP_HOME/common RUN adduser -u 5678 --system --disabled-password --gecos "" app && chown -R app $APP_HOME USER app -ENV APP_MODE="DIAL" ENV WEB_CONCURRENCY=1 +ENV PYDANTIC_V2=True -CMD ["sh", "admin.sh"] +CMD ["sh", "admin_portal/admin.sh"] diff --git a/scripts/admin.sh b/scripts/admin.sh deleted file mode 100644 index 50ae5a6..0000000 --- a/scripts/admin.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -echo "ADMIN_MODE = '$ADMIN_MODE'" - -case $ADMIN_MODE in - - APP) - uvicorn "admin_portal.app:app" --host "0.0.0.0" --port 8000 --lifespan on - ;; - - ALEMBIC_UPGRADE) - alembic upgrade head - ;; - - *) - echo "Unknown ADMIN_MODE = '$ADMIN_MODE'. Possible values: 'APP' or 'ALEMBIC_UPGRADE'" - ;; -esac From 1781a66a434907beaf9818148a817de34f6c97da Mon Sep 17 00:00:00 2001 From: Daniil Yarmalkevich Date: Wed, 3 Dec 2025 13:24:32 +0200 Subject: [PATCH 5/6] updated chat Dockerfile --- docker/chat.Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docker/chat.Dockerfile b/docker/chat.Dockerfile index cd2efce..ec25308 100644 --- a/docker/chat.Dockerfile +++ b/docker/chat.Dockerfile @@ -20,7 +20,7 @@ COPY pyproject.toml . COPY poetry.lock . RUN poetry export -f requirements.txt --without-hashes | pip install $PIP_ARGS -r /dev/stdin -# Copy scripts and source code +# Copy source code COPY ./src/statgpt $APP_HOME/statgpt COPY ./src/common $APP_HOME/common @@ -30,6 +30,7 @@ USER app ENV APP_MODE="DIAL" ENV WEB_CONCURRENCY=1 +ENV PYDANTIC_V2=True EXPOSE 5000 From 82d1e2e7f7f374fe2aa8ab5bf9213b61129ad2b6 Mon Sep 17 00:00:00 2001 From: Daniil Yarmalkevich Date: Fri, 5 Dec 2025 11:14:34 +0200 Subject: [PATCH 6/6] added new code --- src/common/data/quanthub/v21/dataset.py | 5 ++-- .../data/quanthub/v21/qh_sdmx_client.py | 12 ++++++-- src/common/data/sdmx/v21/dataset.py | 4 ++- src/common/data/sdmx/v21/sdmx_client.py | 30 +++++++++++-------- src/common/prompts/assets/indexer.yaml | 15 ++++++++-- src/common/schemas/data_query_tool.py | 5 ++++ src/statgpt/README.md | 2 +- src/statgpt/application/app_factory.py | 5 ++-- src/statgpt/application/application.py | 9 +++++- src/statgpt/application/middleware.py | 26 ++++++++++++++++ src/statgpt/services/hybrid_searcher.py | 15 +++++++--- src/statgpt/settings/dial_app.py | 4 --- 12 files changed, 100 insertions(+), 32 deletions(-) create mode 100644 src/statgpt/application/middleware.py diff --git a/src/common/data/quanthub/v21/dataset.py b/src/common/data/quanthub/v21/dataset.py index 88be666..11bd589 100644 --- a/src/common/data/quanthub/v21/dataset.py +++ b/src/common/data/quanthub/v21/dataset.py @@ -229,9 +229,10 @@ def _availability_result_to_query( ) -> DataSetAvailabilityQuery: result = super()._availability_result_to_query(availability_result) - constraint: ContentConstraint = list(availability_result.constraint.values())[0] + constraints_iterator = iter(availability_result.constraint.values()) + constraint: ContentConstraint | None = next(constraints_iterator, None) - if "TIME_PERIOD" not in result: + if constraint is not None and "TIME_PERIOD" not in result: start, end = self._parse_time_period_from(constraint.annotations) result.time_period_start, result.time_period_end = start, end diff --git a/src/common/data/quanthub/v21/qh_sdmx_client.py b/src/common/data/quanthub/v21/qh_sdmx_client.py index 2f9ae85..e0ca201 100644 --- a/src/common/data/quanthub/v21/qh_sdmx_client.py +++ b/src/common/data/quanthub/v21/qh_sdmx_client.py @@ -217,8 +217,16 @@ async def _qh_available_constraint( headers=headers, json=req_body_obj.model_dump(mode='json', exclude_none=True, by_alias=True), ).prepare() - async with self._rate_limiter.availability_limiter(): - response = await self._perform_request(req) + + try: + async with self._rate_limiter.availability_limiter(): + response = await self._perform_request(req) + except httpx.HTTPStatusError as e: + if e.response.status_code in [400, 404]: + logger.error(f"Bad request for URL {url!r}: {e.response.text}") + logger.info(f"Request body: {req.body!r}") + return StructureMessage() # Return empty StructureMessage on bad request + raise resp_body_obj = QhAvailabilityResponseBody.model_validate(response.json()) structure_msg = resp_body_obj.to_sdmx1() diff --git a/src/common/data/sdmx/v21/dataset.py b/src/common/data/sdmx/v21/dataset.py index 71241a8..904563c 100644 --- a/src/common/data/sdmx/v21/dataset.py +++ b/src/common/data/sdmx/v21/dataset.py @@ -1007,7 +1007,9 @@ def _availability_result_to_query( ) -> DataSetAvailabilityQuery: constraints = list(availability_result.constraint.values()) - if len(constraints) != 1: + if len(constraints) == 0: + return DataSetAvailabilityQuery() # empty query + elif len(constraints) != 1: raise ValueError("Unexpected quantity of constraints in structure message") constraint = constraints[0] if len(constraint.data_content_region) != 1: diff --git a/src/common/data/sdmx/v21/sdmx_client.py b/src/common/data/sdmx/v21/sdmx_client.py index a348165..bb7348f 100644 --- a/src/common/data/sdmx/v21/sdmx_client.py +++ b/src/common/data/sdmx/v21/sdmx_client.py @@ -209,18 +209,24 @@ async def _get_availability( use_cache: bool = False, tofile: os.PathLike | IO | None = None, ) -> Message: - async with self._rate_limiter.availability_limiter(): - return await self._get( - resource_type=resource_type, - resource_id=resource_id, - agency_id=agency_id, - version=version, - key=key, - params=params, - dsd=dsd, - use_cache=use_cache, - tofile=tofile, - ) + try: + async with self._rate_limiter.availability_limiter(): + return await self._get( + resource_type=resource_type, + resource_id=resource_id, + agency_id=agency_id, + version=version, + key=key, + params=params, + dsd=dsd, + use_cache=use_cache, + tofile=tofile, + ) + except httpx.HTTPStatusError as e: + if e.response.status_code in [400, 404]: + logger.error(f"Bad request for URL {e.request.url!r}: {e.response.text}") + return StructureMessage() # Return empty StructureMessage on bad request + raise async def _get_data( self, diff --git a/src/common/prompts/assets/indexer.yaml b/src/common/prompts/assets/indexer.yaml index 5fb56f7..730a43c 100644 --- a/src/common/prompts/assets/indexer.yaml +++ b/src/common/prompts/assets/indexer.yaml @@ -117,18 +117,29 @@ separate_subjects: relevance: systemPrompt: |- You are an expert in statistical indicators. + By picking most relevant series of indicators you help answer the user's question with best suited statistical data. As an output provide relevance score for each input case as JSON object {{"relevance": [{{"number": "score"}}]}}. Only JSON, no markdown userPrompt: |- Instruction steps: - - input consists of separate group of items. Analyze all groups + - input consists of separate groups of items. Analyze all groups - each numbered item describes statistical indicator. Analyze all items - use the number from round brackets for the output reference and all other info as a source for relevance context. All items that have number in a round brackets should be present in output together with score - use all parent levels of an item in the relevance context - for each input indicator provide relevancy score as 0, 1, 2, 3 to the statement, - where: 0 - irrelevant, 1 - somewhat relevant, 2 - highly relevant, 3 - extremely relevant + where: + 3 - ideal indicators that perfectly match the statement; + 2 - best relevant items, but not ideal; + 1 - other relevant; + 0 - irrelevant; + - if there's no ideal matches, then best relevant items should get score 2 + - if there are ideal matches, then score 2 is assigned to the next best relevant items + their purpose is to provide extended answer for the statement + - items that are more detailed compared to the statement usually can't be ideal matches, because they answer more specific question than the statement asks for. + BUT there can be exceptions, e.g. where details can be guessed from the statement, and indicator just provides more precise definition + - if the statement is general and there are general items - they should get score 3 and all more detailed items should get score 2 or lower. - if the input indicator is at least somewhat relevant to the statement then the relevancy score can't be score 0 - if the input statement has extra clarification description then it is mandatory to have relevant part in the candidate item to get score 3. This extra clarification description could be essential to distinguish relevant and extremely relevant items. diff --git a/src/common/schemas/data_query_tool.py b/src/common/schemas/data_query_tool.py index c1cec63..6c89a6b 100644 --- a/src/common/schemas/data_query_tool.py +++ b/src/common/schemas/data_query_tool.py @@ -151,6 +151,10 @@ class SpecialDimensionsProcessor(BaseYamlModel): prompt: SystemUserPrompt +class HybridSearchPrompts(BaseYamlModel): + relevancy_prompts: SystemUserPrompt | None = Field(default=None) + + class HybridSearchConfig(BaseYamlModel): """Configuration for the Hybrid Search and Indexer.""" @@ -228,6 +232,7 @@ class HybridSearchConfig(BaseYamlModel): ge=0, le=3, ) + prompts: HybridSearchPrompts = Field(default_factory=HybridSearchPrompts) class DataQueryDetails(BaseToolDetails): diff --git a/src/statgpt/README.md b/src/statgpt/README.md index 7f52ca4..887049e 100644 --- a/src/statgpt/README.md +++ b/src/statgpt/README.md @@ -9,7 +9,7 @@ the [common README file](../common/README.md). |-----------------------------|:--------:|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------|--------------------------------------------| | DIAL_APP_NAME | No | Name of the DIAL app | | `talk-to-your-data` | | DIAL_AUTH_MODE | No | Define the authentication mode for the LLM models used by the application. `USER_TOKEN` means that the application requests models with a user token. `API_KEY` means that all requests to LLM are sent with a single application key. | `USER_TOKEN`, `API_KEY` | `USER_TOKEN` | -| DIAL_LOG_LEVEL | No | Log level for the DIAL app | `DEBUG`, `INFO`, `WARN`, `ERROR`, `CRITICAL` | `INFO` | +| DIAL_SDK_LOG | No | Log level for the DIAL SDK | `DEBUG`, `INFO`, `WARN`, `ERROR`, `CRITICAL` | `WARNING` | | DIAL_SHOW_STAGE_SECONDS | No | Whether to show the stage seconds in the DIAL app | `true`, `false` | `false` | | DIAL_SHOW_DEBUG_STAGES | No | Whether to show the debug stages in the DIAL app | `true`, `false` | `false` | | DIAL_SHOW_DEBUG_ATTACHMENTS | No | Whether to show the debug attachments in the chat completion responses | `true`, `false` | `false` | diff --git a/src/statgpt/application/app_factory.py b/src/statgpt/application/app_factory.py index cb5eb39..6c58a3b 100644 --- a/src/statgpt/application/app_factory.py +++ b/src/statgpt/application/app_factory.py @@ -4,10 +4,8 @@ from aidial_sdk.chat_completion import ChatCompletion, ConfigurationRequest, Request, Response from aidial_sdk.telemetry.types import MetricsConfig, TelemetryConfig, TracingConfig from fastapi import Request as FastAPIRequest -from fastapi.params import Depends from common.settings.application import application_settings -from common.utils.cancel_dependency import cancel_on_disconnect from statgpt.settings.dial_app import dial_app_settings from .application import StatGPTApp @@ -50,7 +48,8 @@ def create_app(self) -> DIALApp: ), ) - dependencies = [Depends(cancel_on_disconnect)] + # dependencies = [Depends(cancel_on_disconnect)] + dependencies: list = [] app.add_chat_completion_with_dependencies( "{deployment_id}", diff --git a/src/statgpt/application/application.py b/src/statgpt/application/application.py index b8a14be..5f1dcbb 100644 --- a/src/statgpt/application/application.py +++ b/src/statgpt/application/application.py @@ -1,8 +1,9 @@ import asyncio +import logging from collections.abc import Sequence from contextlib import asynccontextmanager -from aidial_sdk import DIALApp +from aidial_sdk import DIALApp, logger from aidial_sdk.chat_completion import ChatCompletion from aidial_sdk.deployment.configuration import ConfigurationRequest from aidial_sdk.deployment.tokenize import TokenizeRequest @@ -13,6 +14,7 @@ from common.models import DatabaseHealthChecker, optional_msi_token_manager_context from common.services.data_preloader import preload_data from common.settings.dial import dial_settings +from statgpt.application.middleware import DebugRequestLoggingMiddleware @asynccontextmanager @@ -36,6 +38,11 @@ def __init__(self, **kwargs): lifespan=lifespan, **kwargs, ) + if logger.isEnabledFor(logging.DEBUG): + self.add_middleware( + DebugRequestLoggingMiddleware, + patterns=[r"/chat/completions$"], + ) def add_chat_completion_with_dependencies( self, diff --git a/src/statgpt/application/middleware.py b/src/statgpt/application/middleware.py new file mode 100644 index 0000000..1fd4024 --- /dev/null +++ b/src/statgpt/application/middleware.py @@ -0,0 +1,26 @@ +import re + +from aidial_sdk import logger +from starlette.middleware.base import BaseHTTPMiddleware, Response +from starlette.requests import Request +from starlette.types import ASGIApp + + +class DebugRequestLoggingMiddleware(BaseHTTPMiddleware): + """Middleware to log raw request bodies for matching endpoints.""" + + def __init__(self, app: ASGIApp, patterns: list[str]): + """ + Args: + app: The ASGI application. + patterns: List of regex patterns to match request paths. + """ + super().__init__(app) + self._patterns = [re.compile(p) for p in patterns] + + async def dispatch(self, request: Request, call_next) -> Response: + if any(p.search(request.url.path) for p in self._patterns): + body = await request.body() + logger.debug(f"Request [{request.url.path}]: {body.decode('utf-8')}") + + return await call_next(request) diff --git a/src/statgpt/services/hybrid_searcher.py b/src/statgpt/services/hybrid_searcher.py index acdbdc7..6e82cd1 100644 --- a/src/statgpt/services/hybrid_searcher.py +++ b/src/statgpt/services/hybrid_searcher.py @@ -7,6 +7,7 @@ from typing import Any, NamedTuple from aidial_sdk.chat_completion import Stage +from langchain_core.prompts import ChatPromptTemplate from pydantic import BaseModel from common.config.logging import logger @@ -732,10 +733,16 @@ def __init__( IndexerPrompts.get_separate_subjects_prompts() | self._llm.with_structured_output(method="json_mode") ) - self._relevance_chain = ( - IndexerPrompts.get_relevance_prompts() - | self._llm.with_structured_output(method="json_mode") - ) + if system_user_prompt := config.prompts.relevancy_prompts: + prompt = ChatPromptTemplate.from_messages( + [ + ("system", system_user_prompt.system_message), + ("human", system_user_prompt.user_message), + ], + ) + else: + prompt = IndexerPrompts.get_relevance_prompts() + self._relevance_chain = prompt | self._llm.with_structured_output(method="json_mode") @property def config(self) -> HybridSearchConfig: diff --git a/src/statgpt/settings/dial_app.py b/src/statgpt/settings/dial_app.py index 5a09e60..3db4e2a 100644 --- a/src/statgpt/settings/dial_app.py +++ b/src/statgpt/settings/dial_app.py @@ -42,10 +42,6 @@ class DialAppSettings(BaseSettings): description="Authentication mode for DIAL API calls", ) - dial_log_level: str = Field( - default="INFO", alias="DIAL_LOG_LEVEL", description="Log level for DIAL application" - ) - dial_show_stage_seconds: bool = Field( default=False, alias="DIAL_SHOW_STAGE_SECONDS",