From 3a6d32da7c24a8cc4d6866f0c10a78987f969a07 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Tue, 19 Mar 2024 16:21:22 -0700 Subject: [PATCH 01/58] Port KV Store to Postgres (#1227) --- .../173cae5bba26_port_config_store.py | 29 ++++++++++++++ backend/danswer/configs/app_configs.py | 4 +- .../connectors/gmail/connector_auth.py | 2 +- .../connectors/google_drive/connector_auth.py | 2 +- backend/danswer/danswerbot/slack/tokens.py | 2 +- backend/danswer/db/engine.py | 4 ++ backend/danswer/db/models.py | 8 ++++ backend/danswer/dynamic_configs/__init__.py | 13 ------ backend/danswer/dynamic_configs/factory.py | 16 ++++++++ .../danswer/dynamic_configs/port_configs.py | 40 +++++++++++++++++++ .../{file_system => }/store.py | 40 +++++++++++++++++++ backend/danswer/llm/utils.py | 2 +- backend/danswer/main.py | 3 ++ .../danswer/server/danswer_api/ingestion.py | 2 +- .../danswer/server/manage/administrative.py | 2 +- backend/danswer/utils/acl.py | 2 +- backend/danswer/utils/telemetry.py | 2 +- 17 files changed, 150 insertions(+), 23 deletions(-) create mode 100644 backend/alembic/versions/173cae5bba26_port_config_store.py create mode 100644 backend/danswer/dynamic_configs/factory.py create mode 100644 backend/danswer/dynamic_configs/port_configs.py rename backend/danswer/dynamic_configs/{file_system => }/store.py (52%) diff --git a/backend/alembic/versions/173cae5bba26_port_config_store.py b/backend/alembic/versions/173cae5bba26_port_config_store.py new file mode 100644 index 0000000000..4087086bf1 --- /dev/null +++ b/backend/alembic/versions/173cae5bba26_port_config_store.py @@ -0,0 +1,29 @@ +"""Port Config Store + +Revision ID: 173cae5bba26 +Revises: e50154680a5c +Create Date: 2024-03-19 15:30:44.425436 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "173cae5bba26" +down_revision = "e50154680a5c" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "key_value_store", + sa.Column("key", sa.String(), nullable=False), + sa.Column("value", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.PrimaryKeyConstraint("key"), + ) + + +def downgrade() -> None: + op.drop_table("key_value_store") diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index cff1b8e5c7..08ac2fc23d 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -224,8 +224,8 @@ ##### # Miscellaneous ##### -DYNAMIC_CONFIG_STORE = os.environ.get( - "DYNAMIC_CONFIG_STORE", "FileSystemBackedDynamicConfigStore" +DYNAMIC_CONFIG_STORE = ( + os.environ.get("DYNAMIC_CONFIG_STORE") or "PostgresBackedDynamicConfigStore" ) DYNAMIC_CONFIG_DIR_PATH = os.environ.get("DYNAMIC_CONFIG_DIR_PATH", "/home/storage") JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default diff --git a/backend/danswer/connectors/gmail/connector_auth.py b/backend/danswer/connectors/gmail/connector_auth.py index f6cfa5a748..39dd9aacf8 100644 --- a/backend/danswer/connectors/gmail/connector_auth.py +++ b/backend/danswer/connectors/gmail/connector_auth.py @@ -24,7 +24,7 @@ from danswer.connectors.gmail.constants import SCOPES from danswer.db.credentials import update_credential_json from danswer.db.models import User -from danswer.dynamic_configs import get_dynamic_config_store +from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.server.documents.models import CredentialBase from danswer.server.documents.models import GoogleAppCredentials from danswer.server.documents.models import GoogleServiceAccountKey diff --git a/backend/danswer/connectors/google_drive/connector_auth.py b/backend/danswer/connectors/google_drive/connector_auth.py index f65e177724..65c34393c7 100644 --- a/backend/danswer/connectors/google_drive/connector_auth.py +++ b/backend/danswer/connectors/google_drive/connector_auth.py @@ -24,7 +24,7 @@ from danswer.connectors.google_drive.constants import SCOPES from danswer.db.credentials import update_credential_json from danswer.db.models import User -from danswer.dynamic_configs import get_dynamic_config_store +from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.server.documents.models import CredentialBase from danswer.server.documents.models import GoogleAppCredentials from danswer.server.documents.models import GoogleServiceAccountKey diff --git a/backend/danswer/danswerbot/slack/tokens.py b/backend/danswer/danswerbot/slack/tokens.py index c9c1286282..34d2b79a30 100644 --- a/backend/danswer/danswerbot/slack/tokens.py +++ b/backend/danswer/danswerbot/slack/tokens.py @@ -1,7 +1,7 @@ import os from typing import cast -from danswer.dynamic_configs import get_dynamic_config_store +from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.server.manage.models import SlackBotTokens diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 146f11ef81..6803c51aa3 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from danswer.configs.app_configs import POSTGRES_DB from danswer.configs.app_configs import POSTGRES_HOST @@ -80,3 +81,6 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]: get_sqlalchemy_async_engine(), expire_on_commit=False ) as async_session: yield async_session + + +SessionFactory = sessionmaker(bind=get_sqlalchemy_engine()) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 5ca3bdbe94..2405ffed6d 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -35,6 +35,7 @@ from danswer.configs.constants import MessageType from danswer.configs.constants import SearchFeedbackType from danswer.connectors.models import InputType +from danswer.dynamic_configs.interface import JSON_ro from danswer.search.models import RecencyBiasSetting from danswer.search.models import SearchType @@ -851,3 +852,10 @@ class TaskQueueState(Base): register_time: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) + + +class KVStore(Base): + __tablename__ = "key_value_store" + + key: Mapped[str] = mapped_column(String, primary_key=True) + value: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=False) diff --git a/backend/danswer/dynamic_configs/__init__.py b/backend/danswer/dynamic_configs/__init__.py index 0fc2233fa9..e69de29bb2 100644 --- a/backend/danswer/dynamic_configs/__init__.py +++ b/backend/danswer/dynamic_configs/__init__.py @@ -1,13 +0,0 @@ -from danswer.configs.app_configs import DYNAMIC_CONFIG_DIR_PATH -from danswer.configs.app_configs import DYNAMIC_CONFIG_STORE -from danswer.dynamic_configs.file_system.store import FileSystemBackedDynamicConfigStore -from danswer.dynamic_configs.interface import DynamicConfigStore - - -def get_dynamic_config_store() -> DynamicConfigStore: - dynamic_config_store_type = DYNAMIC_CONFIG_STORE - if dynamic_config_store_type == FileSystemBackedDynamicConfigStore.__name__: - return FileSystemBackedDynamicConfigStore(DYNAMIC_CONFIG_DIR_PATH) - - # TODO: change exception type - raise Exception("Unknown dynamic config store type") diff --git a/backend/danswer/dynamic_configs/factory.py b/backend/danswer/dynamic_configs/factory.py new file mode 100644 index 0000000000..a82bc315c8 --- /dev/null +++ b/backend/danswer/dynamic_configs/factory.py @@ -0,0 +1,16 @@ +from danswer.configs.app_configs import DYNAMIC_CONFIG_DIR_PATH +from danswer.configs.app_configs import DYNAMIC_CONFIG_STORE +from danswer.dynamic_configs.interface import DynamicConfigStore +from danswer.dynamic_configs.store import FileSystemBackedDynamicConfigStore +from danswer.dynamic_configs.store import PostgresBackedDynamicConfigStore + + +def get_dynamic_config_store() -> DynamicConfigStore: + dynamic_config_store_type = DYNAMIC_CONFIG_STORE + if dynamic_config_store_type == FileSystemBackedDynamicConfigStore.__name__: + return FileSystemBackedDynamicConfigStore(DYNAMIC_CONFIG_DIR_PATH) + if dynamic_config_store_type == PostgresBackedDynamicConfigStore.__name__: + return PostgresBackedDynamicConfigStore() + + # TODO: change exception type + raise Exception("Unknown dynamic config store type") diff --git a/backend/danswer/dynamic_configs/port_configs.py b/backend/danswer/dynamic_configs/port_configs.py new file mode 100644 index 0000000000..34abcff741 --- /dev/null +++ b/backend/danswer/dynamic_configs/port_configs.py @@ -0,0 +1,40 @@ +import json +from pathlib import Path + +from danswer.configs.app_configs import DYNAMIC_CONFIG_DIR_PATH +from danswer.dynamic_configs.factory import PostgresBackedDynamicConfigStore +from danswer.dynamic_configs.interface import ConfigNotFoundError + + +def read_file_system_store(directory_path: str) -> dict: + store = {} + base_path = Path(directory_path) + for file_path in base_path.iterdir(): + if file_path.is_file() and "." not in file_path.name: + with open(file_path, "r") as file: + key = file_path.stem + value = json.load(file) + + if value: + store[key] = value + return store + + +def insert_into_postgres(store_data: dict) -> None: + port_once_key = "file_store_ported" + config_store = PostgresBackedDynamicConfigStore() + try: + config_store.load(port_once_key) + return + except ConfigNotFoundError: + pass + + for key, value in store_data.items(): + config_store.store(key, value) + + config_store.store(port_once_key, True) + + +def port_filesystem_to_postgres(directory_path: str = DYNAMIC_CONFIG_DIR_PATH) -> None: + store_data = read_file_system_store(directory_path) + insert_into_postgres(store_data) diff --git a/backend/danswer/dynamic_configs/file_system/store.py b/backend/danswer/dynamic_configs/store.py similarity index 52% rename from backend/danswer/dynamic_configs/file_system/store.py rename to backend/danswer/dynamic_configs/store.py index 75cc0d7407..043d762d47 100644 --- a/backend/danswer/dynamic_configs/file_system/store.py +++ b/backend/danswer/dynamic_configs/store.py @@ -1,10 +1,15 @@ import json import os +from collections.abc import Iterator +from contextlib import contextmanager from pathlib import Path from typing import cast from filelock import FileLock +from sqlalchemy.orm import Session +from danswer.db.engine import SessionFactory +from danswer.db.models import KVStore from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.dynamic_configs.interface import DynamicConfigStore from danswer.dynamic_configs.interface import JSON_ro @@ -46,3 +51,38 @@ def delete(self, key: str) -> None: lock = _get_file_lock(file_path) with lock.acquire(timeout=FILE_LOCK_TIMEOUT): os.remove(file_path) + + +class PostgresBackedDynamicConfigStore(DynamicConfigStore): + @contextmanager + def get_session(self) -> Iterator[Session]: + session: Session = SessionFactory() + try: + yield session + finally: + session.close() + + def store(self, key: str, val: JSON_ro) -> None: + with self.get_session() as session: + obj = session.query(KVStore).filter_by(key=key).first() + if obj: + obj.value = val + else: + obj = KVStore(key=key, value=val) # type: ignore + session.query(KVStore).filter_by(key=key).delete() + session.add(obj) + session.commit() + + def load(self, key: str) -> JSON_ro: + with self.get_session() as session: + obj = session.query(KVStore).filter_by(key=key).first() + if not obj: + raise ConfigNotFoundError + return cast(JSON_ro, obj.value) + + def delete(self, key: str) -> None: + with self.get_session() as session: + result = session.query(KVStore).filter_by(key=key).delete() # type: ignore + if result == 0: + raise ConfigNotFoundError + session.commit() diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 5685213b4b..b1983180ca 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -30,7 +30,7 @@ from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.db.models import ChatMessage -from danswer.dynamic_configs import get_dynamic_config_store +from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.indexing.models import InferenceChunk from danswer.llm.interfaces import LLM diff --git a/backend/danswer/main.py b/backend/danswer/main.py index ad7bb14b5c..e770cc8abb 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -51,6 +51,7 @@ from danswer.db.index_attempt import cancel_indexing_attempts_past_model from danswer.db.index_attempt import expire_index_attempts from danswer.document_index.factory import get_default_document_index +from danswer.dynamic_configs.port_configs import port_filesystem_to_postgres from danswer.llm.factory import get_default_llm from danswer.llm.utils import get_default_llm_version from danswer.search.search_nlp_models import warm_up_models @@ -168,6 +169,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}" ) + port_filesystem_to_postgres() + with Session(engine) as db_session: db_embedding_model = get_current_db_embedding_model(db_session) secondary_db_embedding_model = get_secondary_db_embedding_model(db_session) diff --git a/backend/danswer/server/danswer_api/ingestion.py b/backend/danswer/server/danswer_api/ingestion.py index 8856e20d64..7fce8d1d38 100644 --- a/backend/danswer/server/danswer_api/ingestion.py +++ b/backend/danswer/server/danswer_api/ingestion.py @@ -19,7 +19,7 @@ from danswer.db.engine import get_session from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index -from danswer.dynamic_configs import get_dynamic_config_store +from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.indexing.embedder import DefaultIndexingEmbedder from danswer.indexing.indexing_pipeline import build_indexing_pipeline diff --git a/backend/danswer/server/manage/administrative.py b/backend/danswer/server/manage/administrative.py index dea338eeae..26e2dfd54c 100644 --- a/backend/danswer/server/manage/administrative.py +++ b/backend/danswer/server/manage/administrative.py @@ -24,7 +24,7 @@ from danswer.db.models import User from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index -from danswer.dynamic_configs import get_dynamic_config_store +from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm diff --git a/backend/danswer/utils/acl.py b/backend/danswer/utils/acl.py index 268457bfdc..8fbadb3000 100644 --- a/backend/danswer/utils/acl.py +++ b/backend/danswer/utils/acl.py @@ -11,7 +11,7 @@ from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import UpdateRequest from danswer.document_index.vespa.index import VespaIndex -from danswer.dynamic_configs import get_dynamic_config_store +from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.utils.logger import setup_logger diff --git a/backend/danswer/utils/telemetry.py b/backend/danswer/utils/telemetry.py index 65e9f4709f..0a21cf66e8 100644 --- a/backend/danswer/utils/telemetry.py +++ b/backend/danswer/utils/telemetry.py @@ -6,7 +6,7 @@ import requests from danswer.configs.app_configs import DISABLE_TELEMETRY -from danswer.dynamic_configs import get_dynamic_config_store +from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError CUSTOMER_UUID_KEY = "customer_uuid" From 6a776648b3ba439df392e12d5e56e42115734395 Mon Sep 17 00:00:00 2001 From: Weves Date: Tue, 19 Mar 2024 17:57:15 -0700 Subject: [PATCH 02/58] Fix LLM max tokens --- backend/danswer/llm/utils.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index b1983180ca..f36f285461 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -5,6 +5,7 @@ from typing import Any from typing import cast +import litellm # type: ignore import tiktoken from langchain.prompts.base import StringPromptValue from langchain.prompts.chat import ChatPromptValue @@ -15,7 +16,6 @@ from langchain.schema.messages import BaseMessageChunk from langchain.schema.messages import HumanMessage from langchain.schema.messages import SystemMessage -from litellm import get_max_tokens # type: ignore from tiktoken.core import Encoding from danswer.configs.app_configs import LOG_LEVEL @@ -247,12 +247,28 @@ def get_llm_max_tokens( return GEN_AI_MAX_TOKENS model_name = model_name or get_default_llm_version()[0] + # NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually + # returns the max OUTPUT tokens. Under the hood, this uses the `litellm.model_cost` dict, + # and there is no other interface to get what we want. This should be okay though, since the + # `model_cost` dict is a named public interface: + # https://litellm.vercel.app/docs/completion/token_usage#7-model_cost + litellm_model_map = litellm.model_cost try: if model_provider == "openai": - return get_max_tokens(model_name) - return get_max_tokens("/".join([model_provider, model_name])) + model_obj = litellm_model_map[model_name] + else: + model_obj = litellm_model_map[f"{model_provider}/{model_name}"] + if "max_tokens" in model_obj: + return model_obj["max_tokens"] + elif "max_input_tokens" in model_obj and "max_output_tokens" in model_obj: + return model_obj["max_input_tokens"] + model_obj["max_output_tokens"] + + raise RuntimeError("No max tokens found for LLM") except Exception: + logger.exception( + f"Failed to get max tokens for LLM with name {model_name}. Defaulting to 4096." + ) return 4096 From d66b6c0559ed7339c1fb5f2a457190b8c9a1ac29 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 21 Mar 2024 12:27:56 -0700 Subject: [PATCH 03/58] Fix Tag Document Source Enum (#1240) --- ...3b470d1a_remove_documentsource_from_tag.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 backend/alembic/versions/91fd3b470d1a_remove_documentsource_from_tag.py diff --git a/backend/alembic/versions/91fd3b470d1a_remove_documentsource_from_tag.py b/backend/alembic/versions/91fd3b470d1a_remove_documentsource_from_tag.py new file mode 100644 index 0000000000..b8f1a72922 --- /dev/null +++ b/backend/alembic/versions/91fd3b470d1a_remove_documentsource_from_tag.py @@ -0,0 +1,36 @@ +"""Remove DocumentSource from Tag + +Revision ID: 91fd3b470d1a +Revises: 173cae5bba26 +Create Date: 2024-03-21 12:05:23.956734 + +""" +from alembic import op +import sqlalchemy as sa +from danswer.configs.constants import DocumentSource + +# revision identifiers, used by Alembic. +revision = "91fd3b470d1a" +down_revision = "173cae5bba26" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.alter_column( + "tag", + "source", + type_=sa.String(length=50), + existing_type=sa.Enum(DocumentSource, native_enum=False), + existing_nullable=False, + ) + + +def downgrade() -> None: + op.alter_column( + "tag", + "source", + type_=sa.Enum(DocumentSource, native_enum=False), + existing_type=sa.String(length=50), + existing_nullable=False, + ) From 8dbe5cbaa69a0f6658ea6ee2d9a345b03a9d2f05 Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 21 Mar 2024 18:29:22 -0700 Subject: [PATCH 04/58] Add private Persona / Document Set migration --- ...df4e935ef_private_personas_documentsets.py | 116 ++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 backend/alembic/versions/e91df4e935ef_private_personas_documentsets.py diff --git a/backend/alembic/versions/e91df4e935ef_private_personas_documentsets.py b/backend/alembic/versions/e91df4e935ef_private_personas_documentsets.py new file mode 100644 index 0000000000..2e1a3297a4 --- /dev/null +++ b/backend/alembic/versions/e91df4e935ef_private_personas_documentsets.py @@ -0,0 +1,116 @@ +"""Private Personas DocumentSets +Revision ID: e91df4e935ef +Revises: 91fd3b470d1a +Create Date: 2024-03-17 11:47:24.675881 +""" +import fastapi_users_db_sqlalchemy +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "e91df4e935ef" +down_revision = "91fd3b470d1a" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "document_set__user", + sa.Column("document_set_id", sa.Integer(), nullable=False), + sa.Column( + "user_id", + fastapi_users_db_sqlalchemy.generics.GUID(), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["document_set_id"], + ["document_set.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("document_set_id", "user_id"), + ) + op.create_table( + "persona__user", + sa.Column("persona_id", sa.Integer(), nullable=False), + sa.Column( + "user_id", + fastapi_users_db_sqlalchemy.generics.GUID(), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["persona_id"], + ["persona.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("persona_id", "user_id"), + ) + op.create_table( + "document_set__user_group", + sa.Column("document_set_id", sa.Integer(), nullable=False), + sa.Column( + "user_group_id", + sa.Integer(), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["document_set_id"], + ["document_set.id"], + ), + sa.ForeignKeyConstraint( + ["user_group_id"], + ["user_group.id"], + ), + sa.PrimaryKeyConstraint("document_set_id", "user_group_id"), + ) + op.create_table( + "persona__user_group", + sa.Column("persona_id", sa.Integer(), nullable=False), + sa.Column( + "user_group_id", + sa.Integer(), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["persona_id"], + ["persona.id"], + ), + sa.ForeignKeyConstraint( + ["user_group_id"], + ["user_group.id"], + ), + sa.PrimaryKeyConstraint("persona_id", "user_group_id"), + ) + + op.add_column( + "document_set", + sa.Column("is_public", sa.Boolean(), nullable=True), + ) + # fill in is_public for existing rows + op.execute("UPDATE document_set SET is_public = true WHERE is_public IS NULL") + op.alter_column("document_set", "is_public", nullable=False) + + op.add_column( + "persona", + sa.Column("is_public", sa.Boolean(), nullable=True), + ) + # fill in is_public for existing rows + op.execute("UPDATE persona SET is_public = true WHERE is_public IS NULL") + op.alter_column("persona", "is_public", nullable=False) + + +def downgrade() -> None: + op.drop_column("persona", "is_public") + + op.drop_column("document_set", "is_public") + + op.drop_table("persona__user") + op.drop_table("document_set__user") + op.drop_table("persona__user_group") + op.drop_table("document_set__user_group") From c28a95e3678d95c5557f960ea72105864becc5c0 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 21 Mar 2024 20:10:08 -0700 Subject: [PATCH 05/58] Port File Store from Volume to PG (#1241) --- .../versions/4738e4b3bae1_pg_file_store.py | 28 ++++++ ...df4e935ef_private_personas_documentsets.py | 2 + backend/danswer/background/celery/celery.py | 4 - .../cross_connector_utils/file_utils.py | 22 ++--- backend/danswer/connectors/file/connector.py | 78 ++++++++------- .../connectors/google_drive/connector.py | 2 +- .../connectors/google_site/connector.py | 10 +- backend/danswer/connectors/web/connector.py | 4 + backend/danswer/db/file_store.py | 96 +++++++++++++++++++ backend/danswer/db/models.py | 6 ++ backend/danswer/db/pg_file_store.py | 93 ++++++++++++++++++ backend/danswer/server/documents/connector.py | 19 ++-- .../danswer/server/manage/administrative.py | 8 ++ .../danswer/utils/variable_functionality.py | 2 - 14 files changed, 315 insertions(+), 59 deletions(-) create mode 100644 backend/alembic/versions/4738e4b3bae1_pg_file_store.py create mode 100644 backend/danswer/db/file_store.py create mode 100644 backend/danswer/db/pg_file_store.py diff --git a/backend/alembic/versions/4738e4b3bae1_pg_file_store.py b/backend/alembic/versions/4738e4b3bae1_pg_file_store.py new file mode 100644 index 0000000000..a57102dbe9 --- /dev/null +++ b/backend/alembic/versions/4738e4b3bae1_pg_file_store.py @@ -0,0 +1,28 @@ +"""PG File Store + +Revision ID: 4738e4b3bae1 +Revises: e91df4e935ef +Create Date: 2024-03-20 18:53:32.461518 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "4738e4b3bae1" +down_revision = "e91df4e935ef" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "file_store", + sa.Column("file_name", sa.String(), nullable=False), + sa.Column("lobj_oid", sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint("file_name"), + ) + + +def downgrade() -> None: + op.drop_table("file_store") diff --git a/backend/alembic/versions/e91df4e935ef_private_personas_documentsets.py b/backend/alembic/versions/e91df4e935ef_private_personas_documentsets.py index 2e1a3297a4..c18084563d 100644 --- a/backend/alembic/versions/e91df4e935ef_private_personas_documentsets.py +++ b/backend/alembic/versions/e91df4e935ef_private_personas_documentsets.py @@ -1,7 +1,9 @@ """Private Personas DocumentSets + Revision ID: e91df4e935ef Revises: 91fd3b470d1a Create Date: 2024-03-17 11:47:24.675881 + """ import fastapi_users_db_sqlalchemy from alembic import op diff --git a/backend/danswer/background/celery/celery.py b/backend/danswer/background/celery/celery.py index 80a8a2a135..fbd823f224 100644 --- a/backend/danswer/background/celery/celery.py +++ b/backend/danswer/background/celery/celery.py @@ -226,8 +226,4 @@ def clean_old_temp_files_task( "task": "check_for_document_sets_sync_task", "schedule": timedelta(seconds=5), }, - "clean-old-temp-files": { - "task": "clean_old_temp_files_task", - "schedule": timedelta(minutes=30), - }, } diff --git a/backend/danswer/connectors/cross_connector_utils/file_utils.py b/backend/danswer/connectors/cross_connector_utils/file_utils.py index b0a9c723fe..c7f662d9af 100644 --- a/backend/danswer/connectors/cross_connector_utils/file_utils.py +++ b/backend/danswer/connectors/cross_connector_utils/file_utils.py @@ -2,8 +2,7 @@ import os import re import zipfile -from collections.abc import Generator -from pathlib import Path +from collections.abc import Iterator from typing import Any from typing import IO @@ -78,11 +77,11 @@ def is_macos_resource_fork_file(file_name: str) -> bool: # to the zip file. This file should contain a list of objects with the following format: # [{ "filename": "file1.txt", "link": "https://example.com/file1.txt" }] def load_files_from_zip( - zip_location: str | Path, + zip_file_io: IO, ignore_macos_resource_fork_files: bool = True, ignore_dirs: bool = True, -) -> Generator[tuple[zipfile.ZipInfo, IO[Any], dict[str, Any]], None, None]: - with zipfile.ZipFile(zip_location, "r") as zip_file: +) -> Iterator[tuple[zipfile.ZipInfo, IO[Any], dict[str, Any]]]: + with zipfile.ZipFile(zip_file_io, "r") as zip_file: zip_metadata = {} try: metadata_file_info = zip_file.getinfo(".danswer_metadata.json") @@ -109,18 +108,19 @@ def load_files_from_zip( yield file_info, file, zip_metadata.get(file_info.filename, {}) -def detect_encoding(file_path: str | Path) -> str: - with open(file_path, "rb") as file: - raw_data = file.read(50000) # Read a portion of the file to guess encoding - return chardet.detect(raw_data)["encoding"] or "utf-8" +def detect_encoding(file: IO[bytes]) -> str: + raw_data = file.read(50000) + encoding = chardet.detect(raw_data)["encoding"] or "utf-8" + file.seek(0) + return encoding def read_file( - file_reader: IO[Any], encoding: str = "utf-8", errors: str = "replace" + file: IO, encoding: str = "utf-8", errors: str = "replace" ) -> tuple[str, dict]: metadata = {} file_content_raw = "" - for ind, line in enumerate(file_reader): + for ind, line in enumerate(file): try: line = line.decode(encoding) if isinstance(line, bytes) else line except UnicodeDecodeError: diff --git a/backend/danswer/connectors/file/connector.py b/backend/danswer/connectors/file/connector.py index f6aeef649e..fa290a4969 100644 --- a/backend/danswer/connectors/file/connector.py +++ b/backend/danswer/connectors/file/connector.py @@ -1,11 +1,13 @@ import os -from collections.abc import Generator +from collections.abc import Iterator from datetime import datetime from datetime import timezone from pathlib import Path from typing import Any from typing import IO +from sqlalchemy.orm import Session + from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource from danswer.connectors.cross_connector_utils.file_utils import detect_encoding @@ -20,37 +22,40 @@ from danswer.connectors.models import BasicExpertInfo from danswer.connectors.models import Document from danswer.connectors.models import Section +from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.file_store import get_default_file_store from danswer.utils.logger import setup_logger logger = setup_logger() -def _open_files_at_location( - file_path: str | Path, -) -> Generator[tuple[str, IO[Any], dict[str, Any]], Any, None]: - extension = get_file_ext(file_path) +def _read_files_and_metadata( + file_name: str, + db_session: Session, +) -> Iterator[tuple[str, IO, dict[str, Any]]]: + """Reads the file into IO, in the case of a zip file, yields each individual + file contained within, also includes the metadata dict if packaged in the zip""" + extension = get_file_ext(file_name) metadata: dict[str, Any] = {} + directory_path = os.path.dirname(file_name) + + file_content = get_default_file_store(db_session).read_file(file_name, mode="b") if extension == ".zip": for file_info, file, metadata in load_files_from_zip( - file_path, ignore_dirs=True + file_content, ignore_dirs=True ): - yield file_info.filename, file, metadata - elif extension in [".txt", ".md", ".mdx"]: - encoding = detect_encoding(file_path) - with open(file_path, "r", encoding=encoding, errors="replace") as file: - yield os.path.basename(file_path), file, metadata - elif extension == ".pdf": - with open(file_path, "rb") as file: - yield os.path.basename(file_path), file, metadata + yield os.path.join(directory_path, file_info.filename), file, metadata + elif extension in [".txt", ".md", ".mdx", ".pdf"]: + yield file_name, file_content, metadata else: - logger.warning(f"Skipping file '{file_path}' with extension '{extension}'") + logger.warning(f"Skipping file '{file_name}' with extension '{extension}'") def _process_file( file_name: str, file: IO[Any], - metadata: dict[str, Any] = {}, + metadata: dict[str, Any] | None = None, pdf_pass: str | None = None, ) -> list[Document]: extension = get_file_ext(file_name) @@ -65,8 +70,9 @@ def _process_file( file=file, file_name=file_name, pdf_pass=pdf_pass ) else: - file_content_raw, file_metadata = read_file(file) - all_metadata = {**metadata, **file_metadata} + encoding = detect_encoding(file) + file_content_raw, file_metadata = read_file(file, encoding=encoding) + all_metadata = {**metadata, **file_metadata} if metadata else file_metadata # If this is set, we will show this in the UI as the "name" of the file file_display_name_override = all_metadata.get("file_display_name") @@ -114,7 +120,8 @@ def _process_file( Section(link=all_metadata.get("link"), text=file_content_raw.strip()) ], source=DocumentSource.FILE, - semantic_identifier=file_display_name_override or file_name, + semantic_identifier=file_display_name_override + or os.path.basename(file_name), doc_updated_at=final_time_updated, primary_owners=p_owners, secondary_owners=s_owners, @@ -140,24 +147,27 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None def load_from_state(self) -> GenerateDocumentsOutput: documents: list[Document] = [] - for file_location in self.file_locations: - current_datetime = datetime.now(timezone.utc) - files = _open_files_at_location(file_location) - - for file_name, file, metadata in files: - metadata["time_updated"] = metadata.get( - "time_updated", current_datetime - ) - documents.extend( - _process_file(file_name, file, metadata, self.pdf_pass) + with Session(get_sqlalchemy_engine()) as db_session: + for file_path in self.file_locations: + current_datetime = datetime.now(timezone.utc) + files = _read_files_and_metadata( + file_name=str(file_path), db_session=db_session ) - if len(documents) >= self.batch_size: - yield documents - documents = [] + for file_name, file, metadata in files: + metadata["time_updated"] = metadata.get( + "time_updated", current_datetime + ) + documents.extend( + _process_file(file_name, file, metadata, self.pdf_pass) + ) + + if len(documents) >= self.batch_size: + yield documents + documents = [] - if documents: - yield documents + if documents: + yield documents if __name__ == "__main__": diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index 15c9894a65..ea7ef60db7 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -388,7 +388,7 @@ def _process_folder_paths( def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None: """Checks for two different types of credentials. - (1) A credential which holds a token acquired via a user going thorugh + (1) A credential which holds a token acquired via a user going thorough the Google OAuth flow. (2) A credential which holds a service account key JSON file, which can then be used to impersonate any user in the workspace. diff --git a/backend/danswer/connectors/google_site/connector.py b/backend/danswer/connectors/google_site/connector.py index 2a2be5ebe3..38d6e0b143 100644 --- a/backend/danswer/connectors/google_site/connector.py +++ b/backend/danswer/connectors/google_site/connector.py @@ -5,6 +5,7 @@ from bs4 import BeautifulSoup from bs4 import Tag +from sqlalchemy.orm import Session from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource @@ -15,6 +16,8 @@ from danswer.connectors.interfaces import LoadConnector from danswer.connectors.models import Document from danswer.connectors.models import Section +from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.file_store import get_default_file_store from danswer.utils.logger import setup_logger logger = setup_logger() @@ -66,8 +69,13 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None def load_from_state(self) -> GenerateDocumentsOutput: documents: list[Document] = [] + with Session(get_sqlalchemy_engine()) as db_session: + file_content_io = get_default_file_store(db_session).read_file( + self.zip_path, mode="b" + ) + # load the HTML files - files = load_files_from_zip(self.zip_path) + files = load_files_from_zip(file_content_io) count = 0 for file_info, file_io, _metadata in files: # skip non-published files diff --git a/backend/danswer/connectors/web/connector.py b/backend/danswer/connectors/web/connector.py index 8acfaca425..4415a88fab 100644 --- a/backend/danswer/connectors/web/connector.py +++ b/backend/danswer/connectors/web/connector.py @@ -149,6 +149,10 @@ def __init__( self.to_visit_list = extract_urls_from_sitemap(_ensure_valid_url(base_url)) elif web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.UPLOAD: + logger.warning( + "This is not a UI supported Web Connector flow, " + "are you sure you want to do this?" + ) self.to_visit_list = _read_urls_file(base_url) else: diff --git a/backend/danswer/db/file_store.py b/backend/danswer/db/file_store.py new file mode 100644 index 0000000000..f0a44bf5da --- /dev/null +++ b/backend/danswer/db/file_store.py @@ -0,0 +1,96 @@ +from abc import ABC +from abc import abstractmethod +from typing import IO + +from sqlalchemy.orm import Session + +from danswer.db.pg_file_store import create_populate_lobj +from danswer.db.pg_file_store import delete_lobj_by_id +from danswer.db.pg_file_store import delete_pgfilestore_by_file_name +from danswer.db.pg_file_store import get_pgfilestore_by_file_name +from danswer.db.pg_file_store import read_lobj +from danswer.db.pg_file_store import upsert_pgfilestore + + +class FileStore(ABC): + """ + An abstraction for storing files and large binary objects. + """ + + @abstractmethod + def save_file(self, file_name: str, content: IO) -> None: + """ + Save a file to the blob store + + Parameters: + - connector_name: Name of the CC-Pair (as specified by the user in the UI) + - file_name: Name of the file to save + - content: Contents of the file + """ + raise NotImplementedError + + @abstractmethod + def read_file(self, file_name: str, mode: str | None) -> IO: + """ + Read the content of a given file by the name + + Parameters: + - file_name: Name of file to read + + Returns: + Contents of the file and metadata dict + """ + + @abstractmethod + def delete_file(self, file_name: str) -> None: + """ + Delete a file by its name. + + Parameters: + - file_name: Name of file to delete + """ + + +class PostgresBackedFileStore(FileStore): + def __init__(self, db_session: Session): + self.db_session = db_session + + def save_file(self, file_name: str, content: IO) -> None: + try: + # The large objects in postgres are saved as special objects can can be listed with + # SELECT * FROM pg_largeobject_metadata; + obj_id = create_populate_lobj(content=content, db_session=self.db_session) + upsert_pgfilestore( + file_name=file_name, lobj_oid=obj_id, db_session=self.db_session + ) + self.db_session.commit() + except Exception: + self.db_session.rollback() + raise + + def read_file(self, file_name: str, mode: str | None = None) -> IO: + file_record = get_pgfilestore_by_file_name( + file_name=file_name, db_session=self.db_session + ) + return read_lobj( + lobj_oid=file_record.lobj_oid, db_session=self.db_session, mode=mode + ) + + def delete_file(self, file_name: str) -> None: + try: + file_record = get_pgfilestore_by_file_name( + file_name=file_name, db_session=self.db_session + ) + delete_lobj_by_id(file_record.lobj_oid, db_session=self.db_session) + delete_pgfilestore_by_file_name( + file_name=file_name, db_session=self.db_session + ) + self.db_session.commit() + except Exception: + self.db_session.rollback() + raise + + +def get_default_file_store(db_session: Session) -> FileStore: + # The only supported file store now is the Postgres File Store + return PostgresBackedFileStore(db_session=db_session) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 2405ffed6d..d0628a795c 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -859,3 +859,9 @@ class KVStore(Base): key: Mapped[str] = mapped_column(String, primary_key=True) value: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=False) + + +class PGFileStore(Base): + __tablename__ = "file_store" + file_name = mapped_column(String, primary_key=True) + lobj_oid = mapped_column(Integer, nullable=False) diff --git a/backend/danswer/db/pg_file_store.py b/backend/danswer/db/pg_file_store.py new file mode 100644 index 0000000000..91a57adab7 --- /dev/null +++ b/backend/danswer/db/pg_file_store.py @@ -0,0 +1,93 @@ +from io import BytesIO +from typing import IO + +from psycopg2.extensions import connection +from sqlalchemy.orm import Session + +from danswer.db.models import PGFileStore +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def get_pg_conn_from_session(db_session: Session) -> connection: + return db_session.connection().connection.connection # type: ignore + + +def create_populate_lobj( + content: IO, + db_session: Session, +) -> int: + """Note, this does not commit the changes to the DB + This is because the commit should happen with the PGFileStore row creation + That step finalizes both the Large Object and the table tracking it + """ + pg_conn = get_pg_conn_from_session(db_session) + large_object = pg_conn.lobject() + + large_object.write(content.read()) + large_object.close() + + return large_object.oid + + +def read_lobj(lobj_oid: int, db_session: Session, mode: str | None = None) -> IO: + pg_conn = get_pg_conn_from_session(db_session) + large_object = ( + pg_conn.lobject(lobj_oid, mode=mode) if mode else pg_conn.lobject(lobj_oid) + ) + return BytesIO(large_object.read()) + + +def delete_lobj_by_id( + lobj_oid: int, + db_session: Session, +) -> None: + pg_conn = get_pg_conn_from_session(db_session) + pg_conn.lobject(lobj_oid).unlink() + + +def upsert_pgfilestore( + file_name: str, lobj_oid: int, db_session: Session, commit: bool = False +) -> PGFileStore: + pgfilestore = db_session.query(PGFileStore).filter_by(file_name=file_name).first() + + if pgfilestore: + try: + # This should not happen in normal execution + delete_lobj_by_id(lobj_oid=pgfilestore.lobj_oid, db_session=db_session) + except Exception: + # If the delete fails as well, the large object doesn't exist anyway and even if it + # fails to delete, it's not too terrible as most files sizes are insignificant + logger.error( + f"Failed to delete large object with oid {pgfilestore.lobj_oid}" + ) + + pgfilestore.lobj_oid = lobj_oid + else: + pgfilestore = PGFileStore(file_name=file_name, lobj_oid=lobj_oid) + db_session.add(pgfilestore) + + if commit: + db_session.commit() + + return pgfilestore + + +def get_pgfilestore_by_file_name( + file_name: str, + db_session: Session, +) -> PGFileStore: + pgfilestore = db_session.query(PGFileStore).filter_by(file_name=file_name).first() + + if not pgfilestore: + raise RuntimeError(f"File by name {file_name} does not exist or was deleted") + + return pgfilestore + + +def delete_pgfilestore_by_file_name( + file_name: str, + db_session: Session, +) -> None: + db_session.query(PGFileStore).filter_by(file_name=file_name).delete() diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index 8c3e50936e..ada53b29ce 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -1,3 +1,5 @@ +import os +import uuid from typing import cast from fastapi import APIRouter @@ -13,7 +15,6 @@ from danswer.auth.users import current_user from danswer.background.celery.celery_utils import get_deletion_status from danswer.configs.constants import DocumentSource -from danswer.connectors.file.utils import write_temp_files from danswer.connectors.gmail.connector_auth import delete_gmail_service_account_key from danswer.connectors.gmail.connector_auth import delete_google_app_gmail_cred from danswer.connectors.gmail.connector_auth import get_gmail_auth_url @@ -57,6 +58,7 @@ from danswer.db.document import get_document_cnts_for_cc_pairs from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_session +from danswer.db.file_store import get_default_file_store from danswer.db.index_attempt import cancel_indexing_attempts_for_connector from danswer.db.index_attempt import cancel_indexing_attempts_past_model from danswer.db.index_attempt import create_index_attempt @@ -335,18 +337,23 @@ def admin_google_drive_auth( @router.post("/admin/connector/file/upload") def upload_files( - files: list[UploadFile], _: User = Depends(current_admin_user) + files: list[UploadFile], + _: User = Depends(current_admin_user), + db_session: Session = Depends(get_session), ) -> FileUploadResponse: for file in files: if not file.filename: raise HTTPException(status_code=400, detail="File name cannot be empty") try: - file_paths = write_temp_files( - [(cast(str, file.filename), file.file) for file in files] - ) + file_store = get_default_file_store(db_session) + deduped_file_paths = [] + for file in files: + file_path = os.path.join(str(uuid.uuid4()), cast(str, file.filename)) + deduped_file_paths.append(file_path) + file_store.save_file(file_name=file_path, content=file.file) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) - return FileUploadResponse(file_paths=file_paths) + return FileUploadResponse(file_paths=deduped_file_paths) @router.get("/admin/connector/indexing-status") diff --git a/backend/danswer/server/manage/administrative.py b/backend/danswer/server/manage/administrative.py index 26e2dfd54c..d3a9c4d3b7 100644 --- a/backend/danswer/server/manage/administrative.py +++ b/backend/danswer/server/manage/administrative.py @@ -11,6 +11,7 @@ from danswer.auth.users import current_admin_user from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ +from danswer.configs.constants import DocumentSource from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY from danswer.configs.constants import GEN_AI_DETECTED_MODEL from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER @@ -21,6 +22,7 @@ from danswer.db.feedback import fetch_docs_ranked_by_boost from danswer.db.feedback import update_document_boost from danswer.db.feedback import update_document_hidden +from danswer.db.file_store import get_default_file_store from danswer.db.models import User from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index @@ -254,3 +256,9 @@ def create_deletion_attempt_for_connector_id( cleanup_connector_credential_pair_task.apply_async( kwargs=dict(connector_id=connector_id, credential_id=credential_id), ) + + if cc_pair.connector.source == DocumentSource.FILE: + connector = cc_pair.connector + file_store = get_default_file_store(db_session) + for file_name in connector.connector_specific_config["file_locations"]: + file_store.delete_file(file_name) diff --git a/backend/danswer/utils/variable_functionality.py b/backend/danswer/utils/variable_functionality.py index d365fdcac6..fd913e63be 100644 --- a/backend/danswer/utils/variable_functionality.py +++ b/backend/danswer/utils/variable_functionality.py @@ -3,7 +3,6 @@ from typing import Any from danswer.utils.logger import setup_logger -from danswer.utils.timing import log_function_time logger = setup_logger() @@ -23,7 +22,6 @@ def get_is_ee_version(self) -> bool: global_version = DanswerVersion() -@log_function_time(print_only=True, include_args=True) @functools.lru_cache(maxsize=128) def fetch_versioned_implementation(module: str, attribute: str) -> Any: logger.info("Fetching versioned implementation for %s.%s", module, attribute) From ec48142a2d388090a4cdc877ce2442042488d0ce Mon Sep 17 00:00:00 2001 From: Weves Date: Fri, 22 Mar 2024 15:20:12 -0700 Subject: [PATCH 06/58] Move some of the user re-work stuff to MIT repo --- backend/danswer/auth/users.py | 33 ++++++++++++++++++-------- backend/danswer/server/manage/users.py | 4 ++-- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 31bdc41a20..975358b6cd 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -279,13 +279,32 @@ async def logout( # take care of that in `double_check_user` ourself. This is needed, since # we want the /me endpoint to still return a user even if they are not # yet verified, so that the frontend knows they exist -optional_valid_user = fastapi_users.current_user(active=True, optional=True) +optional_fastapi_current_user = fastapi_users.current_user(active=True, optional=True) -async def double_check_user( +async def optional_user_( request: Request, user: User | None, db_session: Session, +) -> User | None: + """NOTE: `request` and `db_session` are not used here, but are included + for the EE version of this function.""" + return user + + +async def optional_user( + request: Request, + user: User | None = Depends(optional_fastapi_current_user), + db_session: Session = Depends(get_session), +) -> User | None: + versioned_fetch_user = fetch_versioned_implementation( + "danswer.auth.users", "optional_user_" + ) + return await versioned_fetch_user(request, user, db_session) + + +async def double_check_user( + user: User | None, optional: bool = DISABLE_AUTH, ) -> User | None: if optional: @@ -307,15 +326,9 @@ async def double_check_user( async def current_user( - request: Request, - user: User | None = Depends(optional_valid_user), - db_session: Session = Depends(get_session), + user: User | None = Depends(optional_user), ) -> User | None: - double_check_user = fetch_versioned_implementation( - "danswer.auth.users", "double_check_user" - ) - user = await double_check_user(request, user, db_session) - return user + return await double_check_user(user) async def current_admin_user(user: User | None = Depends(current_user)) -> User | None: diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index 539d7212f2..635d6dbe56 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -11,7 +11,7 @@ from danswer.auth.schemas import UserRole from danswer.auth.users import current_admin_user from danswer.auth.users import current_user -from danswer.auth.users import optional_valid_user +from danswer.auth.users import optional_user from danswer.db.engine import get_session from danswer.db.engine import get_sqlalchemy_async_engine from danswer.db.models import User @@ -57,7 +57,7 @@ async def get_user_role(user: User = Depends(current_user)) -> UserRoleResponse: @router.get("/me") -def verify_user_logged_in(user: User | None = Depends(optional_valid_user)) -> UserInfo: +def verify_user_logged_in(user: User | None = Depends(optional_user)) -> UserInfo: # NOTE: this does not use `current_user` / `current_admin_user` because we don't want # to enforce user verification here - the frontend always wants to get the info about # the current user regardless of if they are currently verified From 89e72783a7105183db193eff944bacd54795d473 Mon Sep 17 00:00:00 2001 From: Weves Date: Fri, 22 Mar 2024 21:28:47 -0700 Subject: [PATCH 07/58] Add some private Persona / Document Set stuff --- backend/danswer/background/celery/celery.py | 2 +- backend/danswer/chat/load_yamls.py | 1 + backend/danswer/db/chat.py | 45 ++- backend/danswer/db/document_set.py | 83 +++- backend/danswer/db/models.py | 168 +++++++- backend/danswer/db/persona.py | 85 ++++ backend/danswer/db/slack_bot_config.py | 1 + .../server/features/document_set/api.py | 23 +- .../server/features/document_set/models.py | 21 + .../danswer/server/features/persona/api.py | 69 +--- .../danswer/server/features/persona/models.py | 10 + .../sets/DocumentSetCreationForm.tsx | 365 +++++++++++------- .../documents/sets/[documentSetId]/page.tsx | 109 ++++++ web/src/app/admin/documents/sets/hooks.tsx | 6 +- web/src/app/admin/documents/sets/lib.ts | 30 +- web/src/app/admin/documents/sets/new/page.tsx | 77 ++++ web/src/app/admin/documents/sets/page.tsx | 80 +--- web/src/app/admin/personas/PersonaEditor.tsx | 81 +++- web/src/app/admin/personas/interfaces.ts | 2 + web/src/app/admin/personas/lib.ts | 8 + web/src/app/admin/personas/page.tsx | 2 +- web/src/components/Bubble.tsx | 2 +- web/src/components/icons/icons.tsx | 8 + web/src/lib/constants.ts | 7 + web/src/lib/hooks.ts | 29 +- web/src/lib/types.ts | 16 + 26 files changed, 1052 insertions(+), 278 deletions(-) create mode 100644 backend/danswer/db/persona.py create mode 100644 web/src/app/admin/documents/sets/[documentSetId]/page.tsx create mode 100644 web/src/app/admin/documents/sets/new/page.tsx diff --git a/backend/danswer/background/celery/celery.py b/backend/danswer/background/celery/celery.py index fbd823f224..408f12f3a0 100644 --- a/backend/danswer/background/celery/celery.py +++ b/backend/danswer/background/celery/celery.py @@ -182,7 +182,7 @@ def check_for_document_sets_sync_task() -> None: with Session(get_sqlalchemy_engine()) as db_session: # check if any document sets are not synced document_set_info = fetch_document_sets( - db_session=db_session, include_outdated=True + user_id=None, db_session=db_session, include_outdated=True ) for document_set, _ in document_set_info: if not document_set.is_up_to_date: diff --git a/backend/danswer/chat/load_yamls.py b/backend/danswer/chat/load_yamls.py index d85def58d0..0800abb70a 100644 --- a/backend/danswer/chat/load_yamls.py +++ b/backend/danswer/chat/load_yamls.py @@ -97,6 +97,7 @@ def load_personas_from_yaml( document_sets=doc_sets, default_persona=True, shared=True, + is_public=True, db_session=db_session, ) diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index cc08003197..343912e275 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -20,10 +20,13 @@ from danswer.db.models import ChatSession from danswer.db.models import DocumentSet as DBDocumentSet from danswer.db.models import Persona +from danswer.db.models import Persona__User +from danswer.db.models import Persona__UserGroup from danswer.db.models import Prompt from danswer.db.models import SearchDoc from danswer.db.models import SearchDoc as DBSearchDoc from danswer.db.models import StarterMessage +from danswer.db.models import User__UserGroup from danswer.search.models import RecencyBiasSetting from danswer.search.models import RetrievalDocs from danswer.search.models import SavedSearchDoc @@ -35,11 +38,17 @@ def get_chat_session_by_id( - chat_session_id: int, user_id: UUID | None, db_session: Session + chat_session_id: int, + user_id: UUID | None, + db_session: Session, + include_deleted: bool = False, ) -> ChatSession: - stmt = select(ChatSession).where( - ChatSession.id == chat_session_id, ChatSession.user_id == user_id - ) + stmt = select(ChatSession).where(ChatSession.id == chat_session_id) + + # if user_id is None, assume this is an admin who should be able + # to view all chat sessions + if user_id is not None: + stmt = stmt.where(ChatSession.user_id == user_id) result = db_session.execute(stmt) chat_session = result.scalar_one_or_none() @@ -47,7 +56,7 @@ def get_chat_session_by_id( if not chat_session: raise ValueError("Invalid Chat Session ID provided") - if chat_session.deleted: + if not include_deleted and chat_session.deleted: raise ValueError("Chat session has been deleted") return chat_session @@ -468,6 +477,7 @@ def upsert_persona( llm_model_version_override: str | None, starter_messages: list[StarterMessage] | None, shared: bool, + is_public: bool, db_session: Session, persona_id: int | None = None, default_persona: bool = False, @@ -494,6 +504,7 @@ def upsert_persona( persona.llm_model_version_override = llm_model_version_override persona.starter_messages = starter_messages persona.deleted = False # Un-delete if previously deleted + persona.is_public = is_public # Do not delete any associations manually added unless # a new updated list is provided @@ -509,6 +520,7 @@ def upsert_persona( persona = Persona( id=persona_id, user_id=None if shared else user_id, + is_public=is_public, name=name, description=description, num_chunks=num_chunks, @@ -638,9 +650,28 @@ def get_personas( include_slack_bot_personas: bool = False, include_deleted: bool = False, ) -> Sequence[Persona]: - stmt = select(Persona) + stmt = select(Persona).distinct() if user_id is not None: - stmt = stmt.where(or_(Persona.user_id == user_id, Persona.user_id.is_(None))) + # Subquery to find all groups the user belongs to + user_groups_subquery = ( + select(User__UserGroup.user_group_id) + .where(User__UserGroup.user_id == user_id) + .subquery() + ) + + # Include personas where the user is directly related or part of a user group that has access + access_conditions = or_( + Persona.is_public == True, # noqa: E712 + Persona.id.in_( # User has access through list of users with access + select(Persona__User.persona_id).where(Persona__User.user_id == user_id) + ), + Persona.id.in_( # User is part of a group that has access + select(Persona__UserGroup.persona_id).where( + Persona__UserGroup.user_group_id.in_(user_groups_subquery) # type: ignore + ) + ), + ) + stmt = stmt.where(access_conditions) if not include_default: stmt = stmt.where(Persona.default_persona.is_(False)) diff --git a/backend/danswer/db/document_set.py b/backend/danswer/db/document_set.py index 848f508837..c3bab1e741 100644 --- a/backend/danswer/db/document_set.py +++ b/backend/danswer/db/document_set.py @@ -16,6 +16,7 @@ from danswer.db.models import DocumentSet__ConnectorCredentialPair from danswer.server.features.document_set.models import DocumentSetCreationRequest from danswer.server.features.document_set.models import DocumentSetUpdateRequest +from danswer.utils.variable_functionality import fetch_versioned_implementation def _delete_document_set_cc_pairs__no_commit( @@ -41,6 +42,12 @@ def _mark_document_set_cc_pairs_as_outdated__no_commit( row.is_current = False +def delete_document_set_privacy__no_commit( + document_set_id: int, db_session: Session +) -> None: + """No private document sets in Danswer MIT""" + + def get_document_set_by_id( db_session: Session, document_set_id: int ) -> DocumentSetDBModel | None: @@ -67,6 +74,17 @@ def get_document_sets_by_ids( ).all() +def make_doc_set_private( + document_set_id: int, + user_ids: list[UUID] | None, + group_ids: list[int] | None, + db_session: Session, +) -> None: + # May cause error if someone switches down to MIT from EE + if user_ids or group_ids: + raise NotImplementedError("Danswer MIT does not support private Document Sets") + + def insert_document_set( document_set_creation_request: DocumentSetCreationRequest, user_id: UUID | None, @@ -83,6 +101,7 @@ def insert_document_set( name=document_set_creation_request.name, description=document_set_creation_request.description, user_id=user_id, + is_public=document_set_creation_request.is_public, ) db_session.add(new_document_set_row) db_session.flush() # ensure the new document set gets assigned an ID @@ -96,6 +115,19 @@ def insert_document_set( for cc_pair_id in document_set_creation_request.cc_pair_ids ] db_session.add_all(ds_cc_pairs) + + versioned_private_doc_set_fn = fetch_versioned_implementation( + "danswer.db.document_set", "make_doc_set_private" + ) + + # Private Document Sets + versioned_private_doc_set_fn( + document_set_id=new_document_set_row.id, + user_ids=document_set_creation_request.users, + group_ids=document_set_creation_request.groups, + db_session=db_session, + ) + db_session.commit() except: db_session.rollback() @@ -130,6 +162,19 @@ def update_document_set( document_set_row.description = document_set_update_request.description document_set_row.is_up_to_date = False + document_set_row.is_public = document_set_update_request.is_public + + versioned_private_doc_set_fn = fetch_versioned_implementation( + "danswer.db.document_set", "make_doc_set_private" + ) + + # Private Document Sets + versioned_private_doc_set_fn( + document_set_id=document_set_row.id, + user_ids=document_set_update_request.users, + group_ids=document_set_update_request.groups, + db_session=db_session, + ) # update the attached CC pairs # first, mark all existing CC pairs as not current @@ -205,6 +250,15 @@ def mark_document_set_as_to_be_deleted( _delete_document_set_cc_pairs__no_commit( db_session=db_session, document_set_id=document_set_id ) + + # delete all private document set information + versioned_delete_private_fn = fetch_versioned_implementation( + "danswer.db.document_set", "delete_document_set_privacy__no_commit" + ) + versioned_delete_private_fn( + document_set_id=document_set_id, db_session=db_session + ) + # mark the row as needing a sync, it will be deleted there since there # are no more relationships to cc pairs document_set_row.is_up_to_date = False @@ -248,7 +302,7 @@ def mark_cc_pair__document_set_relationships_to_be_deleted__no_commit( def fetch_document_sets( - db_session: Session, include_outdated: bool = False + user_id: UUID | None, db_session: Session, include_outdated: bool = False ) -> list[tuple[DocumentSetDBModel, list[ConnectorCredentialPair]]]: """Return is a list where each element contains a tuple of: 1. The document set itself @@ -301,6 +355,31 @@ def fetch_document_sets( ] +def fetch_all_document_sets(db_session: Session) -> Sequence[DocumentSetDBModel]: + """Used for Admin UI where they should have visibility into all document sets""" + return db_session.scalars(select(DocumentSetDBModel)).all() + + +def fetch_user_document_sets( + user_id: UUID | None, db_session: Session +) -> list[tuple[DocumentSetDBModel, list[ConnectorCredentialPair]]]: + # If Auth is turned off, all document sets become visible + # document sets are not permission enforced, only for organizational purposes + # the documents themselves are permission enforced + if user_id is None: + return fetch_document_sets( + user_id=user_id, db_session=db_session, include_outdated=True + ) + + versioned_fetch_doc_sets_fn = fetch_versioned_implementation( + "danswer.db.document_set", "fetch_document_sets" + ) + + return versioned_fetch_doc_sets_fn( + user_id=user_id, db_session=db_session, include_outdated=True + ) + + def fetch_documents_for_document_set( document_set_id: int, db_session: Session, current_only: bool = True ) -> Sequence[Document]: @@ -404,6 +483,8 @@ def check_document_sets_are_public( db_session: Session, document_set_ids: list[int], ) -> bool: + """Checks if any of the CC-Pairs are Non Public (meaning that some documents in this document + set is not Public""" connector_credential_pair_ids = ( db_session.query( DocumentSet__ConnectorCredentialPair.connector_credential_pair_id diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index d0628a795c..cea33f52ad 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -97,6 +97,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base): "ChatSession", back_populates="user" ) prompts: Mapped[List["Prompt"]] = relationship("Prompt", back_populates="user") + # Personas owned by this user personas: Mapped[List["Persona"]] = relationship("Persona", back_populates="user") @@ -141,6 +142,22 @@ class Persona__Prompt(Base): prompt_id: Mapped[int] = mapped_column(ForeignKey("prompt.id"), primary_key=True) +class Persona__User(Base): + __tablename__ = "persona__user" + + persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True) + user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True) + + +class DocumentSet__User(Base): + __tablename__ = "document_set__user" + + document_set_id: Mapped[int] = mapped_column( + ForeignKey("document_set.id"), primary_key=True + ) + user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True) + + class DocumentSet__ConnectorCredentialPair(Base): __tablename__ = "document_set__connector_credential_pair" @@ -617,7 +634,7 @@ class ChatMessage(Base): document_feedbacks: Mapped[List["DocumentRetrievalFeedback"]] = relationship( "DocumentRetrievalFeedback", back_populates="chat_message" ) - search_docs = relationship( + search_docs: Mapped[list["SearchDoc"]] = relationship( "SearchDoc", secondary="chat_message__search_doc", back_populates="chat_messages", @@ -678,6 +695,9 @@ class DocumentSet(Base): user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) # Whether changes to the document set have been propagated is_up_to_date: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + # If `False`, then the document set is not visible to users who are not explicitly + # given access to it either via the `users` or `groups` relationships + is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) connector_credential_pairs: Mapped[list[ConnectorCredentialPair]] = relationship( "ConnectorCredentialPair", @@ -690,6 +710,18 @@ class DocumentSet(Base): secondary=Persona__DocumentSet.__table__, back_populates="document_sets", ) + # Other users with access + users: Mapped[list[User]] = relationship( + "User", + secondary=DocumentSet__User.__table__, + viewonly=True, + ) + # EE only + groups: Mapped[list["UserGroup"]] = relationship( + "UserGroup", + secondary="document_set__user_group", + viewonly=True, + ) class Prompt(Base): @@ -767,6 +799,7 @@ class Persona(Base): # where lower value IDs (e.g. created earlier) are displayed first display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None) deleted: Mapped[bool] = mapped_column(Boolean, default=False) + is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) # These are only defaults, users can select from all if desired prompts: Mapped[list[Prompt]] = relationship( @@ -780,7 +813,20 @@ class Persona(Base): secondary=Persona__DocumentSet.__table__, back_populates="personas", ) + # Owner user: Mapped[User] = relationship("User", back_populates="personas") + # Other users with access + users: Mapped[list[User]] = relationship( + "User", + secondary=Persona__User.__table__, + viewonly=True, + ) + # EE only + groups: Mapped[list["UserGroup"]] = relationship( + "UserGroup", + secondary="persona__user_group", + viewonly=True, + ) # Default personas loaded via yaml cannot have the same name __table_args__ = ( @@ -865,3 +911,123 @@ class PGFileStore(Base): __tablename__ = "file_store" file_name = mapped_column(String, primary_key=True) lobj_oid = mapped_column(Integer, nullable=False) + + +""" +************************************************************************ +Enterprise Edition Models +************************************************************************ + +These models are only used in Enterprise Edition only features in Danswer. +They are kept here to simplify the codebase and avoid having different assumptions +on the shape of data being passed around between the MIT and EE versions of Danswer. + +In the MIT version of Danswer, assume these tables are always empty. +""" + + +class SamlAccount(Base): + __tablename__ = "saml" + + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), unique=True) + encrypted_cookie: Mapped[str] = mapped_column(Text, unique=True) + expires_at: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True)) + updated_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) + + user: Mapped[User] = relationship("User") + + +class User__UserGroup(Base): + __tablename__ = "user__user_group" + + user_group_id: Mapped[int] = mapped_column( + ForeignKey("user_group.id"), primary_key=True + ) + user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True) + + +class UserGroup__ConnectorCredentialPair(Base): + __tablename__ = "user_group__connector_credential_pair" + + user_group_id: Mapped[int] = mapped_column( + ForeignKey("user_group.id"), primary_key=True + ) + cc_pair_id: Mapped[int] = mapped_column( + ForeignKey("connector_credential_pair.id"), primary_key=True + ) + # if `True`, then is part of the current state of the UserGroup + # if `False`, then is a part of the prior state of the UserGroup + # rows with `is_current=False` should be deleted when the UserGroup + # is updated and should not exist for a given UserGroup if + # `UserGroup.is_up_to_date == True` + is_current: Mapped[bool] = mapped_column( + Boolean, + default=True, + primary_key=True, + ) + + cc_pair: Mapped[ConnectorCredentialPair] = relationship( + "ConnectorCredentialPair", + ) + + +class Persona__UserGroup(Base): + __tablename__ = "persona__user_group" + + persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True) + user_group_id: Mapped[int] = mapped_column( + ForeignKey("user_group.id"), primary_key=True + ) + + +class DocumentSet__UserGroup(Base): + __tablename__ = "document_set__user_group" + + document_set_id: Mapped[int] = mapped_column( + ForeignKey("document_set.id"), primary_key=True + ) + user_group_id: Mapped[int] = mapped_column( + ForeignKey("user_group.id"), primary_key=True + ) + + +class UserGroup(Base): + __tablename__ = "user_group" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String, unique=True) + # whether or not changes to the UserGroup have been propagated to Vespa + is_up_to_date: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + # tell the sync job to clean up the group + is_up_for_deletion: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=False + ) + + users: Mapped[list[User]] = relationship( + "User", + secondary=User__UserGroup.__table__, + ) + cc_pairs: Mapped[list[ConnectorCredentialPair]] = relationship( + "ConnectorCredentialPair", + secondary=UserGroup__ConnectorCredentialPair.__table__, + viewonly=True, + ) + cc_pair_relationships: Mapped[ + list[UserGroup__ConnectorCredentialPair] + ] = relationship( + "UserGroup__ConnectorCredentialPair", + viewonly=True, + ) + personas: Mapped[list[Persona]] = relationship( + "Persona", + secondary=Persona__UserGroup.__table__, + viewonly=True, + ) + document_sets: Mapped[list[DocumentSet]] = relationship( + "DocumentSet", + secondary=DocumentSet__UserGroup.__table__, + viewonly=True, + ) diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py new file mode 100644 index 0000000000..38351b18b0 --- /dev/null +++ b/backend/danswer/db/persona.py @@ -0,0 +1,85 @@ +from uuid import UUID + +from fastapi import HTTPException +from sqlalchemy.orm import Session + +from danswer.db.chat import get_prompts_by_ids +from danswer.db.chat import upsert_persona +from danswer.db.document_set import get_document_sets_by_ids +from danswer.db.models import User +from danswer.server.features.persona.models import CreatePersonaRequest +from danswer.server.features.persona.models import PersonaSnapshot +from danswer.utils.logger import setup_logger +from danswer.utils.variable_functionality import fetch_versioned_implementation + +logger = setup_logger() + + +def make_persona_private( + persona_id: int, + user_ids: list[UUID] | None, + group_ids: list[int] | None, + db_session: Session, +) -> None: + # May cause error if someone switches down to MIT from EE + if user_ids or group_ids: + raise NotImplementedError("Danswer MIT does not support private Document Sets") + + +def create_update_persona( + persona_id: int | None, + create_persona_request: CreatePersonaRequest, + user: User | None, + db_session: Session, +) -> PersonaSnapshot: + user_id = user.id if user is not None else None + + # Permission to actually use these is checked later + document_sets = list( + get_document_sets_by_ids( + document_set_ids=create_persona_request.document_set_ids, + db_session=db_session, + ) + ) + prompts = list( + get_prompts_by_ids( + prompt_ids=create_persona_request.prompt_ids, + db_session=db_session, + ) + ) + + try: + persona = upsert_persona( + persona_id=persona_id, + user_id=user_id, + name=create_persona_request.name, + description=create_persona_request.description, + num_chunks=create_persona_request.num_chunks, + llm_relevance_filter=create_persona_request.llm_relevance_filter, + llm_filter_extraction=create_persona_request.llm_filter_extraction, + recency_bias=create_persona_request.recency_bias, + prompts=prompts, + document_sets=document_sets, + llm_model_version_override=create_persona_request.llm_model_version_override, + starter_messages=create_persona_request.starter_messages, + shared=create_persona_request.shared, + is_public=create_persona_request.is_public, + db_session=db_session, + ) + + versioned_make_persona_private = fetch_versioned_implementation( + "danswer.db.persona", "make_persona_private" + ) + + # Privatize Persona + versioned_make_persona_private( + persona_id=persona.id, + user_ids=create_persona_request.users, + group_ids=create_persona_request.groups, + db_session=db_session, + ) + + except ValueError as e: + logger.exception("Failed to create persona") + raise HTTPException(status_code=400, detail=str(e)) + return PersonaSnapshot.from_model(persona) diff --git a/backend/danswer/db/slack_bot_config.py b/backend/danswer/db/slack_bot_config.py index f2aeae7b31..3e93a76cf6 100644 --- a/backend/danswer/db/slack_bot_config.py +++ b/backend/danswer/db/slack_bot_config.py @@ -62,6 +62,7 @@ def create_slack_bot_persona( llm_model_version_override=None, starter_messages=None, shared=True, + is_public=True, default_persona=False, db_session=db_session, commit=False, diff --git a/backend/danswer/server/features/document_set/api.py b/backend/danswer/server/features/document_set/api.py index a5d6040925..f939329bf9 100644 --- a/backend/danswer/server/features/document_set/api.py +++ b/backend/danswer/server/features/document_set/api.py @@ -6,7 +6,8 @@ from danswer.auth.users import current_admin_user from danswer.auth.users import current_user from danswer.db.document_set import check_document_sets_are_public -from danswer.db.document_set import fetch_document_sets +from danswer.db.document_set import fetch_all_document_sets +from danswer.db.document_set import fetch_user_document_sets from danswer.db.document_set import insert_document_set from danswer.db.document_set import mark_document_set_as_to_be_deleted from danswer.db.document_set import update_document_set @@ -71,15 +72,28 @@ def delete_document_set( raise HTTPException(status_code=400, detail=str(e)) +@router.get("/admin/document-set") +def list_document_sets_admin( + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[DocumentSet]: + return [ + DocumentSet.from_model(ds) + for ds in fetch_all_document_sets(db_session=db_session) + ] + + """Endpoints for non-admins""" @router.get("/document-set") def list_document_sets( - _: User = Depends(current_user), + user: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> list[DocumentSet]: - document_set_info = fetch_document_sets(db_session=db_session) + document_set_info = fetch_user_document_sets( + user_id=user.id if user else None, db_session=db_session + ) return [ DocumentSet( id=document_set_db_model.id, @@ -100,6 +114,9 @@ def list_document_sets( for cc_pair in cc_pairs ], is_up_to_date=document_set_db_model.is_up_to_date, + is_public=document_set_db_model.is_public, + users=[user.id for user in document_set_db_model.users], + groups=[group.id for group in document_set_db_model.groups], ) for document_set_db_model, cc_pairs in document_set_info ] diff --git a/backend/danswer/server/features/document_set/models.py b/backend/danswer/server/features/document_set/models.py index 2f5be4587e..05ada42c89 100644 --- a/backend/danswer/server/features/document_set/models.py +++ b/backend/danswer/server/features/document_set/models.py @@ -1,3 +1,5 @@ +from uuid import UUID + from pydantic import BaseModel from danswer.db.models import DocumentSet as DocumentSetDBModel @@ -10,15 +12,27 @@ class DocumentSetCreationRequest(BaseModel): name: str description: str cc_pair_ids: list[int] + is_public: bool + # For Private Document Sets, who should be able to access these + users: list[UUID] | None = None + groups: list[int] | None = None class DocumentSetUpdateRequest(BaseModel): id: int description: str cc_pair_ids: list[int] + is_public: bool + # For Private Document Sets, who should be able to access these + users: list[UUID] + groups: list[int] class CheckDocSetPublicRequest(BaseModel): + """Note that this does not mean that the Document Set itself is to be viewable by everyone + Rather, this refers to the CC-Pairs in the Document Set, and if every CC-Pair is public + """ + document_set_ids: list[int] @@ -33,6 +47,10 @@ class DocumentSet(BaseModel): cc_pair_descriptors: list[ConnectorCredentialPairDescriptor] is_up_to_date: bool contains_non_public: bool + is_public: bool + # For Private Document Sets, who should be able to access these + users: list[UUID] + groups: list[int] @classmethod def from_model(cls, document_set_model: DocumentSetDBModel) -> "DocumentSet": @@ -60,4 +78,7 @@ def from_model(cls, document_set_model: DocumentSetDBModel) -> "DocumentSet": for cc_pair in document_set_model.connector_credential_pairs ], is_up_to_date=document_set_model.is_up_to_date, + is_public=document_set_model.is_public, + users=[user.id for user in document_set_model.users], + groups=[group.id for group in document_set_model.groups], ) diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index 1606654951..8762f40b51 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -1,6 +1,5 @@ from fastapi import APIRouter from fastapi import Depends -from fastapi import HTTPException from pydantic import BaseModel from sqlalchemy.orm import Session @@ -9,14 +8,12 @@ from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.db.chat import get_persona_by_id from danswer.db.chat import get_personas -from danswer.db.chat import get_prompts_by_ids from danswer.db.chat import mark_persona_as_deleted from danswer.db.chat import update_all_personas_display_priority from danswer.db.chat import update_persona_visibility -from danswer.db.chat import upsert_persona -from danswer.db.document_set import get_document_sets_by_ids from danswer.db.engine import get_session from danswer.db.models import User +from danswer.db.persona import create_update_persona from danswer.llm.utils import get_default_llm_version from danswer.one_shot_answer.qa_block import build_dummy_prompt from danswer.server.features.persona.models import CreatePersonaRequest @@ -31,51 +28,6 @@ basic_router = APIRouter(prefix="/persona") -def create_update_persona( - persona_id: int | None, - create_persona_request: CreatePersonaRequest, - user: User | None, - db_session: Session, -) -> PersonaSnapshot: - user_id = user.id if user is not None else None - - # Permission to actually use these is checked later - document_sets = list( - get_document_sets_by_ids( - document_set_ids=create_persona_request.document_set_ids, - db_session=db_session, - ) - ) - prompts = list( - get_prompts_by_ids( - prompt_ids=create_persona_request.prompt_ids, - db_session=db_session, - ) - ) - - try: - persona = upsert_persona( - persona_id=persona_id, - user_id=user_id, - name=create_persona_request.name, - description=create_persona_request.description, - num_chunks=create_persona_request.num_chunks, - llm_relevance_filter=create_persona_request.llm_relevance_filter, - llm_filter_extraction=create_persona_request.llm_filter_extraction, - recency_bias=create_persona_request.recency_bias, - prompts=prompts, - document_sets=document_sets, - llm_model_version_override=create_persona_request.llm_model_version_override, - starter_messages=create_persona_request.starter_messages, - shared=create_persona_request.shared, - db_session=db_session, - ) - except ValueError as e: - logger.exception("Failed to create persona") - raise HTTPException(status_code=400, detail=str(e)) - return PersonaSnapshot.from_model(persona) - - @admin_router.post("") def create_persona( create_persona_request: CreatePersonaRequest, @@ -153,6 +105,25 @@ def delete_persona( ) +@admin_router.get("") +def list_personas_admin( + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), + include_deleted: bool = False, +) -> list[PersonaSnapshot]: + return [ + PersonaSnapshot.from_model(persona) + for persona in get_personas( + db_session=db_session, + user_id=None, # user_id = None -> give back all personas + include_deleted=include_deleted, + ) + ] + + +"""Endpoints for all""" + + @basic_router.get("") def list_personas( user: User | None = Depends(current_user), diff --git a/backend/danswer/server/features/persona/models.py b/backend/danswer/server/features/persona/models.py index 4a36ad7091..a724ac5f3e 100644 --- a/backend/danswer/server/features/persona/models.py +++ b/backend/danswer/server/features/persona/models.py @@ -1,3 +1,5 @@ +from uuid import UUID + from pydantic import BaseModel from danswer.db.models import Persona @@ -13,12 +15,16 @@ class CreatePersonaRequest(BaseModel): shared: bool num_chunks: float llm_relevance_filter: bool + is_public: bool llm_filter_extraction: bool recency_bias: RecencyBiasSetting prompt_ids: list[int] document_set_ids: list[int] llm_model_version_override: str | None = None starter_messages: list[StarterMessage] | None = None + # For Private Personas, who should be able to access these + users: list[UUID] | None = None + groups: list[int] | None = None class PersonaSnapshot(BaseModel): @@ -26,6 +32,7 @@ class PersonaSnapshot(BaseModel): name: str shared: bool is_visible: bool + is_public: bool display_priority: int | None description: str num_chunks: float | None @@ -36,6 +43,7 @@ class PersonaSnapshot(BaseModel): default_persona: bool prompts: list[PromptSnapshot] document_sets: list[DocumentSet] + groups: list[int] @classmethod def from_model(cls, persona: Persona) -> "PersonaSnapshot": @@ -47,6 +55,7 @@ def from_model(cls, persona: Persona) -> "PersonaSnapshot": name=persona.name, shared=persona.user_id is None, is_visible=persona.is_visible, + is_public=persona.is_public, display_priority=persona.display_priority, description=persona.description, num_chunks=persona.num_chunks, @@ -60,6 +69,7 @@ def from_model(cls, persona: Persona) -> "PersonaSnapshot": DocumentSet.from_model(document_set_model) for document_set_model in persona.document_sets ], + groups=[user_group.id for user_group in persona.groups], ) diff --git a/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx b/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx index abc3c87359..e8ef546ae6 100644 --- a/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx +++ b/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx @@ -2,13 +2,19 @@ import { ArrayHelpers, FieldArray, Form, Formik } from "formik"; import * as Yup from "yup"; import { PopupSpec } from "@/components/admin/connectors/Popup"; import { createDocumentSet, updateDocumentSet } from "./lib"; -import { ConnectorIndexingStatus, DocumentSet } from "@/lib/types"; -import { TextFormField } from "@/components/admin/connectors/Field"; +import { ConnectorIndexingStatus, DocumentSet, UserGroup } from "@/lib/types"; +import { + BooleanFormField, + TextFormField, +} from "@/components/admin/connectors/Field"; import { ConnectorTitle } from "@/components/admin/connectors/ConnectorTitle"; -import { Button } from "@tremor/react"; +import { Button, Divider, Text } from "@tremor/react"; +import { EE_ENABLED } from "@/lib/constants"; +import { FiUsers } from "react-icons/fi"; interface SetCreationPopupProps { ccPairs: ConnectorIndexingStatus[]; + userGroups: UserGroup[] | undefined; onClose: () => void; setPopup: (popupSpec: PopupSpec | null) => void; existingDocumentSet?: DocumentSet; @@ -16,6 +22,7 @@ interface SetCreationPopupProps { export const DocumentSetCreationForm = ({ ccPairs, + userGroups, onClose, setPopup, existingDocumentSet, @@ -24,107 +31,182 @@ export const DocumentSetCreationForm = ({ return (
-
{ + return ccPairDescriptor.id; + } + ) + : ([] as number[]), + is_public: existingDocumentSet ? existingDocumentSet.is_public : true, + users: existingDocumentSet ? existingDocumentSet.users : [], + groups: existingDocumentSet ? existingDocumentSet.groups : [], + }} + validationSchema={Yup.object().shape({ + name: Yup.string().required("Please enter a name for the set"), + description: Yup.string().required( + "Please enter a description for the set" + ), + cc_pair_ids: Yup.array() + .of(Yup.number().required()) + .required("Please select at least one connector"), + })} + onSubmit={async (values, formikHelpers) => { + formikHelpers.setSubmitting(true); + // If the document set is public, then we don't want to send any groups + const processedValues = { + ...values, + groups: values.is_public ? [] : values.groups, + }; + + let response; + if (isUpdate) { + response = await updateDocumentSet({ + id: existingDocumentSet.id, + ...processedValues, + }); + } else { + response = await createDocumentSet(processedValues); + } + formikHelpers.setSubmitting(false); + if (response.ok) { + setPopup({ + message: isUpdate + ? "Successfully updated document set!" + : "Successfully created document set!", + type: "success", + }); + onClose(); + } else { + const errorMsg = await response.text(); + setPopup({ + message: isUpdate + ? `Error updating document set - ${errorMsg}` + : `Error creating document set - ${errorMsg}`, + type: "error", + }); + } + }} > -
event.stopPropagation()} - > - { - return ccPairDescriptor.id; - } - ) - : ([] as number[]), - }} - validationSchema={Yup.object().shape({ - name: Yup.string().required("Please enter a name for the set"), - description: Yup.string().required( - "Please enter a description for the set" - ), - ccPairIds: Yup.array() - .of(Yup.number().required()) - .required("Please select at least one connector"), - })} - onSubmit={async (values, formikHelpers) => { - formikHelpers.setSubmitting(true); - let response; - if (isUpdate) { - response = await updateDocumentSet({ - id: existingDocumentSet.id, - ...values, - }); - } else { - response = await createDocumentSet(values); - } - formikHelpers.setSubmitting(false); - if (response.ok) { - setPopup({ - message: isUpdate - ? "Successfully updated document set!" - : "Successfully created document set!", - type: "success", - }); - onClose(); - } else { - const errorMsg = await response.text(); - setPopup({ - message: isUpdate - ? `Error updating document set - ${errorMsg}` - : `Error creating document set - ${errorMsg}`, - type: "error", - }); - } - }} - > - {({ isSubmitting, values }) => ( -
-

- {isUpdate - ? "Update a Document Set" - : "Create a new Document Set"} -

- - ( + + + + + + +

+ Pick your connectors: +

+

+ All documents indexed by the selected connectors will be a part of + this document set. +

+ ( +
+ {ccPairs.map((ccPair) => { + const ind = values.cc_pair_ids.indexOf(ccPair.cc_pair_id); + let isSelected = ind !== -1; + return ( +
{ + if (isSelected) { + arrayHelpers.remove(ind); + } else { + arrayHelpers.push(ccPair.cc_pair_id); + } + }} + > +
+ +
+
+ ); + })} +
+ )} + /> + + {EE_ENABLED && userGroups && userGroups.length > 0 && ( +
+ + + + If the document set is public, then it will be visible to{" "} + all users. If it is not public, then only users in + the specified groups will be able to see it. + + } /> + +

- Pick your connectors: + Groups with Access

-

- All documents indexed by the selected connectors will be a - part of this document set. -

- ( -
- {ccPairs.map((ccPair) => { - const ind = values.ccPairIds.indexOf(ccPair.cc_pair_id); - let isSelected = ind !== -1; - return ( -
+ + If any groups are specified, then this Document Set will + only be visible to the specified groups. If no groups are + specified, then the Document Set will be visible to all + users. + + ( +
+ {userGroups.map((userGroup) => { + const ind = values.groups.indexOf(userGroup.id); + let isSelected = ind !== -1; + return ( +
{ - if (isSelected) { - arrayHelpers.remove(ind); - } else { - arrayHelpers.push(ccPair.cc_pair_id); - } - }} - > -
- -
-
- ); - })} -
- )} - /> -
- -
- + (isSelected + ? " bg-background-strong" + : " hover:bg-hover") + } + onClick={() => { + if (isSelected) { + arrayHelpers.remove(ind); + } else { + arrayHelpers.push(userGroup.id); + } + }} + > +
+ {" "} + {userGroup.name} +
+
+ ); + })} +
+ )} + /> + + ) : ( + + This Document Set is public, so this does not apply. If you + want to control which user groups see this Document Set, + mark it as non-public! + + )} +
)} -
-
-
+
+ +
+ + )} +
); }; diff --git a/web/src/app/admin/documents/sets/[documentSetId]/page.tsx b/web/src/app/admin/documents/sets/[documentSetId]/page.tsx new file mode 100644 index 0000000000..464be57435 --- /dev/null +++ b/web/src/app/admin/documents/sets/[documentSetId]/page.tsx @@ -0,0 +1,109 @@ +"use client"; + +import { ErrorCallout } from "@/components/ErrorCallout"; +import { useDocumentSets } from "../hooks"; +import { + useConnectorCredentialIndexingStatus, + useUserGroups, +} from "@/lib/hooks"; +import { ThreeDotsLoader } from "@/components/Loading"; +import { AdminPageTitle } from "@/components/admin/Title"; +import { BookmarkIcon } from "@/components/icons/icons"; +import { BackButton } from "@/components/BackButton"; +import { Card } from "@tremor/react"; +import { DocumentSetCreationForm } from "../DocumentSetCreationForm"; +import { useRouter } from "next/navigation"; +import { usePopup } from "@/components/admin/connectors/Popup"; + +function Main({ documentSetId }: { documentSetId: number }) { + const router = useRouter(); + const { popup, setPopup } = usePopup(); + + const { + data: documentSets, + isLoading: isDocumentSetsLoading, + error: documentSetsError, + } = useDocumentSets(); + + const { + data: ccPairs, + isLoading: isCCPairsLoading, + error: ccPairsError, + } = useConnectorCredentialIndexingStatus(); + + // EE only + const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups(); + + if (isDocumentSetsLoading || isCCPairsLoading || userGroupsIsLoading) { + return ; + } + + if (documentSetsError || !documentSets) { + return ( + + ); + } + + if (ccPairsError || !ccPairs) { + return ( + + ); + } + + const documentSet = documentSets.find( + (documentSet) => documentSet.id === documentSetId + ); + if (!documentSet) { + return ( + + ); + } + + return ( +
+ {popup} + + } + title={documentSet.name} + /> + + + { + router.push("/admin/documents/sets"); + }} + setPopup={setPopup} + existingDocumentSet={documentSet} + /> + +
+ ); +} + +export default function Page({ + params, +}: { + params: { documentSetId: string }; +}) { + const documentSetId = parseInt(params.documentSetId); + + return ( +
+ + +
+
+ ); +} diff --git a/web/src/app/admin/documents/sets/hooks.tsx b/web/src/app/admin/documents/sets/hooks.tsx index 179e36385d..608e3f2b2a 100644 --- a/web/src/app/admin/documents/sets/hooks.tsx +++ b/web/src/app/admin/documents/sets/hooks.tsx @@ -2,12 +2,12 @@ import { errorHandlingFetcher } from "@/lib/fetcher"; import { DocumentSet } from "@/lib/types"; import useSWR, { mutate } from "swr"; -export const useDocumentSets = () => { - const url = "/api/manage/document-set"; +export function useDocumentSets() { + const url = "/api/manage/admin/document-set"; const swrResponse = useSWR(url, errorHandlingFetcher); return { ...swrResponse, refreshDocumentSets: () => mutate(url), }; -}; +} diff --git a/web/src/app/admin/documents/sets/lib.ts b/web/src/app/admin/documents/sets/lib.ts index 71ddcf8d90..2184504cc3 100644 --- a/web/src/app/admin/documents/sets/lib.ts +++ b/web/src/app/admin/documents/sets/lib.ts @@ -1,13 +1,19 @@ interface DocumentSetCreationRequest { name: string; description: string; - ccPairIds: number[]; + cc_pair_ids: number[]; + is_public: boolean; + users: string[]; + groups: number[]; } export const createDocumentSet = async ({ name, description, - ccPairIds, + cc_pair_ids, + is_public, + users, + groups, }: DocumentSetCreationRequest) => { return fetch("/api/manage/admin/document-set", { method: "POST", @@ -17,7 +23,10 @@ export const createDocumentSet = async ({ body: JSON.stringify({ name, description, - cc_pair_ids: ccPairIds, + cc_pair_ids, + is_public, + users, + groups, }), }); }; @@ -25,13 +34,19 @@ export const createDocumentSet = async ({ interface DocumentSetUpdateRequest { id: number; description: string; - ccPairIds: number[]; + cc_pair_ids: number[]; + is_public: boolean; + users: string[]; + groups: number[]; } export const updateDocumentSet = async ({ id, description, - ccPairIds, + cc_pair_ids, + is_public, + users, + groups, }: DocumentSetUpdateRequest) => { return fetch("/api/manage/admin/document-set", { method: "PATCH", @@ -41,7 +56,10 @@ export const updateDocumentSet = async ({ body: JSON.stringify({ id, description, - cc_pair_ids: ccPairIds, + cc_pair_ids, + is_public, + users, + groups, }), }); }; diff --git a/web/src/app/admin/documents/sets/new/page.tsx b/web/src/app/admin/documents/sets/new/page.tsx new file mode 100644 index 0000000000..8919d1a959 --- /dev/null +++ b/web/src/app/admin/documents/sets/new/page.tsx @@ -0,0 +1,77 @@ +"use client"; + +import { AdminPageTitle } from "@/components/admin/Title"; +import { BookmarkIcon } from "@/components/icons/icons"; +import { DocumentSetCreationForm } from "../DocumentSetCreationForm"; +import { + useConnectorCredentialIndexingStatus, + useUserGroups, +} from "@/lib/hooks"; +import { ThreeDotsLoader } from "@/components/Loading"; +import { usePopup } from "@/components/admin/connectors/Popup"; +import { Card } from "@tremor/react"; +import { BackButton } from "@/components/BackButton"; +import { ErrorCallout } from "@/components/ErrorCallout"; +import { useRouter } from "next/navigation"; +import { UserGroup } from "@/lib/types"; + +function Main() { + const { popup, setPopup } = usePopup(); + const router = useRouter(); + + const { + data: ccPairs, + isLoading: isCCPairsLoading, + error: ccPairsError, + } = useConnectorCredentialIndexingStatus(); + + // EE only + const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups(); + + if (isCCPairsLoading || userGroupsIsLoading) { + return ; + } + + if (ccPairsError || !ccPairs) { + return ( + + ); + } + + return ( + <> + {popup} + + + { + router.push("/admin/documents/sets"); + }} + setPopup={setPopup} + /> + + + ); +} + +const Page = () => { + return ( +
+ + + } + title="New Document Set" + /> + +
+
+ ); +}; + +export default Page; diff --git a/web/src/app/admin/documents/sets/page.tsx b/web/src/app/admin/documents/sets/page.tsx index 777aea092f..fa5486bc77 100644 --- a/web/src/app/admin/documents/sets/page.tsx +++ b/web/src/app/admin/documents/sets/page.tsx @@ -1,14 +1,8 @@ "use client"; -import { LoadingAnimation, ThreeDotsLoader } from "@/components/Loading"; +import { ThreeDotsLoader } from "@/components/Loading"; import { PageSelector } from "@/components/PageSelector"; -import { BasicTable } from "@/components/admin/connectors/BasicTable"; -import { - BookmarkIcon, - EditIcon, - InfoIcon, - TrashIcon, -} from "@/components/icons/icons"; +import { BookmarkIcon, InfoIcon } from "@/components/icons/icons"; import { Table, TableHead, @@ -24,7 +18,6 @@ import { useConnectorCredentialIndexingStatus } from "@/lib/hooks"; import { ConnectorIndexingStatus, DocumentSet } from "@/lib/types"; import { useState } from "react"; import { useDocumentSets } from "./hooks"; -import { DocumentSetCreationForm } from "./DocumentSetCreationForm"; import { ConnectorTitle } from "@/components/admin/connectors/ConnectorTitle"; import { deleteDocumentSet } from "./lib"; import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup"; @@ -37,49 +30,31 @@ import { FiEdit, } from "react-icons/fi"; import { DeleteButton } from "@/components/DeleteButton"; +import Link from "next/link"; +import { useRouter } from "next/navigation"; const numToDisplay = 50; -const EditRow = ({ - documentSet, - ccPairs, - setPopup, - refreshDocumentSets, -}: { - documentSet: DocumentSet; - ccPairs: ConnectorIndexingStatus[]; - setPopup: (popupSpec: PopupSpec | null) => void; - refreshDocumentSets: () => void; -}) => { - const [isEditPopupOpen, setEditPopupOpen] = useState(false); +const EditRow = ({ documentSet }: { documentSet: DocumentSet }) => { + const router = useRouter(); + const [isSyncingTooltipOpen, setIsSyncingTooltipOpen] = useState(false); return (
- {isEditPopupOpen && ( - { - setEditPopupOpen(false); - refreshDocumentSets(); - }} - setPopup={setPopup} - existingDocumentSet={documentSet} - /> - )} {isSyncingTooltipOpen && ( -
+
Cannot update while syncing! Wait for the sync to finish, then try again.
)}
{ if (documentSet.is_up_to_date) { - setEditPopupOpen(true); + router.push(`/admin/documents/sets/${documentSet.id}`); } }} onMouseEnter={() => { @@ -109,7 +84,6 @@ interface DocumentFeedbackTableProps { const DocumentSetTable = ({ documentSets, - ccPairs, refresh, setPopup, }: DocumentFeedbackTableProps) => { @@ -146,12 +120,7 @@ const DocumentSetTable = ({
- +
@@ -237,7 +206,6 @@ const DocumentSetTable = ({ }; const Main = () => { - const [isOpen, setIsOpen] = useState(false); const { popup, setPopup } = usePopup(); const { data: documentSets, @@ -278,14 +246,11 @@ const Main = () => {
- + + +
{documentSets.length > 0 && ( @@ -299,17 +264,6 @@ const Main = () => { /> )} - - {isOpen && ( - { - refreshDocumentSets(); - setIsOpen(false); - }} - setPopup={setPopup} - /> - )}
); }; diff --git a/web/src/app/admin/personas/PersonaEditor.tsx b/web/src/app/admin/personas/PersonaEditor.tsx index 375cd6520f..6ce77edb56 100644 --- a/web/src/app/admin/personas/PersonaEditor.tsx +++ b/web/src/app/admin/personas/PersonaEditor.tsx @@ -1,6 +1,6 @@ "use client"; -import { DocumentSet } from "@/lib/types"; +import { DocumentSet, UserGroup } from "@/lib/types"; import { Button, Divider, Text } from "@tremor/react"; import { ArrayHelpers, @@ -25,6 +25,10 @@ import { } from "@/components/admin/connectors/Field"; import { HidableSection } from "./HidableSection"; import { FiPlus, FiX } from "react-icons/fi"; +import { EE_ENABLED } from "@/lib/constants"; +import { useUserGroups } from "@/lib/hooks"; +import { Bubble } from "@/components/Bubble"; +import { GroupsIcon } from "@/components/icons/icons"; function Label({ children }: { children: string | JSX.Element }) { return ( @@ -50,6 +54,9 @@ export function PersonaEditor({ const router = useRouter(); const { popup, setPopup } = usePopup(); + // EE only + const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups(); + const [finalPrompt, setFinalPrompt] = useState(""); const [finalPromptError, setFinalPromptError] = useState(""); @@ -92,6 +99,7 @@ export function PersonaEditor({ system_prompt: existingPrompt?.system_prompt ?? "", task_prompt: existingPrompt?.task_prompt ?? "", disable_retrieval: (existingPersona?.num_chunks ?? 10) === 0, + is_public: existingPersona?.is_public ?? true, document_set_ids: existingPersona?.document_sets?.map( (documentSet) => documentSet.id @@ -103,6 +111,8 @@ export function PersonaEditor({ llm_model_version_override: existingPersona?.llm_model_version_override ?? null, starter_messages: existingPersona?.starter_messages ?? [], + // EE Only + groups: existingPersona?.groups ?? [], }} validationSchema={Yup.object() .shape({ @@ -113,6 +123,7 @@ export function PersonaEditor({ system_prompt: Yup.string(), task_prompt: Yup.string(), disable_retrieval: Yup.boolean().required(), + is_public: Yup.boolean().required(), document_set_ids: Yup.array().of(Yup.number()), num_chunks: Yup.number().max(20).nullable(), include_citations: Yup.boolean().required(), @@ -125,6 +136,8 @@ export function PersonaEditor({ message: Yup.string().required(), }) ), + // EE Only + groups: Yup.array().of(Yup.number()), }) .test( "system-prompt-or-task-prompt", @@ -163,6 +176,9 @@ export function PersonaEditor({ ? 0 : values.num_chunks || 10; + // don't set groups if marked as public + const groups = values.is_public ? [] : values.groups; + let promptResponse; let personaResponse; if (isUpdate) { @@ -171,11 +187,13 @@ export function PersonaEditor({ existingPromptId: existingPrompt?.id, ...values, num_chunks: numChunks, + groups, }); } else { [promptResponse, personaResponse] = await createPersona({ ...values, num_chunks: numChunks, + groups, }); } @@ -375,6 +393,67 @@ export function PersonaEditor({ )} + {EE_ENABLED && userGroups && ( + <> + + <> + + + {userGroups && + userGroups.length > 0 && + !values.is_public && ( +
+ + Select which User Groups should have access to + this Persona. + +
+ {userGroups.map((userGroup) => { + const isSelected = values.groups.includes( + userGroup.id + ); + return ( + { + if (isSelected) { + setFieldValue( + "groups", + values.groups.filter( + (id) => id !== userGroup.id + ) + ); + } else { + setFieldValue("groups", [ + ...values.groups, + userGroup.id, + ]); + } + }} + > +
+ +
+ {userGroup.name} +
+
+
+ ); + })} +
+
+ )} + +
+ + + )} + {llmOverrideOptions.length > 0 && defaultLLM && ( <> diff --git a/web/src/app/admin/personas/interfaces.ts b/web/src/app/admin/personas/interfaces.ts index 0d79fc744d..0e800b2cc7 100644 --- a/web/src/app/admin/personas/interfaces.ts +++ b/web/src/app/admin/personas/interfaces.ts @@ -23,6 +23,7 @@ export interface Persona { name: string; shared: boolean; is_visible: boolean; + is_public: boolean; display_priority: number | null; description: string; document_sets: DocumentSet[]; @@ -33,4 +34,5 @@ export interface Persona { llm_model_version_override?: string; starter_messages: StarterMessage[] | null; default_persona: boolean; + groups: number[]; } diff --git a/web/src/app/admin/personas/lib.ts b/web/src/app/admin/personas/lib.ts index e186149b1a..49dfbec6b3 100644 --- a/web/src/app/admin/personas/lib.ts +++ b/web/src/app/admin/personas/lib.ts @@ -8,9 +8,11 @@ interface PersonaCreationRequest { document_set_ids: number[]; num_chunks: number | null; include_citations: boolean; + is_public: boolean; llm_relevance_filter: boolean | null; llm_model_version_override: string | null; starter_messages: StarterMessage[] | null; + groups: number[]; } interface PersonaUpdateRequest { @@ -23,9 +25,11 @@ interface PersonaUpdateRequest { document_set_ids: number[]; num_chunks: number | null; include_citations: boolean; + is_public: boolean; llm_relevance_filter: boolean | null; llm_model_version_override: string | null; starter_messages: StarterMessage[] | null; + groups: number[]; } function promptNameFromPersonaName(personaName: string) { @@ -98,6 +102,8 @@ function buildPersonaAPIBody( document_set_ids, num_chunks, llm_relevance_filter, + is_public, + groups, } = creationRequest; return { @@ -107,11 +113,13 @@ function buildPersonaAPIBody( num_chunks, llm_relevance_filter, llm_filter_extraction: false, + is_public, recency_bias: "base_decay", prompt_ids: [promptId], document_set_ids, llm_model_version_override: creationRequest.llm_model_version_override, starter_messages: creationRequest.starter_messages, + groups, }; } diff --git a/web/src/app/admin/personas/page.tsx b/web/src/app/admin/personas/page.tsx index b4e42f579f..1f78d6cfdd 100644 --- a/web/src/app/admin/personas/page.tsx +++ b/web/src/app/admin/personas/page.tsx @@ -9,7 +9,7 @@ import { RobotIcon } from "@/components/icons/icons"; import { AdminPageTitle } from "@/components/admin/Title"; export default async function Page() { - const personaResponse = await fetchSS("/persona"); + const personaResponse = await fetchSS("/admin/persona"); if (!personaResponse.ok) { return ( diff --git a/web/src/components/Bubble.tsx b/web/src/components/Bubble.tsx index 316611f188..4cd1170ea9 100644 --- a/web/src/components/Bubble.tsx +++ b/web/src/components/Bubble.tsx @@ -4,7 +4,7 @@ export function Bubble({ children, }: { isSelected: boolean; - onClick: () => void; + onClick?: () => void; children: string | JSX.Element; }) { return ( diff --git a/web/src/components/icons/icons.tsx b/web/src/components/icons/icons.tsx index e5ae456cfb..e847d1dd38 100644 --- a/web/src/components/icons/icons.tsx +++ b/web/src/components/icons/icons.tsx @@ -36,6 +36,7 @@ import { FiCpu, FiInfo, FiUploadCloud, + FiUsers, } from "react-icons/fi"; import { SiBookstack } from "react-icons/si"; import Image from "next/image"; @@ -89,6 +90,13 @@ export const UsersIcon = ({ return ; }; +export const GroupsIcon = ({ + size = 16, + className = defaultTailwindCSS, +}: IconProps) => { + return ; +}; + export const GearIcon = ({ size = 16, className = defaultTailwindCSS, diff --git a/web/src/lib/constants.ts b/web/src/lib/constants.ts index 6d43d9c0f0..67346119c6 100644 --- a/web/src/lib/constants.ts +++ b/web/src/lib/constants.ts @@ -16,3 +16,10 @@ export const GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME = export const SEARCH_TYPE_COOKIE_NAME = "search_type"; export const HEADER_PADDING = "pt-[64px]"; + +// NOTE: since this is a `NEXT_PUBLIC_` variable, it will be set at +// build-time +// TODO: consider moving this to an API call so that the api_server +// can be the single source of truth +export const EE_ENABLED = + process.env.NEXT_PUBLIC_EE_ENABLED?.toLowerCase() === "true"; diff --git a/web/src/lib/hooks.ts b/web/src/lib/hooks.ts index d3e75ad63d..8d65f365cb 100644 --- a/web/src/lib/hooks.ts +++ b/web/src/lib/hooks.ts @@ -3,9 +3,11 @@ import { Credential, DocumentBoostStatus, Tag, + User, + UserGroup, } from "@/lib/types"; import useSWR, { mutate, useSWRConfig } from "swr"; -import { fetcher } from "./fetcher"; +import { errorHandlingFetcher, fetcher } from "./fetcher"; import { useState } from "react"; import { DateRangePickerValue } from "@tremor/react"; import { SourceMetadata } from "./search/interfaces"; @@ -95,3 +97,28 @@ export function useFilters() { setSelectedTags, }; } + +export const useUsers = () => { + const url = "/api/manage/users"; + const swrResponse = useSWR(url, errorHandlingFetcher); + + return { + ...swrResponse, + refreshIndexingStatus: () => mutate(url), + }; +}; + +/* +EE Only APIs +*/ + +const USER_GROUP_URL = "/api/manage/admin/user-group"; + +export const useUserGroups = () => { + const swrResponse = useSWR(USER_GROUP_URL, errorHandlingFetcher); + + return { + ...swrResponse, + refreshUserGroups: () => mutate(USER_GROUP_URL), + }; +}; diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index 1f79bcb3d1..5c46af9d9d 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -42,6 +42,7 @@ export type ValidStatuses = | "in_progress" | "not_started"; export type TaskStatus = "PENDING" | "STARTED" | "SUCCESS" | "FAILURE"; +export type Feedback = "like" | "dislike"; export interface DocumentBoostStatus { document_id: string; @@ -348,6 +349,9 @@ export interface DocumentSet { description: string; cc_pair_descriptors: CCPairDescriptor[]; is_up_to_date: boolean; + is_public: boolean; + users: string[]; + groups: number[]; } export interface Tag { @@ -384,3 +388,15 @@ export interface SlackBotTokens { bot_token: string; app_token: string; } + +/* EE Only Types */ +export interface UserGroup { + id: number; + name: string; + users: User[]; + cc_pairs: CCPairDescriptor[]; + document_sets: DocumentSet[]; + personas: Persona[]; + is_up_to_date: boolean; + is_up_for_deletion: boolean; +} From aaa7b26a4db26ce4263dde5ec4901625831d8461 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Fri, 22 Mar 2024 23:01:05 -0700 Subject: [PATCH 08/58] Remove All Enums from Postgres (#1247) --- .../776b3bbe9092_remove_remaining_enums.py | 71 +++++++++++++++++++ backend/danswer/db/models.py | 32 ++++++--- 2 files changed, 93 insertions(+), 10 deletions(-) create mode 100644 backend/alembic/versions/776b3bbe9092_remove_remaining_enums.py diff --git a/backend/alembic/versions/776b3bbe9092_remove_remaining_enums.py b/backend/alembic/versions/776b3bbe9092_remove_remaining_enums.py new file mode 100644 index 0000000000..272335ca07 --- /dev/null +++ b/backend/alembic/versions/776b3bbe9092_remove_remaining_enums.py @@ -0,0 +1,71 @@ +"""Remove Remaining Enums + +Revision ID: 776b3bbe9092 +Revises: 4738e4b3bae1 +Create Date: 2024-03-22 21:34:27.629444 + +""" +from alembic import op +import sqlalchemy as sa + +from danswer.db.models import IndexModelStatus +from danswer.search.models import RecencyBiasSetting +from danswer.search.models import SearchType + +# revision identifiers, used by Alembic. +revision = "776b3bbe9092" +down_revision = "4738e4b3bae1" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.alter_column( + "persona", + "search_type", + type_=sa.String, + existing_type=sa.Enum(SearchType, native_enum=False), + existing_nullable=False, + ) + op.alter_column( + "persona", + "recency_bias", + type_=sa.String, + existing_type=sa.Enum(RecencyBiasSetting, native_enum=False), + existing_nullable=False, + ) + + # Because the indexmodelstatus enum does not have a mapping to a string type + # we need this workaround instead of directly changing the type + op.add_column("embedding_model", sa.Column("temp_status", sa.String)) + op.execute("UPDATE embedding_model SET temp_status = status::text") + op.drop_column("embedding_model", "status") + op.alter_column("embedding_model", "temp_status", new_column_name="status") + + op.execute("DROP TYPE IF EXISTS searchtype") + op.execute("DROP TYPE IF EXISTS recencybiassetting") + op.execute("DROP TYPE IF EXISTS indexmodelstatus") + + +def downgrade() -> None: + op.alter_column( + "persona", + "search_type", + type_=sa.Enum(SearchType, native_enum=False), + existing_type=sa.String(length=50), + existing_nullable=False, + ) + op.alter_column( + "persona", + "recency_bias", + type_=sa.Enum(RecencyBiasSetting, native_enum=False), + existing_type=sa.String(length=50), + existing_nullable=False, + ) + op.alter_column( + "embedding_model", + "status", + type_=sa.Enum(IndexModelStatus, native_enum=False), + existing_type=sa.String(length=50), + existing_nullable=False, + ) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index cea33f52ad..abe189c45e 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -242,7 +242,7 @@ class ConnectorCredentialPair(Base): DateTime(timezone=True), default=None ) last_attempt_status: Mapped[IndexingStatus | None] = mapped_column( - Enum(IndexingStatus) + Enum(IndexingStatus, native_enum=False) ) total_docs_indexed: Mapped[int] = mapped_column(Integer, default=0) @@ -309,7 +309,9 @@ class Tag(Base): id: Mapped[int] = mapped_column(primary_key=True) tag_key: Mapped[str] = mapped_column(String) tag_value: Mapped[str] = mapped_column(String) - source: Mapped[DocumentSource] = mapped_column(Enum(DocumentSource)) + source: Mapped[DocumentSource] = mapped_column( + Enum(DocumentSource, native_enum=False) + ) documents = relationship( "Document", @@ -396,7 +398,9 @@ class EmbeddingModel(Base): normalize: Mapped[bool] = mapped_column(Boolean) query_prefix: Mapped[str] = mapped_column(String) passage_prefix: Mapped[str] = mapped_column(String) - status: Mapped[IndexModelStatus] = mapped_column(Enum(IndexModelStatus)) + status: Mapped[IndexModelStatus] = mapped_column( + Enum(IndexModelStatus, native_enum=False) + ) index_name: Mapped[str] = mapped_column(String) index_attempts: Mapped[List["IndexAttempt"]] = relationship( @@ -441,7 +445,9 @@ class IndexAttempt(Base): # This is only for attempts that are explicitly marked as from the start via # the run once API from_beginning: Mapped[bool] = mapped_column(Boolean) - status: Mapped[IndexingStatus] = mapped_column(Enum(IndexingStatus)) + status: Mapped[IndexingStatus] = mapped_column( + Enum(IndexingStatus, native_enum=False) + ) # The two below may be slightly out of sync if user switches Embedding Model new_docs_indexed: Mapped[int | None] = mapped_column(Integer, default=0) total_docs_indexed: Mapped[int | None] = mapped_column(Integer, default=0) @@ -544,7 +550,9 @@ class SearchDoc(Base): link: Mapped[str | None] = mapped_column(String, nullable=True) blurb: Mapped[str] = mapped_column(String) boost: Mapped[int] = mapped_column(Integer) - source_type: Mapped[DocumentSource] = mapped_column(Enum(DocumentSource)) + source_type: Mapped[DocumentSource] = mapped_column( + Enum(DocumentSource, native_enum=False) + ) hidden: Mapped[bool] = mapped_column(Boolean) doc_metadata: Mapped[dict[str, str | list[str]]] = mapped_column(postgresql.JSONB()) score: Mapped[float] = mapped_column(Float) @@ -617,7 +625,9 @@ class ChatMessage(Base): # If prompt is None, then token_count is 0 as this message won't be passed into # the LLM's context (not included in the history of messages) token_count: Mapped[int] = mapped_column(Integer) - message_type: Mapped[MessageType] = mapped_column(Enum(MessageType)) + message_type: Mapped[MessageType] = mapped_column( + Enum(MessageType, native_enum=False) + ) # Maps the citation numbers to a SearchDoc id citations: Mapped[dict[int, int]] = mapped_column(postgresql.JSONB(), nullable=True) # Only applies for LLM @@ -656,7 +666,7 @@ class DocumentRetrievalFeedback(Base): document_rank: Mapped[int] = mapped_column(Integer) clicked: Mapped[bool] = mapped_column(Boolean, default=False) feedback: Mapped[SearchFeedbackType | None] = mapped_column( - Enum(SearchFeedbackType), nullable=True + Enum(SearchFeedbackType, native_enum=False), nullable=True ) chat_message: Mapped[ChatMessage] = relationship( @@ -768,7 +778,7 @@ class Persona(Base): description: Mapped[str] = mapped_column(String) # Currently stored but unused, all flows use hybrid search_type: Mapped[SearchType] = mapped_column( - Enum(SearchType), default=SearchType.HYBRID + Enum(SearchType, native_enum=False), default=SearchType.HYBRID ) # Number of chunks to pass to the LLM for generation. num_chunks: Mapped[float | None] = mapped_column(Float, nullable=True) @@ -778,7 +788,9 @@ class Persona(Base): # Enables using LLM to extract time and source type filters # Can also be admin disabled globally llm_filter_extraction: Mapped[bool] = mapped_column(Boolean) - recency_bias: Mapped[RecencyBiasSetting] = mapped_column(Enum(RecencyBiasSetting)) + recency_bias: Mapped[RecencyBiasSetting] = mapped_column( + Enum(RecencyBiasSetting, native_enum=False) + ) # Allows the Persona to specify a different LLM version than is controlled # globablly via env variables. For flexibility, validity is not currently enforced # NOTE: only is applied on the actual response generation - is not used for things like @@ -891,7 +903,7 @@ class TaskQueueState(Base): # For any job type, this would be the same task_name: Mapped[str] = mapped_column(String) # Note that if the task dies, this won't necessarily be marked FAILED correctly - status: Mapped[TaskStatus] = mapped_column(Enum(TaskStatus)) + status: Mapped[TaskStatus] = mapped_column(Enum(TaskStatus, native_enum=False)) start_time: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True) ) From 920d059da553dfd7e45eb4076c1b35f41292e8c3 Mon Sep 17 00:00:00 2001 From: Arthur De Kimpe Date: Sat, 23 Mar 2024 22:04:26 +0100 Subject: [PATCH 09/58] Bugfix: Support more Confluence Cloud hostname (*.jira.com) (#1244) --- backend/danswer/connectors/confluence/connector.py | 2 +- web/src/app/admin/connectors/confluence/page.tsx | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/danswer/connectors/confluence/connector.py b/backend/danswer/connectors/confluence/connector.py index 9b25524d6a..013948351b 100644 --- a/backend/danswer/connectors/confluence/connector.py +++ b/backend/danswer/connectors/confluence/connector.py @@ -75,7 +75,7 @@ def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, st def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, bool]: - is_confluence_cloud = ".atlassian.net/wiki/spaces/" in wiki_url + is_confluence_cloud = ".atlassian.net/wiki/spaces/" in wiki_url or ".jira.com/wiki/spaces/" in wiki_url try: if is_confluence_cloud: diff --git a/web/src/app/admin/connectors/confluence/page.tsx b/web/src/app/admin/connectors/confluence/page.tsx index 2e1ded53a8..32c50c0a72 100644 --- a/web/src/app/admin/connectors/confluence/page.tsx +++ b/web/src/app/admin/connectors/confluence/page.tsx @@ -43,7 +43,7 @@ const extractSpaceFromDataCenterUrl = (wikiUrl: string): string => { // Copied from the `extract_confluence_keys_from_url` function const extractSpaceFromUrl = (wikiUrl: string): string | null => { try { - if (wikiUrl.includes(".atlassian.net/wiki/spaces/")) { + if (wikiUrl.includes(".atlassian.net/wiki/spaces/") || wikiUrl.includes(".jira.com/wiki/spaces/")) { return extractSpaceFromCloudUrl(wikiUrl); } return extractSpaceFromDataCenterUrl(wikiUrl); From b8f767adf2a8d52b3ccf49c35921a8b8eeea106e Mon Sep 17 00:00:00 2001 From: Weves Date: Sat, 23 Mar 2024 14:54:23 -0700 Subject: [PATCH 10/58] Fix persona client side error --- web/src/app/admin/personas/PersonaTable.tsx | 6 +++++- web/src/app/admin/personas/[personaId]/page.tsx | 3 --- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/web/src/app/admin/personas/PersonaTable.tsx b/web/src/app/admin/personas/PersonaTable.tsx index bce8e29851..f2aef72277 100644 --- a/web/src/app/admin/personas/PersonaTable.tsx +++ b/web/src/app/admin/personas/PersonaTable.tsx @@ -79,7 +79,11 @@ export function PersonasTable({ personas }: { personas: Persona[] }) { {!persona.default_persona && ( router.push(`/admin/personas/${persona.id}`)} + onClick={() => + router.push( + `/admin/personas/${persona.id}?u=${Date.now()}` + ) + } /> )}

diff --git a/web/src/app/admin/personas/[personaId]/page.tsx b/web/src/app/admin/personas/[personaId]/page.tsx index 0b52131927..e9102d219b 100644 --- a/web/src/app/admin/personas/[personaId]/page.tsx +++ b/web/src/app/admin/personas/[personaId]/page.tsx @@ -6,7 +6,6 @@ import { DocumentSet } from "@/lib/types"; import { BackButton } from "@/components/BackButton"; import { Card, Title } from "@tremor/react"; import { DeletePersonaButton } from "./DeletePersonaButton"; -import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh"; export default async function Page({ params, @@ -68,8 +67,6 @@ export default async function Page({ return (

- -

Edit Persona

From 12e8fd852c0b3556c448599fedec88aca9191511 Mon Sep 17 00:00:00 2001 From: Arnaud Ritti <77437157+arnaud-ritti@users.noreply.github.com> Date: Mon, 25 Mar 2024 02:37:27 +0100 Subject: [PATCH 11/58] feat: add Helm chart (#1186) --- deployment/helm/.gitignore | 3 + deployment/helm/.helmignore | 23 ++ deployment/helm/Chart.lock | 6 + deployment/helm/Chart.yaml | 24 ++ deployment/helm/templates/NOTES.txt | 22 + deployment/helm/templates/_helpers.tpl | 62 +++ deployment/helm/templates/api-deployment.yaml | 110 +++++ deployment/helm/templates/api-hpa.yaml | 32 ++ deployment/helm/templates/api-service.yaml | 15 + .../helm/templates/background-deployment.yaml | 100 +++++ deployment/helm/templates/background-hpa.yaml | 32 ++ deployment/helm/templates/configmap.yaml | 11 + deployment/helm/templates/connector-pvc.yaml | 19 + deployment/helm/templates/dynamic-pvc.yaml | 19 + deployment/helm/templates/ingress.yaml | 60 +++ deployment/helm/templates/secret.yaml | 10 + deployment/helm/templates/serviceaccount.yaml | 13 + .../helm/templates/tests/test-connection.yaml | 15 + deployment/helm/templates/vespa-service.yaml | 23 ++ .../helm/templates/vespa-statefulset.yaml | 83 ++++ .../helm/templates/webserver-deployment.yaml | 93 +++++ deployment/helm/templates/webserver-hpa.yaml | 32 ++ .../helm/templates/webserver-service.yaml | 15 + deployment/helm/values.yaml | 377 ++++++++++++++++++ 24 files changed, 1199 insertions(+) create mode 100644 deployment/helm/.gitignore create mode 100644 deployment/helm/.helmignore create mode 100644 deployment/helm/Chart.lock create mode 100644 deployment/helm/Chart.yaml create mode 100644 deployment/helm/templates/NOTES.txt create mode 100644 deployment/helm/templates/_helpers.tpl create mode 100644 deployment/helm/templates/api-deployment.yaml create mode 100644 deployment/helm/templates/api-hpa.yaml create mode 100644 deployment/helm/templates/api-service.yaml create mode 100644 deployment/helm/templates/background-deployment.yaml create mode 100644 deployment/helm/templates/background-hpa.yaml create mode 100755 deployment/helm/templates/configmap.yaml create mode 100644 deployment/helm/templates/connector-pvc.yaml create mode 100644 deployment/helm/templates/dynamic-pvc.yaml create mode 100644 deployment/helm/templates/ingress.yaml create mode 100755 deployment/helm/templates/secret.yaml create mode 100644 deployment/helm/templates/serviceaccount.yaml create mode 100644 deployment/helm/templates/tests/test-connection.yaml create mode 100644 deployment/helm/templates/vespa-service.yaml create mode 100644 deployment/helm/templates/vespa-statefulset.yaml create mode 100644 deployment/helm/templates/webserver-deployment.yaml create mode 100644 deployment/helm/templates/webserver-hpa.yaml create mode 100644 deployment/helm/templates/webserver-service.yaml create mode 100644 deployment/helm/values.yaml diff --git a/deployment/helm/.gitignore b/deployment/helm/.gitignore new file mode 100644 index 0000000000..b442275d6b --- /dev/null +++ b/deployment/helm/.gitignore @@ -0,0 +1,3 @@ +### Helm ### +# Chart dependencies +**/charts/*.tgz diff --git a/deployment/helm/.helmignore b/deployment/helm/.helmignore new file mode 100644 index 0000000000..0e8a0eb36f --- /dev/null +++ b/deployment/helm/.helmignore @@ -0,0 +1,23 @@ +# Patterns to ignore when building packages. +# This supports shell glob matching, relative path matching, and +# negation (prefixed with !). Only one pattern per line. +.DS_Store +# Common VCS dirs +.git/ +.gitignore +.bzr/ +.bzrignore +.hg/ +.hgignore +.svn/ +# Common backup files +*.swp +*.bak +*.tmp +*.orig +*~ +# Various IDEs +.project +.idea/ +*.tmproj +.vscode/ diff --git a/deployment/helm/Chart.lock b/deployment/helm/Chart.lock new file mode 100644 index 0000000000..7486bf317f --- /dev/null +++ b/deployment/helm/Chart.lock @@ -0,0 +1,6 @@ +dependencies: +- name: postgresql + repository: https://charts.bitnami.com/bitnami + version: 14.1.0 +digest: sha256:526d286ca7143959104d8a7f9b196706efdbd89dcc37943a1b54016f224d4b4d +generated: "2024-02-16T12:21:42.36744+01:00" diff --git a/deployment/helm/Chart.yaml b/deployment/helm/Chart.yaml new file mode 100644 index 0000000000..a36131be12 --- /dev/null +++ b/deployment/helm/Chart.yaml @@ -0,0 +1,24 @@ +apiVersion: v2 +name: danswer-stack +description: A Helm chart for Kubernetes +home: https://www.danswer.ai/ +sources: + - "https://github.com/danswer-ai/danswer" +type: application +version: 0.1.0 +appVersion: "v0.3.42" +annotations: + category: Productivity + licenses: MIT + images: | + - name: webserver + image: docker.io/danswer/danswer-web-server:v0.3.42 + - name: background + image: docker.io/danswer/danswer-backend:v0.3.42 + - name: vespa + image: vespaengine/vespa:8.277.17 +dependencies: + - name: postgresql + version: "14.1.0" + repository: https://charts.bitnami.com/bitnami + condition: postgresql.enabled \ No newline at end of file diff --git a/deployment/helm/templates/NOTES.txt b/deployment/helm/templates/NOTES.txt new file mode 100644 index 0000000000..41703407b6 --- /dev/null +++ b/deployment/helm/templates/NOTES.txt @@ -0,0 +1,22 @@ +1. Get the application URL by running these commands: +{{- if .Values.ingress.enabled }} +{{- range $host := .Values.ingress.hosts }} + {{- range .paths }} + http{{ if $.Values.ingress.tls }}s{{ end }}://{{ $host.host }}{{ .path }} + {{- end }} +{{- end }} +{{- else if contains "NodePort" .Values.webserver.service.type }} + export NODE_PORT=$(kubectl get --namespace {{ .Release.Namespace }} -o jsonpath="{.spec.ports[0].nodePort}" services {{ include "danswer-stack.fullname" . }}) + export NODE_IP=$(kubectl get nodes --namespace {{ .Release.Namespace }} -o jsonpath="{.items[0].status.addresses[0].address}") + echo http://$NODE_IP:$NODE_PORT +{{- else if contains "LoadBalancer" .Values.webserver.service.type }} + NOTE: It may take a few minutes for the LoadBalancer IP to be available. + You can watch the status of by running 'kubectl get --namespace {{ .Release.Namespace }} svc -w {{ include "danswer-stack.fullname" . }}' + export SERVICE_IP=$(kubectl get svc --namespace {{ .Release.Namespace }} {{ include "danswer-stack.fullname" . }} --template "{{"{{ range (index .status.loadBalancer.ingress 0) }}{{.}}{{ end }}"}}") + echo http://$SERVICE_IP:{{ .Values.webserver.service.port }} +{{- else if contains "ClusterIP" .Values.webserver.service.type }} + export POD_NAME=$(kubectl get pods --namespace {{ .Release.Namespace }} -l "app.kubernetes.io/name={{ include "danswer-stack.name" . }},app.kubernetes.io/instance={{ .Release.Name }}" -o jsonpath="{.items[0].metadata.name}") + export CONTAINER_PORT=$(kubectl get pod --namespace {{ .Release.Namespace }} $POD_NAME -o jsonpath="{.spec.containers[0].ports[0].containerPort}") + echo "Visit http://127.0.0.1:8080 to use your application" + kubectl --namespace {{ .Release.Namespace }} port-forward $POD_NAME 8080:$CONTAINER_PORT +{{- end }} diff --git a/deployment/helm/templates/_helpers.tpl b/deployment/helm/templates/_helpers.tpl new file mode 100644 index 0000000000..4e6672fd67 --- /dev/null +++ b/deployment/helm/templates/_helpers.tpl @@ -0,0 +1,62 @@ +{{/* +Expand the name of the chart. +*/}} +{{- define "danswer-stack.name" -}} +{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Create a default fully qualified app name. +We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec). +If release name contains chart name it will be used as a full name. +*/}} +{{- define "danswer-stack.fullname" -}} +{{- if .Values.fullnameOverride }} +{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }} +{{- else }} +{{- $name := default .Chart.Name .Values.nameOverride }} +{{- if contains $name .Release.Name }} +{{- .Release.Name | trunc 63 | trimSuffix "-" }} +{{- else }} +{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }} +{{- end }} +{{- end }} +{{- end }} + +{{/* +Create chart name and version as used by the chart label. +*/}} +{{- define "danswer-stack.chart" -}} +{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Common labels +*/}} +{{- define "danswer-stack.labels" -}} +helm.sh/chart: {{ include "danswer-stack.chart" . }} +{{ include "danswer-stack.selectorLabels" . }} +{{- if .Chart.AppVersion }} +app.kubernetes.io/version: {{ .Chart.AppVersion | quote }} +{{- end }} +app.kubernetes.io/managed-by: {{ .Release.Service }} +{{- end }} + +{{/* +Selector labels +*/}} +{{- define "danswer-stack.selectorLabels" -}} +app.kubernetes.io/name: {{ include "danswer-stack.name" . }} +app.kubernetes.io/instance: {{ .Release.Name }} +{{- end }} + +{{/* +Create the name of the service account to use +*/}} +{{- define "danswer-stack.serviceAccountName" -}} +{{- if .Values.serviceAccount.create }} +{{- default (include "danswer-stack.fullname" .) .Values.serviceAccount.name }} +{{- else }} +{{- default "default" .Values.serviceAccount.name }} +{{- end }} +{{- end }} diff --git a/deployment/helm/templates/api-deployment.yaml b/deployment/helm/templates/api-deployment.yaml new file mode 100644 index 0000000000..8c40f3408c --- /dev/null +++ b/deployment/helm/templates/api-deployment.yaml @@ -0,0 +1,110 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ include "danswer-stack.fullname" . }}-api + labels: + {{- include "danswer-stack.labels" . | nindent 4 }} +spec: + {{- if not .Values.api.autoscaling.enabled }} + replicas: {{ .Values.api.replicaCount }} + {{- end }} + selector: + matchLabels: + {{- include "danswer-stack.selectorLabels" . | nindent 6 }} + template: + metadata: + {{- with .Values.api.podAnnotations }} + annotations: + {{- toYaml . | nindent 8 }} + {{- end }} + labels: + {{- include "danswer-stack.labels" . | nindent 8 }} + {{- with .Values.api.podLabels }} + {{- toYaml . | nindent 8 }} + {{- end }} + spec: + {{- with .Values.imagePullSecrets }} + imagePullSecrets: + {{- toYaml . | nindent 8 }} + {{- end }} + serviceAccountName: {{ include "danswer-stack.serviceAccountName" . }} + securityContext: + {{- toYaml .Values.api.podSecurityContext | nindent 8 }} + containers: + - name: {{ .Chart.Name }} + securityContext: + {{- toYaml .Values.api.securityContext | nindent 12 }} + image: "{{ .Values.api.image.repository }}:{{ .Values.api.image.tag | default .Chart.AppVersion }}" + imagePullPolicy: {{ .Values.api.image.pullPolicy }} + command: + - "/bin/sh" + - "-c" + - | + alembic upgrade head && + echo "Starting Danswer Api Server" && + uvicorn danswer.main:app --host 0.0.0.0 --port 8080 + ports: + - name: api-server-port + containerPort: {{ .Values.api.service.port }} + protocol: TCP + resources: + {{- toYaml .Values.api.resources | nindent 12 }} + envFrom: + - configMapRef: + name: {{ include "danswer-stack.fullname" . }} + env: + - name: INTERNAL_URL + value: {{ (list "http://" (include "danswer-stack.fullname" .) "-api:" .Values.api.service.port | join "") | quote }} + - name: VESPA_HOST + value: {{ (list (include "danswer-stack.fullname" .) "vespa" | join "-") }} + {{- if .Values.postgresql.enabled }} + - name: POSTGRES_HOST + value: {{ (list .Release.Name "postgresql" | join "-") }} + - name: POSTGRES_DB + value: {{ .Values.postgresql.auth.database }} + - name: POSTGRES_USER + value: {{ .Values.postgresql.auth.username }} + - name: POSTGRES_PASSWORD + valueFrom: + secretKeyRef: + name: {{ (list .Release.Name "postgresql" | join "-") }} + key: password + {{- end }} + volumeMounts: + - name: dynamic-storage + mountPath: /home/storage + - name: connector-storage + mountPath: /home/file_connector_storage + {{- if .Values.api.volumeMounts }} + {{- .Values.api.volumeMounts | toYaml | nindent 12}} + {{- end }} + volumes: + - name: dynamic-storage + {{- if .Values.persistence.dynamic.enabled }} + persistentVolumeClaim: + claimName: {{ .Values.persistence.dynamic.existingClaim | default (list (include "danswer-stack.fullname" .) "dynamic" | join "-") }} + {{- else }} + emptyDir: { } + {{- end }} + - name: connector-storage + {{- if .Values.persistence.connector.enabled }} + persistentVolumeClaim: + claimName: {{ .Values.persistence.connector.existingClaim | default (list (include "danswer-stack.fullname" .) "connector" | join "-") }} + {{- else }} + emptyDir: { } + {{- end }} + {{- if .Values.api.volumes }} + {{- .Values.api.volumes | toYaml | nindent 8}} + {{- end }} + {{- with .Values.api.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.api.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.api.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} diff --git a/deployment/helm/templates/api-hpa.yaml b/deployment/helm/templates/api-hpa.yaml new file mode 100644 index 0000000000..378c39715a --- /dev/null +++ b/deployment/helm/templates/api-hpa.yaml @@ -0,0 +1,32 @@ +{{- if .Values.api.autoscaling.enabled }} +apiVersion: autoscaling/v2 +kind: HorizontalPodAutoscaler +metadata: + name: {{ include "danswer-stack.fullname" . }}-api + labels: + {{- include "danswer-stack.labels" . | nindent 4 }} +spec: + scaleTargetRef: + apiVersion: apps/v1 + kind: Deployment + name: {{ include "danswer-stack.fullname" . }} + minReplicas: {{ .Values.api.autoscaling.minReplicas }} + maxReplicas: {{ .Values.api.autoscaling.maxReplicas }} + metrics: + {{- if .Values.api.autoscaling.targetCPUUtilizationPercentage }} + - type: Resource + resource: + name: cpu + target: + type: Utilization + averageUtilization: {{ .Values.api.autoscaling.targetCPUUtilizationPercentage }} + {{- end }} + {{- if .Values.api.autoscaling.targetMemoryUtilizationPercentage }} + - type: Resource + resource: + name: memory + target: + type: Utilization + averageUtilization: {{ .Values.api.autoscaling.targetMemoryUtilizationPercentage }} + {{- end }} +{{- end }} diff --git a/deployment/helm/templates/api-service.yaml b/deployment/helm/templates/api-service.yaml new file mode 100644 index 0000000000..f4e4e0be69 --- /dev/null +++ b/deployment/helm/templates/api-service.yaml @@ -0,0 +1,15 @@ +apiVersion: v1 +kind: Service +metadata: + name: {{ include "danswer-stack.fullname" . }}-api + labels: + {{- include "danswer-stack.labels" . | nindent 4 }} +spec: + type: {{ .Values.api.service.type }} + ports: + - port: {{ .Values.api.service.port }} + targetPort: api-server-port + protocol: TCP + name: api-server-port + selector: + {{- include "danswer-stack.selectorLabels" . | nindent 4 }} diff --git a/deployment/helm/templates/background-deployment.yaml b/deployment/helm/templates/background-deployment.yaml new file mode 100644 index 0000000000..59cfc52462 --- /dev/null +++ b/deployment/helm/templates/background-deployment.yaml @@ -0,0 +1,100 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ include "danswer-stack.fullname" . }}-background + labels: + {{- include "danswer-stack.labels" . | nindent 4 }} +spec: + {{- if not .Values.background.autoscaling.enabled }} + replicas: {{ .Values.background.replicaCount }} + {{- end }} + selector: + matchLabels: + {{- include "danswer-stack.selectorLabels" . | nindent 6 }} + template: + metadata: + {{- with .Values.background.podAnnotations }} + annotations: + {{- toYaml . | nindent 8 }} + {{- end }} + labels: + {{- include "danswer-stack.labels" . | nindent 8 }} + {{- with .Values.background.podLabels }} + {{- toYaml . | nindent 8 }} + {{- end }} + spec: + {{- with .Values.imagePullSecrets }} + imagePullSecrets: + {{- toYaml . | nindent 8 }} + {{- end }} + serviceAccountName: {{ include "danswer-stack.serviceAccountName" . }} + securityContext: + {{- toYaml .Values.background.podSecurityContext | nindent 8 }} + containers: + - name: {{ .Chart.Name }} + securityContext: + {{- toYaml .Values.background.securityContext | nindent 12 }} + image: "{{ .Values.background.image.repository }}:{{ .Values.background.image.tag | default .Chart.AppVersion }}" + imagePullPolicy: {{ .Values.background.image.pullPolicy }} + command: ["/usr/bin/supervisord"] + resources: + {{- toYaml .Values.background.resources | nindent 12 }} + envFrom: + - configMapRef: + name: {{ include "danswer-stack.fullname" . }} + env: + - name: INTERNAL_URL + value: {{ (list "http://" (include "danswer-stack.fullname" .) "-api:" .Values.api.service.port | join "") | quote }} + - name: VESPA_HOST + value: {{ (list (include "danswer-stack.fullname" .) "vespa" | join "-") }} + {{- if .Values.postgresql.enabled }} + - name: POSTGRES_HOST + value: {{ (list .Release.Name "postgresql" | join "-") }} + - name: POSTGRES_DB + value: {{ .Values.postgresql.auth.database }} + - name: POSTGRES_USER + value: {{ .Values.postgresql.auth.username }} + - name: POSTGRES_PASSWORD + valueFrom: + secretKeyRef: + name: {{ (list .Release.Name "postgresql" | join "-") }} + key: password + {{- end }} + volumeMounts: + - name: dynamic-storage + mountPath: /home/storage + - name: connector-storage + mountPath: /home/file_connector_storage + {{- if .Values.background.volumeMounts }} + {{- .Values.background.volumeMounts | toYaml | nindent 12}} + {{- end }} + volumes: + - name: dynamic-storage + {{- if .Values.persistence.dynamic.enabled }} + persistentVolumeClaim: + claimName: {{ .Values.persistence.dynamic.existingClaim | default (list (include "danswer-stack.fullname" .) "dynamic" | join "-") }} + {{- else }} + emptyDir: { } + {{- end }} + - name: connector-storage + {{- if .Values.persistence.connector.enabled }} + persistentVolumeClaim: + claimName: {{ .Values.persistence.connector.existingClaim | default (list (include "danswer-stack.fullname" .) "connector" | join "-") }} + {{- else }} + emptyDir: { } + {{- end }} + {{- if .Values.background.volumes }} + {{- .Values.background.volumes | toYaml | nindent 8}} + {{- end }} + {{- with .Values.background.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.background.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.background.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} diff --git a/deployment/helm/templates/background-hpa.yaml b/deployment/helm/templates/background-hpa.yaml new file mode 100644 index 0000000000..009daf10f0 --- /dev/null +++ b/deployment/helm/templates/background-hpa.yaml @@ -0,0 +1,32 @@ +{{- if .Values.background.autoscaling.enabled }} +apiVersion: autoscaling/v2 +kind: HorizontalPodAutoscaler +metadata: + name: {{ include "danswer-stack.fullname" . }}-background + labels: + {{- include "danswer-stack.labels" . | nindent 4 }} +spec: + scaleTargetRef: + apiVersion: apps/v1 + kind: Deployment + name: {{ include "danswer-stack.fullname" . }} + minReplicas: {{ .Values.background.autoscaling.minReplicas }} + maxReplicas: {{ .Values.background.autoscaling.maxReplicas }} + metrics: + {{- if .Values.background.autoscaling.targetCPUUtilizationPercentage }} + - type: Resource + resource: + name: cpu + target: + type: Utilization + averageUtilization: {{ .Values.background.autoscaling.targetCPUUtilizationPercentage }} + {{- end }} + {{- if .Values.background.autoscaling.targetMemoryUtilizationPercentage }} + - type: Resource + resource: + name: memory + target: + type: Utilization + averageUtilization: {{ .Values.background.autoscaling.targetMemoryUtilizationPercentage }} + {{- end }} +{{- end }} diff --git a/deployment/helm/templates/configmap.yaml b/deployment/helm/templates/configmap.yaml new file mode 100755 index 0000000000..a393977986 --- /dev/null +++ b/deployment/helm/templates/configmap.yaml @@ -0,0 +1,11 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ include "danswer-stack.fullname" . }} + labels: + {{- include "danswer-stack.labels" . | nindent 4 }} +data: +{{- range $key, $value := .Values.config }} + {{ $key }}: |- + {{- $value | nindent 4 }} +{{- end }} diff --git a/deployment/helm/templates/connector-pvc.yaml b/deployment/helm/templates/connector-pvc.yaml new file mode 100644 index 0000000000..41c41c3cff --- /dev/null +++ b/deployment/helm/templates/connector-pvc.yaml @@ -0,0 +1,19 @@ +{{- if and .Values.persistence.connector.enabled (not .Values.persistence.connector.existingClaim)}} +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: {{ include "danswer-stack.fullname" . }}-connector + labels: + {{- include "danswer-stack.labels" . | nindent 4 }} +spec: + accessModes: + {{- range .Values.persistence.connector.accessModes }} + - {{ . | quote }} + {{- end }} + resources: + requests: + storage: {{ .Values.persistence.connector.size | quote }} + {{- with .Values.persistence.connector.storageClassName }} + storageClassName: {{ . }} + {{- end }} +{{- end }} \ No newline at end of file diff --git a/deployment/helm/templates/dynamic-pvc.yaml b/deployment/helm/templates/dynamic-pvc.yaml new file mode 100644 index 0000000000..703b33acb5 --- /dev/null +++ b/deployment/helm/templates/dynamic-pvc.yaml @@ -0,0 +1,19 @@ +{{- if and .Values.persistence.dynamic.enabled (not .Values.persistence.dynamic.existingClaim)}} +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: {{ include "danswer-stack.fullname" . }}-dynamic + labels: + {{- include "danswer-stack.labels" . | nindent 4 }} +spec: + accessModes: + {{- range .Values.persistence.dynamic.accessModes }} + - {{ . | quote }} + {{- end }} + resources: + requests: + storage: {{ .Values.persistence.dynamic.size | quote }} + {{- with .Values.persistence.dynamic.storageClassName }} + storageClassName: {{ . }} + {{- end }} +{{- end }} \ No newline at end of file diff --git a/deployment/helm/templates/ingress.yaml b/deployment/helm/templates/ingress.yaml new file mode 100644 index 0000000000..cfbef35dd7 --- /dev/null +++ b/deployment/helm/templates/ingress.yaml @@ -0,0 +1,60 @@ +{{- if .Values.ingress.enabled -}} +{{- $fullName := include "danswer-stack.fullname" . -}} +{{- if and .Values.ingress.className (not (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion)) }} + {{- if not (hasKey .Values.ingress.annotations "kubernetes.io/ingress.class") }} + {{- $_ := set .Values.ingress.annotations "kubernetes.io/ingress.class" .Values.ingress.className}} + {{- end }} +{{- end }} +{{- if semverCompare ">=1.19-0" .Capabilities.KubeVersion.GitVersion -}} +apiVersion: networking.k8s.io/v1 +{{- else if semverCompare ">=1.14-0" .Capabilities.KubeVersion.GitVersion -}} +apiVersion: networking.k8s.io/v1beta1 +{{- else -}} +apiVersion: extensions/v1beta1 +{{- end }} +kind: Ingress +metadata: + name: {{ $fullName }} + labels: + {{- include "danswer-stack.labels" . | nindent 4 }} + {{- with .Values.ingress.annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +spec: + {{- if and .Values.ingress.className (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion) }} + ingressClassName: {{ .Values.ingress.className }} + {{- end }} + {{- if .Values.ingress.tls }} + tls: + {{- range .Values.ingress.tls }} + - hosts: + {{- range .hosts }} + - {{ . | quote }} + {{- end }} + secretName: {{ .secretName }} + {{- end }} + {{- end }} + rules: + {{- range .Values.ingress.hosts }} + - host: {{ .host | quote }} + http: + paths: + {{- range .paths }} + - path: {{ .path }} + {{- if and .pathType (semverCompare ">=1.18-0" $.Capabilities.KubeVersion.GitVersion) }} + pathType: {{ .pathType }} + {{- end }} + backend: + {{- if semverCompare ">=1.19-0" $.Capabilities.KubeVersion.GitVersion }} + service: + name: {{ (list $fullName .service) | join "-" }} + port: + number: {{ .servicePort }} + {{- else }} + serviceName: {{ (list $fullName .service) | join "-" }} + servicePort: {{ .servicePort }} + {{- end }} + {{- end }} + {{- end }} +{{- end }} diff --git a/deployment/helm/templates/secret.yaml b/deployment/helm/templates/secret.yaml new file mode 100755 index 0000000000..58bfba87d9 --- /dev/null +++ b/deployment/helm/templates/secret.yaml @@ -0,0 +1,10 @@ +apiVersion: v1 +kind: Secret +metadata: + name: {{ include "danswer-stack.fullname" . }} + labels: + {{- include "danswer-stack.labels" . | nindent 4 }} +data: +{{- range $key, $value := .Values.secrets }} + {{ $key }}: '{{ $value | b64enc }}' +{{- end }} diff --git a/deployment/helm/templates/serviceaccount.yaml b/deployment/helm/templates/serviceaccount.yaml new file mode 100644 index 0000000000..afd351217b --- /dev/null +++ b/deployment/helm/templates/serviceaccount.yaml @@ -0,0 +1,13 @@ +{{- if .Values.serviceAccount.create -}} +apiVersion: v1 +kind: ServiceAccount +metadata: + name: {{ include "danswer-stack.serviceAccountName" . }} + labels: + {{- include "danswer-stack.labels" . | nindent 4 }} + {{- with .Values.serviceAccount.annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +automountServiceAccountToken: {{ .Values.serviceAccount.automount }} +{{- end }} diff --git a/deployment/helm/templates/tests/test-connection.yaml b/deployment/helm/templates/tests/test-connection.yaml new file mode 100644 index 0000000000..60fbd1054c --- /dev/null +++ b/deployment/helm/templates/tests/test-connection.yaml @@ -0,0 +1,15 @@ +apiVersion: v1 +kind: Pod +metadata: + name: "{{ include "danswer-stack.fullname" . }}-test-connection" + labels: + {{- include "danswer-stack.labels" . | nindent 4 }} + annotations: + "helm.sh/hook": test +spec: + containers: + - name: wget + image: busybox + command: ['wget'] + args: ['{{ include "danswer-stack.fullname" . }}:{{ .Values.webserver.service.port }}'] + restartPolicy: Never diff --git a/deployment/helm/templates/vespa-service.yaml b/deployment/helm/templates/vespa-service.yaml new file mode 100644 index 0000000000..01216a2897 --- /dev/null +++ b/deployment/helm/templates/vespa-service.yaml @@ -0,0 +1,23 @@ +apiVersion: v1 +kind: Service +metadata: + name: {{ include "danswer-stack.fullname" . }}-vespa + labels: + {{- include "danswer-stack.labels" . | nindent 4 }} +spec: + type: ClusterIP + ports: + - name: vespa-tenant-port + protocol: TCP + port: 19070 + targetPort: 19070 + - name: vespa-tenant-port-2 + protocol: TCP + port: 19071 + targetPort: 19071 + - name: vespa-port + protocol: TCP + port: 8080 + targetPort: 8080 + selector: + {{- include "danswer-stack.selectorLabels" . | nindent 4 }} diff --git a/deployment/helm/templates/vespa-statefulset.yaml b/deployment/helm/templates/vespa-statefulset.yaml new file mode 100644 index 0000000000..674b52bc44 --- /dev/null +++ b/deployment/helm/templates/vespa-statefulset.yaml @@ -0,0 +1,83 @@ +apiVersion: apps/v1 +kind: StatefulSet +metadata: + name: {{ include "danswer-stack.fullname" . }}-vespa + labels: + {{- include "danswer-stack.labels" . | nindent 4 }} +spec: + replicas: {{ .Values.vespa.replicaCount }} + selector: + matchLabels: + {{- include "danswer-stack.selectorLabels" . | nindent 6 }} + template: + metadata: + {{- with .Values.vespa.podAnnotations }} + annotations: + {{- toYaml . | nindent 8 }} + {{- end }} + labels: + {{- include "danswer-stack.labels" . | nindent 8 }} + {{- with .Values.vespa.podLabels }} + {{- toYaml . | nindent 8 }} + {{- end }} + spec: + {{- with .Values.imagePullSecrets }} + imagePullSecrets: + {{- toYaml . | nindent 8 }} + {{- end }} + serviceAccountName: {{ include "danswer-stack.serviceAccountName" . }} + securityContext: + {{- toYaml .Values.vespa.podSecurityContext | nindent 8 }} + containers: + - name: {{ .Chart.Name }} + securityContext: + {{- toYaml .Values.vespa.securityContext | nindent 12 }} + image: "{{ .Values.vespa.image.repository }}:{{ .Values.vespa.image.tag }}" + imagePullPolicy: {{ .Values.vespa.image.pullPolicy }} + ports: + - containerPort: 19070 + - containerPort: 19071 + - containerPort: 8081 + livenessProbe: + httpGet: + path: /state/v1/health + port: 19071 + scheme: HTTP + readinessProbe: + httpGet: + path: /state/v1/health + port: 19071 + scheme: HTTP + resources: + {{- toYaml .Values.vespa.resources | nindent 12 }} + volumeMounts: + - name: vespa-storage + mountPath: /opt/vespa/var/ + {{- with .Values.vespa.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.vespa.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.vespa.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- if .Values.persistence.vespa.enabled }} + volumeClaimTemplates: + - metadata: + name: vespa-storage + spec: + accessModes: + {{- range .Values.persistence.vespa.accessModes }} + - {{ . | quote }} + {{- end }} + resources: + requests: + storage: {{ .Values.persistence.vespa.size | quote }} + {{- with .Values.persistence.vespa.storageClassName }} + storageClassName: {{ . }} + {{- end }} + {{- end }} \ No newline at end of file diff --git a/deployment/helm/templates/webserver-deployment.yaml b/deployment/helm/templates/webserver-deployment.yaml new file mode 100644 index 0000000000..c679e6e0a2 --- /dev/null +++ b/deployment/helm/templates/webserver-deployment.yaml @@ -0,0 +1,93 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ include "danswer-stack.fullname" . }}-webserver + labels: + {{- include "danswer-stack.labels" . | nindent 4 }} +spec: + {{- if not .Values.webserver.autoscaling.enabled }} + replicas: {{ .Values.webserver.replicaCount }} + {{- end }} + selector: + matchLabels: + {{- include "danswer-stack.selectorLabels" . | nindent 6 }} + template: + metadata: + {{- with .Values.webserver.podAnnotations }} + annotations: + {{- toYaml . | nindent 8 }} + {{- end }} + labels: + {{- include "danswer-stack.labels" . | nindent 8 }} + {{- with .Values.webserver.podLabels }} + {{- toYaml . | nindent 8 }} + {{- end }} + spec: + {{- with .Values.imagePullSecrets }} + imagePullSecrets: + {{- toYaml . | nindent 8 }} + {{- end }} + serviceAccountName: {{ include "danswer-stack.serviceAccountName" . }} + securityContext: + {{- toYaml .Values.webserver.podSecurityContext | nindent 8 }} + containers: + - name: {{ .Chart.Name }} + securityContext: + {{- toYaml .Values.webserver.securityContext | nindent 12 }} + image: "{{ .Values.webserver.image.repository }}:{{ .Values.webserver.image.tag | default .Chart.AppVersion }}" + imagePullPolicy: {{ .Values.webserver.image.pullPolicy }} + ports: + - name: http + containerPort: {{ .Values.webserver.service.port }} + protocol: TCP + livenessProbe: + httpGet: + path: / + port: http + readinessProbe: + httpGet: + path: / + port: http + resources: + {{- toYaml .Values.webserver.resources | nindent 12 }} + envFrom: + - configMapRef: + name: {{ include "danswer-stack.fullname" . }} + env: + - name: INTERNAL_URL + value: {{ (list "http://" (include "danswer-stack.fullname" .) "-api:" .Values.api.service.port | join "") | quote }} + - name: VESPA_HOST + value: {{ (list (include "danswer-stack.fullname" .) "vespa" | join "-") }} + {{- if .Values.postgresql.enabled }} + - name: POSTGRES_HOST + value: {{ (list .Release.Name "postgresql" | join "-") }} + - name: POSTGRES_DB + value: {{ .Values.postgresql.auth.database }} + - name: POSTGRES_USER + value: {{ .Values.postgresql.auth.username }} + - name: POSTGRES_PASSWORD + valueFrom: + secretKeyRef: + name: {{ (list .Release.Name "postgresql" | join "-") }} + key: password + {{- end }} + {{- with .Values.webserver.volumeMounts }} + volumeMounts: + {{- toYaml . | nindent 12 }} + {{- end }} + {{- with .Values.webserver.volumes }} + volumes: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.webserver.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.webserver.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.webserver.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} diff --git a/deployment/helm/templates/webserver-hpa.yaml b/deployment/helm/templates/webserver-hpa.yaml new file mode 100644 index 0000000000..b46820a7fa --- /dev/null +++ b/deployment/helm/templates/webserver-hpa.yaml @@ -0,0 +1,32 @@ +{{- if .Values.webserver.autoscaling.enabled }} +apiVersion: autoscaling/v2 +kind: HorizontalPodAutoscaler +metadata: + name: {{ include "danswer-stack.fullname" . }}-webserver + labels: + {{- include "danswer-stack.labels" . | nindent 4 }} +spec: + scaleTargetRef: + apiVersion: apps/v1 + kind: Deployment + name: {{ include "danswer-stack.fullname" . }} + minReplicas: {{ .Values.webserver.autoscaling.minReplicas }} + maxReplicas: {{ .Values.webserver.autoscaling.maxReplicas }} + metrics: + {{- if .Values.webserver.autoscaling.targetCPUUtilizationPercentage }} + - type: Resource + resource: + name: cpu + target: + type: Utilization + averageUtilization: {{ .Values.webserver.autoscaling.targetCPUUtilizationPercentage }} + {{- end }} + {{- if .Values.webserver.autoscaling.targetMemoryUtilizationPercentage }} + - type: Resource + resource: + name: memory + target: + type: Utilization + averageUtilization: {{ .Values.webserver.autoscaling.targetMemoryUtilizationPercentage }} + {{- end }} +{{- end }} diff --git a/deployment/helm/templates/webserver-service.yaml b/deployment/helm/templates/webserver-service.yaml new file mode 100644 index 0000000000..776b65f8f9 --- /dev/null +++ b/deployment/helm/templates/webserver-service.yaml @@ -0,0 +1,15 @@ +apiVersion: v1 +kind: Service +metadata: + name: {{ include "danswer-stack.fullname" . }}-webserver + labels: + {{- include "danswer-stack.labels" . | nindent 4 }} +spec: + type: {{ .Values.webserver.service.type }} + ports: + - port: {{ .Values.webserver.service.port }} + targetPort: http + protocol: TCP + name: http + selector: + {{- include "danswer-stack.selectorLabels" . | nindent 4 }} diff --git a/deployment/helm/values.yaml b/deployment/helm/values.yaml new file mode 100644 index 0000000000..8d994b55ff --- /dev/null +++ b/deployment/helm/values.yaml @@ -0,0 +1,377 @@ +# Default values for danswer-stack. +# This is a YAML-formatted file. +# Declare variables to be passed into your templates. + +imagePullSecrets: [] +nameOverride: "" +fullnameOverride: "" + +serviceAccount: + # Specifies whether a service account should be created + create: true + # Automatically mount a ServiceAccount's API credentials? + automount: true + # Annotations to add to the service account + annotations: {} + # The name of the service account to use. + # If not set and create is true, a name is generated using the fullname template + name: "" + + +webserver: + replicaCount: 1 + image: + repository: danswer/danswer-web-server + pullPolicy: IfNotPresent + # Overrides the image tag whose default is the chart appVersion. + tag: "" + + podAnnotations: {} + podLabels: {} + + podSecurityContext: {} + # fsGroup: 2000 + + securityContext: {} + # capabilities: + # drop: + # - ALL + # readOnlyRootFilesystem: true + # runAsNonRoot: true + # runAsUser: 1000 + + service: + type: ClusterIP + port: 3000 + + resources: {} + # We usually recommend not to specify default resources and to leave this as a conscious + # choice for the user. This also increases chances charts run on environments with little + # resources, such as Minikube. If you do want to specify resources, uncomment the following + # lines, adjust them as necessary, and remove the curly braces after 'resources:'. + # limits: + # cpu: 100m + # memory: 128Mi + # requests: + # cpu: 100m + # memory: 128Mi + + autoscaling: + enabled: false + minReplicas: 1 + maxReplicas: 100 + targetCPUUtilizationPercentage: 80 + # targetMemoryUtilizationPercentage: 80 + + # Additional volumes on the output Deployment definition. + volumes: [] + # - name: foo + # secret: + # secretName: mysecret + # optional: false + + # Additional volumeMounts on the output Deployment definition. + volumeMounts: [] + # - name: foo + # mountPath: "/etc/foo" + # readOnly: true + + nodeSelector: {} + tolerations: [] + affinity: {} + +api: + replicaCount: 1 + image: + repository: danswer/danswer-backend + pullPolicy: IfNotPresent + # Overrides the image tag whose default is the chart appVersion. + tag: "" + + podAnnotations: {} + podLabels: + scope: danswer-backend + + podSecurityContext: {} + # fsGroup: 2000 + + securityContext: {} + # capabilities: + # drop: + # - ALL + # readOnlyRootFilesystem: true + # runAsNonRoot: true + # runAsUser: 1000 + + service: + type: ClusterIP + port: 8080 + + resources: + # We usually recommend not to specify default resources and to leave this as a conscious + # choice for the user. This also increases chances charts run on environments with little + # resources, such as Minikube. If you do want to specify resources, uncomment the following + # lines, adjust them as necessary, and remove the curly braces after 'resources:'. + requests: + cpu: 1500m + memory: 2Gi + # limits: + # cpu: 100m + # memory: 128Mi + + autoscaling: + enabled: false + minReplicas: 1 + maxReplicas: 100 + targetCPUUtilizationPercentage: 80 + # targetMemoryUtilizationPercentage: 80 + + # Additional volumes on the output Deployment definition. + volumes: [] + # - name: foo + # secret: + # secretName: mysecret + # optional: false + + # Additional volumeMounts on the output Deployment definition. + volumeMounts: [] + # - name: foo + # mountPath: "/etc/foo" + # readOnly: true + + nodeSelector: {} + tolerations: [] + affinity: + podAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + - labelSelector: + matchExpressions: + - key: scope + operator: In + values: + - danswer-backend + topologyKey: "kubernetes.io/hostname" + +background: + replicaCount: 1 + image: + repository: danswer/danswer-backend + pullPolicy: IfNotPresent + # Overrides the image tag whose default is the chart appVersion. + tag: "" + podAnnotations: {} + podLabels: + scope: danswer-backend + + podSecurityContext: {} + # fsGroup: 2000 + + securityContext: {} + # capabilities: + # drop: + # - ALL + # readOnlyRootFilesystem: true + # runAsNonRoot: true + # runAsUser: 1000 + + resources: + # We usually recommend not to specify default resources and to leave this as a conscious + # choice for the user. This also increases chances charts run on environments with little + # resources, such as Minikube. If you do want to specify resources, uncomment the following + # lines, adjust them as necessary, and remove the curly braces after 'resources:'. + requests: + cpu: 2500m + memory: 5Gi + # limits: + # cpu: 100m + # memory: 128Mi + + autoscaling: + enabled: false + minReplicas: 1 + maxReplicas: 100 + targetCPUUtilizationPercentage: 80 + # targetMemoryUtilizationPercentage: 80 + + # Additional volumes on the output Deployment definition. + volumes: [] + # - name: foo + # secret: + # secretName: mysecret + # optional: false + + # Additional volumeMounts on the output Deployment definition. + volumeMounts: [] + # - name: foo + # mountPath: "/etc/foo" + # readOnly: true + + nodeSelector: {} + tolerations: [] + affinity: + podAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + - labelSelector: + matchExpressions: + - key: scope + operator: In + values: + - danswer-backend + topologyKey: "kubernetes.io/hostname" + +vespa: + replicaCount: 1 + image: + repository: vespaengine/vespa + pullPolicy: IfNotPresent + tag: "8.277.17" + podAnnotations: {} + podLabels: {} + + podSecurityContext: {} + # fsGroup: 2000 + + securityContext: + privileged: true + runAsUser: 0 + # capabilities: + # drop: + # - ALL + # readOnlyRootFilesystem: true + # runAsNonRoot: true + # runAsUser: 1000 + + resources: + # We usually recommend not to specify default resources and to leave this as a conscious + # choice for the user. This also increases chances charts run on environments with little + # resources, such as Minikube. If you do want to specify resources, uncomment the following + # lines, adjust them as necessary, and remove the curly braces after 'resources:'. + requests: + cpu: 2500m + memory: 5Gi + # limits: + # cpu: 100m + # memory: 128Mi + + nodeSelector: {} + tolerations: [] + affinity: {} + + +#ingress: +# enabled: false +# className: "" +# annotations: {} +# # kubernetes.io/ingress.class: nginx +# # kubernetes.io/tls-acme: "true" +# hosts: +# - host: chart-example.local +# paths: +# - path: / +# pathType: ImplementationSpecific +# tls: [] +# # - secretName: chart-example-tls +# # hosts: +# # - chart-example.local + +persistence: + vespa: + enabled: true + existingClaim: "" + storageClassName: "" + accessModes: + - ReadWriteOnce + size: 1Gi + connector: + enabled: true + existingClaim: "" + storageClassName: "" + accessModes: + - ReadWriteOnce + size: 1Gi + dynamic: + enabled: true + existingClaim: "" + storageClassName: "" + accessModes: + - ReadWriteOnce + size: 1Gi + +postgresql: + enabled: false + auth: + postgresPassword: "" + username: danswer + password: danswer + database: danswer + +config: + # Auth Setting, also check the secrets file + #AUTH_TYPE: "disabled" # Change this for production uses unless Danswer is only accessible behind VPN + #SESSION_EXPIRE_TIME_SECONDS: "86400" # 1 Day Default + #VALID_EMAIL_DOMAINS: "" # Can be something like danswer.ai, as an extra double-check + #SMTP_SERVER: "" # For sending verification emails, if unspecified then defaults to 'smtp.gmail.com' + #SMTP_PORT: "" # For sending verification emails, if unspecified then defaults to '587' + #SMTP_USER: "" # 'your-email@company.com' + #SMTP_PASS: "" # 'your-gmail-password' + #EMAIL_FROM: "" # 'your-email@company.com' SMTP_USER missing used instead + # Gen AI Settings + #GEN_AI_MODEL_PROVIDER: "openai" + #GEN_AI_MODEL_VERSION: "gpt-4" # "gpt-3.5-turbo-0125" # Use GPT-4 if you have it + #FAST_GEN_AI_MODEL_VERSION: "gpt-3.5-turbo-0125" + #GEN_AI_API_KEY: "" + #GEN_AI_API_ENDPOINT: "" + #GEN_AI_API_VERSION: "" + #GEN_AI_LLM_PROVIDER_TYPE: "" + #GEN_AI_MAX_TOKENS: "" + #QA_TIMEOUT: "60" + #MAX_CHUNKS_FED_TO_CHAT: "" + #DISABLE_LLM_FILTER_EXTRACTION: "" + #DISABLE_LLM_CHUNK_FILTER: "" + #DISABLE_LLM_CHOOSE_SEARCH: "" + # Query Options + #DOC_TIME_DECAY: "" + #HYBRID_ALPHA: "" + #EDIT_KEYWORD_QUERY: "" + #MULTILINGUAL_QUERY_EXPANSION: "" + #QA_PROMPT_OVERRIDE: "" + # Don't change the NLP models unless you know what you're doing + #DOCUMENT_ENCODER_MODEL: "" + #NORMALIZE_EMBEDDINGS: "" + #ASYM_QUERY_PREFIX: "" + #ASYM_PASSAGE_PREFIX: "" + #ENABLE_RERANKING_REAL_TIME_FLOW: "" + #ENABLE_RERANKING_ASYNC_FLOW: "" + #MODEL_SERVER_HOST: "" + #MODEL_SERVER_PORT: "" + #INDEXING_MODEL_SERVER_HOST: "" + #MIN_THREADS_ML_MODELS: "" + # Indexing Configs + #NUM_INDEXING_WORKERS: "" + #DASK_JOB_CLIENT_ENABLED: "" + #CONTINUE_ON_CONNECTOR_FAILURE: "" + #EXPERIMENTAL_CHECKPOINTING_ENABLED: "" + #CONFLUENCE_CONNECTOR_LABELS_TO_SKIP: "" + #GONG_CONNECTOR_START_TIME: "" + #NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP: "" + # DanswerBot SlackBot Configs + #DANSWER_BOT_SLACK_APP_TOKEN: "" + #DANSWER_BOT_SLACK_BOT_TOKEN: "" + #DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER: "" + #DANSWER_BOT_DISPLAY_ERROR_MSGS: "" + #DANSWER_BOT_RESPOND_EVERY_CHANNEL: "" + #DANSWER_BOT_DISABLE_COT: "" # Currently unused + #NOTIFY_SLACKBOT_NO_ANSWER: "" + # Logging + # Optional Telemetry, please keep it on (nothing sensitive is collected)? <3 + # https://docs.danswer.dev/more/telemetry + #DISABLE_TELEMETRY: "" + #LOG_LEVEL: "" + #LOG_ALL_MODEL_INTERACTIONS: "" + #LOG_VESPA_TIMING_INFORMATION: "" + # Shared or Non-backend Related + #INTERNAL_URL: "http://api-server-service:80" # for web server + WEB_DOMAIN: "http://localhost:3000" # for web server and api server + # Other Services + #POSTGRES_HOST: "relational-db-service" + #VESPA_HOST: "document-index-service" \ No newline at end of file From b28b3cfa377708b10ec6ed310725ddc41f48014e Mon Sep 17 00:00:00 2001 From: Arslan Date: Mon, 25 Mar 2024 02:54:38 +0100 Subject: [PATCH 12/58] Making searching docs as a default option (#904) --- backend/danswer/prompts/chat_prompts.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/backend/danswer/prompts/chat_prompts.py b/backend/danswer/prompts/chat_prompts.py index ec69639481..0d4cf9ec3a 100644 --- a/backend/danswer/prompts/chat_prompts.py +++ b/backend/danswer/prompts/chat_prompts.py @@ -49,20 +49,21 @@ SKIP_SEARCH = "Skip Search" YES_SEARCH = "Yes Search" AGGRESSIVE_SEARCH_TEMPLATE = f""" -Given the conversation history and a follow up query, determine if the system should call \ +You are an expert of a critical system. Given the conversation history and a follow up query, determine if the system should call \ an external search tool to better answer the latest user input. +Your default response is {YES_SEARCH}. +If you are even slightly unsure, respond with {YES_SEARCH}. + Respond "{SKIP_SEARCH}" if either: -- There is sufficient information in chat history to FULLY and ACCURATELY answer the query AND \ -additional information or details would provide little or no value. +- There is sufficient information in chat history to FULLY and ACCURATELY answer the query - The query is some form of request that does not require additional information to handle. - +- You are 100% absolutely sure about the question and there is no ambiguity in the answer or question. Conversation History: {GENERAL_SEP_PAT} {{chat_history}} {GENERAL_SEP_PAT} -If you are unsure, respond with {YES_SEARCH}. Respond with EXACTLY and ONLY "{YES_SEARCH}" or "{SKIP_SEARCH}" Follow Up Input: From d3674b02e6fcc86012c418ad599c2cd54e194676 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sun, 24 Mar 2024 19:01:38 -0700 Subject: [PATCH 13/58] Add Llama2 Prompt Option (#1254) --- backend/danswer/prompts/chat_prompts.py | 34 ++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/backend/danswer/prompts/chat_prompts.py b/backend/danswer/prompts/chat_prompts.py index 0d4cf9ec3a..e0b20243bc 100644 --- a/backend/danswer/prompts/chat_prompts.py +++ b/backend/danswer/prompts/chat_prompts.py @@ -48,17 +48,43 @@ # consider doing COT for this and keep it brief, but likely only small gains. SKIP_SEARCH = "Skip Search" YES_SEARCH = "Yes Search" + AGGRESSIVE_SEARCH_TEMPLATE = f""" -You are an expert of a critical system. Given the conversation history and a follow up query, determine if the system should call \ +Given the conversation history and a follow up query, determine if the system should call \ an external search tool to better answer the latest user input. +Your default response is {YES_SEARCH}. + +Respond "{SKIP_SEARCH}" if either: +- There is sufficient information in chat history to FULLY and ACCURATELY answer the query AND \ +additional information or details would provide little or no value. +- The query is some form of request that does not require additional information to handle. + +Conversation History: +{GENERAL_SEP_PAT} +{{chat_history}} +{GENERAL_SEP_PAT} + +If you are at all unsure, respond with {YES_SEARCH}. +Respond with EXACTLY and ONLY "{YES_SEARCH}" or "{SKIP_SEARCH}" + +Follow Up Input: +{{final_query}} +""".strip() + + +# TODO, templatize this so users don't need to make code changes to use this +AGGRESSIVE_SEARCH_TEMPLATE_LLAMA2 = f""" +You are an expert of a critical system. Given the conversation history and a follow up query, \ +determine if the system should call an external search tool to better answer the latest user input. Your default response is {YES_SEARCH}. If you are even slightly unsure, respond with {YES_SEARCH}. -Respond "{SKIP_SEARCH}" if either: -- There is sufficient information in chat history to FULLY and ACCURATELY answer the query +Respond "{SKIP_SEARCH}" if any of these are true: +- There is sufficient information in chat history to FULLY and ACCURATELY answer the query. - The query is some form of request that does not require additional information to handle. -- You are 100% absolutely sure about the question and there is no ambiguity in the answer or question. +- You are absolutely sure about the question and there is no ambiguity in the answer or question. + Conversation History: {GENERAL_SEP_PAT} {{chat_history}} From bd1df9649b420f42cb8a5e2ca601a76b95feab45 Mon Sep 17 00:00:00 2001 From: Matthew Holland Date: Sun, 24 Mar 2024 19:04:40 -0700 Subject: [PATCH 14/58] Added check for internet connection (#1214) --- backend/danswer/connectors/web/connector.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/backend/danswer/connectors/web/connector.py b/backend/danswer/connectors/web/connector.py index 4415a88fab..ba7f352a5f 100644 --- a/backend/danswer/connectors/web/connector.py +++ b/backend/danswer/connectors/web/connector.py @@ -1,4 +1,5 @@ import io +import socket from enum import Enum from typing import Any from typing import cast @@ -41,6 +42,15 @@ class WEB_CONNECTOR_VALID_SETTINGS(str, Enum): # Given a file upload where every line is a URL, parse all the URLs provided UPLOAD = "upload" +def check_internet_connection() -> bool: + dns_servers = [("1.1.1.1", 53), ("8.8.8.8", 53)] + for server in dns_servers: + try: + socket.create_connection(server, timeout=3) + return True + except OSError: # try the next server + continue + raise Exception("Unable to contact DNS server - check your internet connection") def is_valid_url(url: str) -> bool: try: @@ -173,6 +183,7 @@ def load_from_state(self) -> GenerateDocumentsOutput: base_url = to_visit[0] # For the recursive case doc_batch: list[Document] = [] + check_internet_connection() playwright, context = start_playwright() restart_playwright = False while to_visit: From 49263ed146e26cdef33c285c83b3a1fe400088ed Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sun, 24 Mar 2024 19:07:57 -0700 Subject: [PATCH 15/58] Linting (#1255) --- backend/danswer/connectors/confluence/connector.py | 5 ++++- backend/danswer/connectors/web/connector.py | 10 ++++++---- web/src/app/admin/connectors/confluence/page.tsx | 5 ++++- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/backend/danswer/connectors/confluence/connector.py b/backend/danswer/connectors/confluence/connector.py index 013948351b..f9f5e7c3bb 100644 --- a/backend/danswer/connectors/confluence/connector.py +++ b/backend/danswer/connectors/confluence/connector.py @@ -75,7 +75,10 @@ def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, st def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, bool]: - is_confluence_cloud = ".atlassian.net/wiki/spaces/" in wiki_url or ".jira.com/wiki/spaces/" in wiki_url + is_confluence_cloud = ( + ".atlassian.net/wiki/spaces/" in wiki_url + or ".jira.com/wiki/spaces/" in wiki_url + ) try: if is_confluence_cloud: diff --git a/backend/danswer/connectors/web/connector.py b/backend/danswer/connectors/web/connector.py index ba7f352a5f..38f30a28ed 100644 --- a/backend/danswer/connectors/web/connector.py +++ b/backend/danswer/connectors/web/connector.py @@ -42,16 +42,18 @@ class WEB_CONNECTOR_VALID_SETTINGS(str, Enum): # Given a file upload where every line is a URL, parse all the URLs provided UPLOAD = "upload" -def check_internet_connection() -> bool: + +def check_internet_connection() -> None: dns_servers = [("1.1.1.1", 53), ("8.8.8.8", 53)] for server in dns_servers: try: socket.create_connection(server, timeout=3) - return True - except OSError: # try the next server - continue + return + except OSError: + continue raise Exception("Unable to contact DNS server - check your internet connection") + def is_valid_url(url: str) -> bool: try: result = urlparse(url) diff --git a/web/src/app/admin/connectors/confluence/page.tsx b/web/src/app/admin/connectors/confluence/page.tsx index 32c50c0a72..649d8853eb 100644 --- a/web/src/app/admin/connectors/confluence/page.tsx +++ b/web/src/app/admin/connectors/confluence/page.tsx @@ -43,7 +43,10 @@ const extractSpaceFromDataCenterUrl = (wikiUrl: string): string => { // Copied from the `extract_confluence_keys_from_url` function const extractSpaceFromUrl = (wikiUrl: string): string | null => { try { - if (wikiUrl.includes(".atlassian.net/wiki/spaces/") || wikiUrl.includes(".jira.com/wiki/spaces/")) { + if ( + wikiUrl.includes(".atlassian.net/wiki/spaces/") || + wikiUrl.includes(".jira.com/wiki/spaces/") + ) { return extractSpaceFromCloudUrl(wikiUrl); } return extractSpaceFromDataCenterUrl(wikiUrl); From 3107edc921d9cf821e979506ed915bae74a007ad Mon Sep 17 00:00:00 2001 From: Johannes Vass Date: Mon, 25 Mar 2024 03:31:07 +0100 Subject: [PATCH 16/58] Do not obtain DB session via Depends() (#1238) Endpoints that use Depends(get_session) with a StreamingResponse have the problem that Depends() releases the session again after the endpoint function returns. At that point, the streaming response is not finished yet but still holds a reference to the session and uses it. However, there is no cleanup of the session after the answer stream finishes which leads to the connections accumulating in state "idle in transaction". This was due to a breaking change in FastAPI 0.106.0 https://fastapi.tiangolo.com/release-notes/#01060 Co-authored-by: Johannes Vass --- backend/danswer/db/engine.py | 6 ++++++ .../one_shot_answer/answer_question.py | 21 ++++++++++--------- .../server/query_and_chat/query_backend.py | 2 -- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 6803c51aa3..22f1193fe8 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -1,6 +1,8 @@ +import contextlib from collections.abc import AsyncGenerator from collections.abc import Generator from datetime import datetime +from typing import ContextManager from ddtrace import tracer from sqlalchemy import text @@ -70,6 +72,10 @@ def get_sqlalchemy_async_engine() -> AsyncEngine: return _ASYNC_ENGINE +def get_session_context_manager() -> ContextManager: + return contextlib.contextmanager(get_session)() + + def get_session() -> Generator[Session, None, None]: with tracer.trace("db.get_session"): with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session: diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 4f4a931aef..529180a799 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -34,6 +34,7 @@ from danswer.db.chat import get_prompt_by_id from danswer.db.chat import translate_db_message_to_chat_message_detail from danswer.db.embedding_model import get_current_db_embedding_model +from danswer.db.engine import get_session_context_manager from danswer.db.models import Prompt from danswer.db.models import User from danswer.document_index.factory import get_default_document_index @@ -418,17 +419,17 @@ def stream_search_answer( user: User | None, max_document_tokens: int | None, max_history_tokens: int | None, - db_session: Session, ) -> Iterator[str]: - objects = stream_answer_objects( - query_req=query_req, - user=user, - max_document_tokens=max_document_tokens, - max_history_tokens=max_history_tokens, - db_session=db_session, - ) - for obj in objects: - yield get_json_line(obj.dict()) + with get_session_context_manager() as session: + objects = stream_answer_objects( + query_req=query_req, + user=user, + max_document_tokens=max_document_tokens, + max_history_tokens=max_history_tokens, + db_session=session, + ) + for obj in objects: + yield get_json_line(obj.dict()) def get_search_answer( diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index 0f0e540c6b..6d8529486f 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -148,7 +148,6 @@ def stream_query_validation( def get_answer_with_quote( query_request: DirectQARequest, user: User = Depends(current_user), - db_session: Session = Depends(get_session), ) -> StreamingResponse: query = query_request.messages[0].message logger.info(f"Received query for one shot answer with quotes: {query}") @@ -157,6 +156,5 @@ def get_answer_with_quote( user=user, max_document_tokens=None, max_history_tokens=0, - db_session=db_session, ) return StreamingResponse(packets, media_type="application/json") From 7a861ecec49db677219abf7f1ae7cd658031087b Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sun, 24 Mar 2024 19:40:06 -0700 Subject: [PATCH 17/58] Session Dependency for Chat Streaming (#1256) --- backend/danswer/chat/process_message.py | 17 +++++++++-------- .../server/query_and_chat/chat_backend.py | 7 +------ 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 5ebf8ab158..aafe5d000f 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -36,6 +36,7 @@ from danswer.db.chat import translate_db_message_to_chat_message_detail from danswer.db.chat import translate_db_search_doc_to_server_search_doc from danswer.db.embedding_model import get_current_db_embedding_model +from danswer.db.engine import get_session_context_manager from danswer.db.models import ChatMessage from danswer.db.models import Persona from danswer.db.models import SearchDoc as DbSearchDoc @@ -582,12 +583,12 @@ def stream_chat_message_objects( def stream_chat_message( new_msg_req: CreateChatMessageRequest, user: User | None, - db_session: Session, ) -> Iterator[str]: - objects = stream_chat_message_objects( - new_msg_req=new_msg_req, - user=user, - db_session=db_session, - ) - for obj in objects: - yield get_json_line(obj.dict()) + with get_session_context_manager() as db_session: + objects = stream_chat_message_objects( + new_msg_req=new_msg_req, + user=user, + db_session=db_session, + ) + for obj in objects: + yield get_json_line(obj.dict()) diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 66c69fa876..a8076659c6 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -162,7 +162,6 @@ def delete_chat_session_by_id( def handle_new_chat_message( chat_message_req: CreateChatMessageRequest, user: User | None = Depends(current_user), - db_session: Session = Depends(get_session), ) -> StreamingResponse: """This endpoint is both used for all the following purposes: - Sending a new message in the session @@ -176,11 +175,7 @@ def handle_new_chat_message( if not chat_message_req.message and chat_message_req.prompt_id is not None: raise HTTPException(status_code=400, detail="Empty chat message is invalid") - packets = stream_chat_message( - new_msg_req=chat_message_req, - user=user, - db_session=db_session, - ) + packets = stream_chat_message(new_msg_req=chat_message_req, user=user) return StreamingResponse(packets, media_type="application/json") From 1ba74ee4df0717ec3484e3a87ff469c6ca2f11ea Mon Sep 17 00:00:00 2001 From: Weves Date: Sat, 23 Mar 2024 20:12:23 -0700 Subject: [PATCH 18/58] Refactor search pipeline --- .../776b3bbe9092_remove_remaining_enums.py | 2 +- backend/danswer/chat/load_yamls.py | 2 +- backend/danswer/chat/models.py | 4 +- backend/danswer/chat/process_message.py | 68 +- backend/danswer/db/chat.py | 2 +- backend/danswer/db/models.py | 4 +- backend/danswer/db/slack_bot_config.py | 2 +- backend/danswer/document_index/vespa/index.py | 4 +- .../one_shot_answer/answer_question.py | 64 +- backend/danswer/search/enums.py | 30 + backend/danswer/search/models.py | 54 +- backend/danswer/search/pipeline.py | 152 +++++ .../search/postprocessing/postprocessing.py | 222 ++++++ .../{ => preprocessing}/access_filters.py | 0 .../{ => preprocessing}/danswer_helper.py | 4 +- .../preprocessing.py} | 64 +- .../danswer/search/retrieval/search_runner.py | 256 +++++++ backend/danswer/search/search_runner.py | 645 ------------------ backend/danswer/search/utils.py | 29 + backend/danswer/server/documents/document.py | 2 +- .../danswer/server/features/persona/models.py | 2 +- backend/danswer/server/gpts/api.py | 36 +- .../server/query_and_chat/query_backend.py | 6 +- .../regression/search_quality/eval_search.py | 44 +- 24 files changed, 828 insertions(+), 870 deletions(-) create mode 100644 backend/danswer/search/enums.py create mode 100644 backend/danswer/search/pipeline.py create mode 100644 backend/danswer/search/postprocessing/postprocessing.py rename backend/danswer/search/{ => preprocessing}/access_filters.py (100%) rename backend/danswer/search/{ => preprocessing}/danswer_helper.py (96%) rename backend/danswer/search/{request_preprocessing.py => preprocessing/preprocessing.py} (76%) create mode 100644 backend/danswer/search/retrieval/search_runner.py delete mode 100644 backend/danswer/search/search_runner.py create mode 100644 backend/danswer/search/utils.py diff --git a/backend/alembic/versions/776b3bbe9092_remove_remaining_enums.py b/backend/alembic/versions/776b3bbe9092_remove_remaining_enums.py index 272335ca07..1e2e7cd3c1 100644 --- a/backend/alembic/versions/776b3bbe9092_remove_remaining_enums.py +++ b/backend/alembic/versions/776b3bbe9092_remove_remaining_enums.py @@ -9,7 +9,7 @@ import sqlalchemy as sa from danswer.db.models import IndexModelStatus -from danswer.search.models import RecencyBiasSetting +from danswer.search.enums import RecencyBiasSetting from danswer.search.models import SearchType # revision identifiers, used by Alembic. diff --git a/backend/danswer/chat/load_yamls.py b/backend/danswer/chat/load_yamls.py index 0800abb70a..ccc7544374 100644 --- a/backend/danswer/chat/load_yamls.py +++ b/backend/danswer/chat/load_yamls.py @@ -13,7 +13,7 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import DocumentSet as DocumentSetDBModel from danswer.db.models import Prompt as PromptDBModel -from danswer.search.models import RecencyBiasSetting +from danswer.search.enums import RecencyBiasSetting def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None: diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index de3f7e4f01..47d554de77 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -5,10 +5,10 @@ from pydantic import BaseModel from danswer.configs.constants import DocumentSource -from danswer.search.models import QueryFlow +from danswer.search.enums import QueryFlow +from danswer.search.enums import SearchType from danswer.search.models import RetrievalDocs from danswer.search.models import SearchResponse -from danswer.search.models import SearchType class LlmDoc(BaseModel): diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index aafe5d000f..9cd78c963b 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -53,11 +53,10 @@ from danswer.llm.utils import translate_history_to_basemessages from danswer.prompts.prompt_utils import build_doc_context_str from danswer.search.models import OptionalSearchSetting -from danswer.search.models import RetrievalDetails -from danswer.search.request_preprocessing import retrieval_preprocessing -from danswer.search.search_runner import chunks_to_search_docs -from danswer.search.search_runner import full_chunk_search_generator -from danswer.search.search_runner import inference_documents_from_ids +from danswer.search.models import SearchRequest +from danswer.search.pipeline import SearchPipeline +from danswer.search.retrieval.search_runner import inference_documents_from_ids +from danswer.search.utils import chunks_to_search_docs from danswer.secondary_llm_flows.choose_search import check_if_need_search from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase from danswer.server.query_and_chat.models import ChatMessageDetail @@ -377,37 +376,25 @@ def stream_chat_message_objects( else query_override ) - ( - retrieval_request, - predicted_search_type, - predicted_flow, - ) = retrieval_preprocessing( - query=rephrased_query, - retrieval_details=cast(RetrievalDetails, retrieval_options), - persona=persona, + search_pipeline = SearchPipeline( + search_request=SearchRequest( + query=rephrased_query, + human_selected_filters=retrieval_options.filters + if retrieval_options + else None, + persona=persona, + offset=retrieval_options.offset if retrieval_options else None, + limit=retrieval_options.limit if retrieval_options else None, + ), user=user, db_session=db_session, ) - documents_generator = full_chunk_search_generator( - search_query=retrieval_request, - document_index=document_index, - db_session=db_session, - ) - time_cutoff = retrieval_request.filters.time_cutoff - recency_bias_multiplier = retrieval_request.recency_bias_multiplier - run_llm_chunk_filter = not retrieval_request.skip_llm_chunk_filter - - # First fetch and return the top chunks to the UI so the user can - # immediately see some results - top_chunks = cast(list[InferenceChunk], next(documents_generator)) + top_chunks = search_pipeline.reranked_docs + top_docs = chunks_to_search_docs(top_chunks) # Get ranking of the documents for citation purposes later - doc_id_to_rank_map = map_document_id_order( - cast(list[InferenceChunk | LlmDoc], top_chunks) - ) - - top_docs = chunks_to_search_docs(top_chunks) + doc_id_to_rank_map = map_document_id_order(top_chunks) reference_db_search_docs = [ create_db_search_doc(server_search_doc=top_doc, db_session=db_session) @@ -422,24 +409,17 @@ def stream_chat_message_objects( initial_response = QADocsResponse( rephrased_query=rephrased_query, top_documents=response_docs, - predicted_flow=predicted_flow, - predicted_search=predicted_search_type, - applied_source_filters=retrieval_request.filters.source_type, - applied_time_cutoff=time_cutoff, - recency_bias_multiplier=recency_bias_multiplier, + predicted_flow=search_pipeline.predicted_flow, + predicted_search=search_pipeline.predicted_search_type, + applied_source_filters=search_pipeline.search_query.filters.source_type, + applied_time_cutoff=search_pipeline.search_query.filters.time_cutoff, + recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier, ) yield initial_response - # Get the final ordering of chunks for the LLM call - llm_chunk_selection = cast(list[bool], next(documents_generator)) - # Yield the list of LLM selected chunks for showing the LLM selected icons in the UI llm_relevance_filtering_response = LLMRelevanceFilterResponse( - relevant_chunk_indices=[ - index for index, value in enumerate(llm_chunk_selection) if value - ] - if run_llm_chunk_filter - else [] + relevant_chunk_indices=search_pipeline.relevant_chunk_indicies ) yield llm_relevance_filtering_response @@ -467,7 +447,7 @@ def stream_chat_message_objects( ) llm_chunks_indices = get_chunks_for_qa( chunks=top_chunks, - llm_chunk_selection=llm_chunk_selection, + llm_chunk_selection=search_pipeline.chunk_relevance_list, token_limit=chunk_token_limit, llm_tokenizer=llm_tokenizer, ) diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 343912e275..6dfa02c2f9 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -27,7 +27,7 @@ from danswer.db.models import SearchDoc as DBSearchDoc from danswer.db.models import StarterMessage from danswer.db.models import User__UserGroup -from danswer.search.models import RecencyBiasSetting +from danswer.search.enums import RecencyBiasSetting from danswer.search.models import RetrievalDocs from danswer.search.models import SavedSearchDoc from danswer.search.models import SearchDoc as ServerSearchDoc diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index abe189c45e..faafd2aedf 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -36,8 +36,8 @@ from danswer.configs.constants import SearchFeedbackType from danswer.connectors.models import InputType from danswer.dynamic_configs.interface import JSON_ro -from danswer.search.models import RecencyBiasSetting -from danswer.search.models import SearchType +from danswer.search.enums import RecencyBiasSetting +from danswer.search.enums import SearchType class IndexingStatus(str, PyEnum): diff --git a/backend/danswer/db/slack_bot_config.py b/backend/danswer/db/slack_bot_config.py index 3e93a76cf6..c3b463e35d 100644 --- a/backend/danswer/db/slack_bot_config.py +++ b/backend/danswer/db/slack_bot_config.py @@ -12,7 +12,7 @@ from danswer.db.models import Persona__DocumentSet from danswer.db.models import SlackBotConfig from danswer.db.models import SlackBotResponseType -from danswer.search.models import RecencyBiasSetting +from danswer.search.enums import RecencyBiasSetting def _build_persona_name(channel_names: list[str]) -> str: diff --git a/backend/danswer/document_index/vespa/index.py b/backend/danswer/document_index/vespa/index.py index 178aadf3ee..9f78f05c20 100644 --- a/backend/danswer/document_index/vespa/index.py +++ b/backend/danswer/document_index/vespa/index.py @@ -64,8 +64,8 @@ from danswer.indexing.models import DocMetadataAwareIndexChunk from danswer.indexing.models import InferenceChunk from danswer.search.models import IndexFilters -from danswer.search.search_runner import query_processing -from danswer.search.search_runner import remove_stop_words_and_punctuation +from danswer.search.retrieval.search_runner import query_processing +from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation from danswer.utils.batching import batch_generator from danswer.utils.logger import setup_logger from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 529180a799..db5ef6f0f9 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -1,7 +1,6 @@ import itertools from collections.abc import Callable from collections.abc import Iterator -from typing import cast from langchain.schema.messages import BaseMessage from langchain.schema.messages import HumanMessage @@ -33,11 +32,9 @@ from danswer.db.chat import get_persona_by_id from danswer.db.chat import get_prompt_by_id from danswer.db.chat import translate_db_message_to_chat_message_detail -from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_session_context_manager from danswer.db.models import Prompt from danswer.db.models import User -from danswer.document_index.factory import get_default_document_index from danswer.indexing.models import InferenceChunk from danswer.llm.factory import get_default_llm from danswer.llm.utils import get_default_llm_token_encode @@ -55,9 +52,9 @@ from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer from danswer.search.models import SavedSearchDoc -from danswer.search.request_preprocessing import retrieval_preprocessing -from danswer.search.search_runner import chunks_to_search_docs -from danswer.search.search_runner import full_chunk_search_generator +from danswer.search.models import SearchRequest +from danswer.search.pipeline import SearchPipeline +from danswer.search.utils import chunks_to_search_docs from danswer.secondary_llm_flows.answer_validation import get_answer_validity from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase from danswer.server.query_and_chat.models import ChatMessageDetail @@ -221,12 +218,6 @@ def stream_answer_objects( llm_tokenizer = get_default_llm_token_encode() - embedding_model = get_current_db_embedding_model(db_session) - - document_index = get_default_document_index( - primary_index_name=embedding_model.index_name, secondary_index_name=None - ) - # Create a chat session which will just store the root message, the query, and the AI response root_message = get_or_create_root_message( chat_session_id=chat_session.id, db_session=db_session @@ -244,33 +235,23 @@ def stream_answer_objects( # In chat flow it's given back along with the documents yield QueryRephrase(rephrased_query=rephrased_query) - ( - retrieval_request, - predicted_search_type, - predicted_flow, - ) = retrieval_preprocessing( - query=rephrased_query, - retrieval_details=query_req.retrieval_options, - persona=chat_session.persona, + search_pipeline = SearchPipeline( + search_request=SearchRequest( + query=rephrased_query, + human_selected_filters=query_req.retrieval_options.filters, + persona=chat_session.persona, + offset=query_req.retrieval_options.offset, + limit=query_req.retrieval_options.limit, + ), user=user, db_session=db_session, bypass_acl=bypass_acl, - ) - - documents_generator = full_chunk_search_generator( - search_query=retrieval_request, - document_index=document_index, - db_session=db_session, retrieval_metrics_callback=retrieval_metrics_callback, rerank_metrics_callback=rerank_metrics_callback, ) - applied_time_cutoff = retrieval_request.filters.time_cutoff - recency_bias_multiplier = retrieval_request.recency_bias_multiplier - run_llm_chunk_filter = not retrieval_request.skip_llm_chunk_filter # First fetch and return the top chunks so the user can immediately see some results - top_chunks = cast(list[InferenceChunk], next(documents_generator)) - + top_chunks = search_pipeline.reranked_docs top_docs = chunks_to_search_docs(top_chunks) fake_saved_docs = [SavedSearchDoc.from_search_doc(doc) for doc in top_docs] @@ -278,24 +259,17 @@ def stream_answer_objects( initial_response = QADocsResponse( rephrased_query=rephrased_query, top_documents=fake_saved_docs, - predicted_flow=predicted_flow, - predicted_search=predicted_search_type, - applied_source_filters=retrieval_request.filters.source_type, - applied_time_cutoff=applied_time_cutoff, - recency_bias_multiplier=recency_bias_multiplier, + predicted_flow=search_pipeline.predicted_flow, + predicted_search=search_pipeline.predicted_search_type, + applied_source_filters=search_pipeline.search_query.filters.source_type, + applied_time_cutoff=search_pipeline.search_query.filters.time_cutoff, + recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier, ) yield initial_response - # Get the final ordering of chunks for the LLM call - llm_chunk_selection = cast(list[bool], next(documents_generator)) - # Yield the list of LLM selected chunks for showing the LLM selected icons in the UI llm_relevance_filtering_response = LLMRelevanceFilterResponse( - relevant_chunk_indices=[ - index for index, value in enumerate(llm_chunk_selection) if value - ] - if run_llm_chunk_filter - else [] + relevant_chunk_indices=search_pipeline.relevant_chunk_indicies ) yield llm_relevance_filtering_response @@ -317,7 +291,7 @@ def stream_answer_objects( llm_chunks_indices = get_chunks_for_qa( chunks=top_chunks, - llm_chunk_selection=llm_chunk_selection, + llm_chunk_selection=search_pipeline.chunk_relevance_list, token_limit=chunk_token_limit, ) llm_chunks = [top_chunks[i] for i in llm_chunks_indices] diff --git a/backend/danswer/search/enums.py b/backend/danswer/search/enums.py new file mode 100644 index 0000000000..9ba44ada2c --- /dev/null +++ b/backend/danswer/search/enums.py @@ -0,0 +1,30 @@ +"""NOTE: this needs to be separate from models.py because of circular imports. +Both search/models.py and db/models.py import enums from this file AND +search/models.py imports from db/models.py.""" +from enum import Enum + + +class OptionalSearchSetting(str, Enum): + ALWAYS = "always" + NEVER = "never" + # Determine whether to run search based on history and latest query + AUTO = "auto" + + +class RecencyBiasSetting(str, Enum): + FAVOR_RECENT = "favor_recent" # 2x decay rate + BASE_DECAY = "base_decay" + NO_DECAY = "no_decay" + # Determine based on query if to use base_decay or favor_recent + AUTO = "auto" + + +class SearchType(str, Enum): + KEYWORD = "keyword" + SEMANTIC = "semantic" + HYBRID = "hybrid" + + +class QueryFlow(str, Enum): + SEARCH = "search" + QUESTION_ANSWER = "question-answer" diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index db3dc31f83..d2ad74c34e 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -1,46 +1,24 @@ from datetime import datetime -from enum import Enum from typing import Any from pydantic import BaseModel from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER +from danswer.configs.chat_configs import HYBRID_ALPHA from danswer.configs.chat_configs import NUM_RERANKED_RESULTS from danswer.configs.chat_configs import NUM_RETURNED_HITS from danswer.configs.constants import DocumentSource from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW +from danswer.db.models import Persona +from danswer.search.enums import OptionalSearchSetting +from danswer.search.enums import SearchType + MAX_METRICS_CONTENT = ( 200 # Just need enough characters to identify where in the doc the chunk is ) -class OptionalSearchSetting(str, Enum): - ALWAYS = "always" - NEVER = "never" - # Determine whether to run search based on history and latest query - AUTO = "auto" - - -class RecencyBiasSetting(str, Enum): - FAVOR_RECENT = "favor_recent" # 2x decay rate - BASE_DECAY = "base_decay" - NO_DECAY = "no_decay" - # Determine based on query if to use base_decay or favor_recent - AUTO = "auto" - - -class SearchType(str, Enum): - KEYWORD = "keyword" - SEMANTIC = "semantic" - HYBRID = "hybrid" - - -class QueryFlow(str, Enum): - SEARCH = "search" - QUESTION_ANSWER = "question-answer" - - class Tag(BaseModel): tag_key: str tag_value: str @@ -64,6 +42,28 @@ class ChunkMetric(BaseModel): score: float +class SearchRequest(BaseModel): + """Input to the SearchPipeline.""" + + query: str + search_type: SearchType = SearchType.HYBRID + + human_selected_filters: BaseFilters | None = None + enable_auto_detect_filters: bool | None = None + persona: Persona | None = None + + # if None, no offset / limit + offset: int | None = None + limit: int | None = None + + recency_bias_multiplier: float = 1.0 + hybrid_alpha: float = HYBRID_ALPHA + skip_rerank: bool = True + + class Config: + arbitrary_types_allowed = True + + class SearchQuery(BaseModel): query: str filters: IndexFilters diff --git a/backend/danswer/search/pipeline.py b/backend/danswer/search/pipeline.py new file mode 100644 index 0000000000..972f510db9 --- /dev/null +++ b/backend/danswer/search/pipeline.py @@ -0,0 +1,152 @@ +from collections.abc import Callable +from typing import cast + +from sqlalchemy.orm import Session + +from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION +from danswer.db.embedding_model import get_current_db_embedding_model +from danswer.db.models import User +from danswer.document_index.factory import get_default_document_index +from danswer.indexing.models import InferenceChunk +from danswer.search.enums import QueryFlow +from danswer.search.enums import SearchType +from danswer.search.models import RerankMetricsContainer +from danswer.search.models import RetrievalMetricsContainer +from danswer.search.models import SearchQuery +from danswer.search.models import SearchRequest +from danswer.search.postprocessing.postprocessing import search_postprocessing +from danswer.search.preprocessing.preprocessing import retrieval_preprocessing +from danswer.search.retrieval.search_runner import retrieve_chunks + + +class SearchPipeline: + def __init__( + self, + search_request: SearchRequest, + user: User | None, + db_session: Session, + bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION + retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] + | None = None, + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, + ): + self.search_request = search_request + self.user = user + self.db_session = db_session + self.bypass_acl = bypass_acl + self.retrieval_metrics_callback = retrieval_metrics_callback + self.rerank_metrics_callback = rerank_metrics_callback + + self.embedding_model = get_current_db_embedding_model(db_session) + self.document_index = get_default_document_index( + primary_index_name=self.embedding_model.index_name, + secondary_index_name=None, + ) + + self._search_query: SearchQuery | None = None + self._predicted_search_type: SearchType | None = None + self._predicted_flow: QueryFlow | None = None + + self._retrieved_docs: list[InferenceChunk] | None = None + self._reranked_docs: list[InferenceChunk] | None = None + self._relevant_chunk_indicies: list[int] | None = None + + """Pre-processing""" + + def _run_preprocessing(self) -> None: + ( + final_search_query, + predicted_search_type, + predicted_flow, + ) = retrieval_preprocessing( + search_request=self.search_request, + user=self.user, + db_session=self.db_session, + bypass_acl=self.bypass_acl, + ) + self._predicted_search_type = predicted_search_type + self._predicted_flow = predicted_flow + self._search_query = final_search_query + + @property + def search_query(self) -> SearchQuery: + if self._search_query is not None: + return self._search_query + + self._run_preprocessing() + return cast(SearchQuery, self._search_query) + + @property + def predicted_search_type(self) -> SearchType: + if self._predicted_search_type is not None: + return self._predicted_search_type + + self._run_preprocessing() + return cast(SearchType, self._predicted_search_type) + + @property + def predicted_flow(self) -> QueryFlow: + if self._predicted_flow is not None: + return self._predicted_flow + + self._run_preprocessing() + return cast(QueryFlow, self._predicted_flow) + + """Retrieval""" + + @property + def retrieved_docs(self) -> list[InferenceChunk]: + if self._retrieved_docs is not None: + return self._retrieved_docs + + self._retrieved_docs = retrieve_chunks( + query=self.search_query, + document_index=self.document_index, + db_session=self.db_session, + hybrid_alpha=self.search_request.hybrid_alpha, + multilingual_expansion_str=MULTILINGUAL_QUERY_EXPANSION, + retrieval_metrics_callback=self.retrieval_metrics_callback, + ) + + # self._retrieved_docs = chunks_to_search_docs(retrieved_chunks) + return cast(list[InferenceChunk], self._retrieved_docs) + + """Post-Processing""" + + def _run_postprocessing(self) -> None: + postprocessing_generator = search_postprocessing( + search_query=self.search_query, + retrieved_chunks=self.retrieved_docs, + rerank_metrics_callback=self.rerank_metrics_callback, + ) + self._reranked_docs = cast(list[InferenceChunk], next(postprocessing_generator)) + + relevant_chunk_ids = cast(list[str], next(postprocessing_generator)) + self._relevant_chunk_indicies = [ + ind + for ind, chunk in enumerate(self._reranked_docs) + if chunk.unique_id in relevant_chunk_ids + ] + + @property + def reranked_docs(self) -> list[InferenceChunk]: + if self._reranked_docs is not None: + return self._reranked_docs + + self._run_postprocessing() + return cast(list[InferenceChunk], self._reranked_docs) + + @property + def relevant_chunk_indicies(self) -> list[int]: + if self._relevant_chunk_indicies is not None: + return self._relevant_chunk_indicies + + self._run_postprocessing() + return cast(list[int], self._relevant_chunk_indicies) + + @property + def chunk_relevance_list(self) -> list[bool]: + return [ + True if ind in self.relevant_chunk_indicies else False + for ind in range(len(self.reranked_docs)) + ] diff --git a/backend/danswer/search/postprocessing/postprocessing.py b/backend/danswer/search/postprocessing/postprocessing.py new file mode 100644 index 0000000000..e1cee4bd6d --- /dev/null +++ b/backend/danswer/search/postprocessing/postprocessing.py @@ -0,0 +1,222 @@ +from collections.abc import Callable +from collections.abc import Generator +from typing import cast + +import numpy + +from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX +from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN +from danswer.document_index.document_index_utils import ( + translate_boost_count_to_multiplier, +) +from danswer.indexing.models import InferenceChunk +from danswer.search.models import ChunkMetric +from danswer.search.models import MAX_METRICS_CONTENT +from danswer.search.models import RerankMetricsContainer +from danswer.search.models import SearchQuery +from danswer.search.models import SearchType +from danswer.search.search_nlp_models import CrossEncoderEnsembleModel +from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks +from danswer.utils.logger import setup_logger +from danswer.utils.threadpool_concurrency import FunctionCall +from danswer.utils.threadpool_concurrency import run_functions_in_parallel +from danswer.utils.timing import log_function_time + + +logger = setup_logger() + + +def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None: + top_links = [ + c.source_links[0] if c.source_links is not None else "No Link" for c in chunks + ] + logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}") + + +def should_rerank(query: SearchQuery) -> bool: + # Don't re-rank for keyword search + return query.search_type != SearchType.KEYWORD and not query.skip_rerank + + +def should_apply_llm_based_relevance_filter(query: SearchQuery) -> bool: + return not query.skip_llm_chunk_filter + + +@log_function_time(print_only=True) +def semantic_reranking( + query: str, + chunks: list[InferenceChunk], + model_min: int = CROSS_ENCODER_RANGE_MIN, + model_max: int = CROSS_ENCODER_RANGE_MAX, + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, +) -> tuple[list[InferenceChunk], list[int]]: + """Reranks chunks based on cross-encoder models. Additionally provides the original indices + of the chunks in their new sorted order. + + Note: this updates the chunks in place, it updates the chunk scores which came from retrieval + """ + cross_encoders = CrossEncoderEnsembleModel() + passages = [chunk.content for chunk in chunks] + sim_scores_floats = cross_encoders.predict(query=query, passages=passages) + + sim_scores = [numpy.array(scores) for scores in sim_scores_floats] + + raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores)) + + cross_models_min = numpy.min(sim_scores) + + shifted_sim_scores = sum( + [enc_n_scores - cross_models_min for enc_n_scores in sim_scores] + ) / len(sim_scores) + + boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks] + recency_multiplier = [chunk.recency_bias for chunk in chunks] + boosted_sim_scores = shifted_sim_scores * boosts * recency_multiplier + normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / ( + model_max - model_min + ) + orig_indices = [i for i in range(len(normalized_b_s_scores))] + scored_results = list( + zip(normalized_b_s_scores, raw_sim_scores, chunks, orig_indices) + ) + scored_results.sort(key=lambda x: x[0], reverse=True) + ranked_sim_scores, ranked_raw_scores, ranked_chunks, ranked_indices = zip( + *scored_results + ) + + logger.debug( + f"Reranked (Boosted + Time Weighted) similarity scores: {ranked_sim_scores}" + ) + + # Assign new chunk scores based on reranking + for ind, chunk in enumerate(ranked_chunks): + chunk.score = ranked_sim_scores[ind] + + if rerank_metrics_callback is not None: + chunk_metrics = [ + ChunkMetric( + document_id=chunk.document_id, + chunk_content_start=chunk.content[:MAX_METRICS_CONTENT], + first_link=chunk.source_links[0] if chunk.source_links else None, + score=chunk.score if chunk.score is not None else 0, + ) + for chunk in ranked_chunks + ] + + rerank_metrics_callback( + RerankMetricsContainer( + metrics=chunk_metrics, raw_similarity_scores=ranked_raw_scores # type: ignore + ) + ) + + return list(ranked_chunks), list(ranked_indices) + + +def rerank_chunks( + query: SearchQuery, + chunks_to_rerank: list[InferenceChunk], + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, +) -> list[InferenceChunk]: + ranked_chunks, _ = semantic_reranking( + query=query.query, + chunks=chunks_to_rerank[: query.num_rerank], + rerank_metrics_callback=rerank_metrics_callback, + ) + lower_chunks = chunks_to_rerank[query.num_rerank :] + # Scores from rerank cannot be meaningfully combined with scores without rerank + for lower_chunk in lower_chunks: + lower_chunk.score = None + ranked_chunks.extend(lower_chunks) + return ranked_chunks + + +@log_function_time(print_only=True) +def filter_chunks( + query: SearchQuery, + chunks_to_filter: list[InferenceChunk], +) -> list[str]: + """Filters chunks based on whether the LLM thought they were relevant to the query. + + Returns a list of the unique chunk IDs that were marked as relevant""" + chunks_to_filter = chunks_to_filter[: query.max_llm_filter_chunks] + llm_chunk_selection = llm_batch_eval_chunks( + query=query.query, + chunk_contents=[chunk.content for chunk in chunks_to_filter], + ) + return [ + chunk.unique_id + for ind, chunk in enumerate(chunks_to_filter) + if llm_chunk_selection[ind] + ] + + +def search_postprocessing( + search_query: SearchQuery, + retrieved_chunks: list[InferenceChunk], + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, +) -> Generator[list[InferenceChunk] | list[str], None, None]: + post_processing_tasks: list[FunctionCall] = [] + + rerank_task_id = None + if should_rerank(search_query): + post_processing_tasks.append( + FunctionCall( + rerank_chunks, + ( + search_query, + retrieved_chunks, + rerank_metrics_callback, + ), + ) + ) + rerank_task_id = post_processing_tasks[-1].result_id + else: + final_chunks = retrieved_chunks + # NOTE: if we don't rerank, we can return the chunks immediately + # since we know this is the final order + _log_top_chunk_links(search_query.search_type.value, final_chunks) + yield final_chunks + chunks_yielded = True + + llm_filter_task_id = None + if should_apply_llm_based_relevance_filter(search_query): + post_processing_tasks.append( + FunctionCall( + filter_chunks, + (search_query, retrieved_chunks[: search_query.max_llm_filter_chunks]), + ) + ) + llm_filter_task_id = post_processing_tasks[-1].result_id + + post_processing_results = ( + run_functions_in_parallel(post_processing_tasks) + if post_processing_tasks + else {} + ) + reranked_chunks = cast( + list[InferenceChunk] | None, + post_processing_results.get(str(rerank_task_id)) if rerank_task_id else None, + ) + if reranked_chunks: + if chunks_yielded: + logger.error( + "Trying to yield re-ranked chunks, but chunks were already yielded. This should never happen." + ) + else: + _log_top_chunk_links(search_query.search_type.value, reranked_chunks) + yield reranked_chunks + + llm_chunk_selection = cast( + list[str] | None, + post_processing_results.get(str(llm_filter_task_id)) + if llm_filter_task_id + else None, + ) + if llm_chunk_selection is not None: + yield [ + chunk.unique_id + for chunk in reranked_chunks or retrieved_chunks + if chunk.unique_id in llm_chunk_selection + ] + else: + yield [] diff --git a/backend/danswer/search/access_filters.py b/backend/danswer/search/preprocessing/access_filters.py similarity index 100% rename from backend/danswer/search/access_filters.py rename to backend/danswer/search/preprocessing/access_filters.py diff --git a/backend/danswer/search/danswer_helper.py b/backend/danswer/search/preprocessing/danswer_helper.py similarity index 96% rename from backend/danswer/search/danswer_helper.py rename to backend/danswer/search/preprocessing/danswer_helper.py index d5dbeb8a3e..88e465dacb 100644 --- a/backend/danswer/search/danswer_helper.py +++ b/backend/danswer/search/preprocessing/danswer_helper.py @@ -1,10 +1,10 @@ from typing import TYPE_CHECKING -from danswer.search.models import QueryFlow +from danswer.search.enums import QueryFlow from danswer.search.models import SearchType +from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation from danswer.search.search_nlp_models import get_default_tokenizer from danswer.search.search_nlp_models import IntentModel -from danswer.search.search_runner import remove_stop_words_and_punctuation from danswer.server.query_and_chat.models import HelperResponse from danswer.utils.logger import setup_logger diff --git a/backend/danswer/search/request_preprocessing.py b/backend/danswer/search/preprocessing/preprocessing.py similarity index 76% rename from backend/danswer/search/request_preprocessing.py rename to backend/danswer/search/preprocessing/preprocessing.py index e74618d395..f35afe4389 100644 --- a/backend/danswer/search/request_preprocessing.py +++ b/backend/danswer/search/preprocessing/preprocessing.py @@ -5,19 +5,16 @@ from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER from danswer.configs.chat_configs import NUM_RETURNED_HITS -from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW -from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW -from danswer.db.models import Persona from danswer.db.models import User -from danswer.search.access_filters import build_access_filters_for_user -from danswer.search.danswer_helper import query_intent +from danswer.search.enums import QueryFlow +from danswer.search.enums import RecencyBiasSetting from danswer.search.models import BaseFilters from danswer.search.models import IndexFilters -from danswer.search.models import QueryFlow -from danswer.search.models import RecencyBiasSetting -from danswer.search.models import RetrievalDetails from danswer.search.models import SearchQuery +from danswer.search.models import SearchRequest from danswer.search.models import SearchType +from danswer.search.preprocessing.access_filters import build_access_filters_for_user +from danswer.search.preprocessing.danswer_helper import query_intent from danswer.secondary_llm_flows.source_filter import extract_source_filter from danswer.secondary_llm_flows.time_filter import extract_time_filter from danswer.utils.logger import setup_logger @@ -31,15 +28,12 @@ @log_function_time(print_only=True) def retrieval_preprocessing( - query: str, - retrieval_details: RetrievalDetails, - persona: Persona, + search_request: SearchRequest, user: User | None, db_session: Session, bypass_acl: bool = False, include_query_intent: bool = True, - skip_rerank_realtime: bool = not ENABLE_RERANKING_REAL_TIME_FLOW, - skip_rerank_non_realtime: bool = not ENABLE_RERANKING_ASYNC_FLOW, + enable_auto_detect_filters: bool = False, disable_llm_filter_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION, disable_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER, base_recency_decay: float = BASE_RECENCY_DECAY, @@ -50,8 +44,12 @@ def retrieval_preprocessing( Then any filters or settings as part of the query are used Then defaults to Persona settings if not specified by the query """ + query = search_request.query + limit = search_request.limit + offset = search_request.offset + persona = search_request.persona - preset_filters = retrieval_details.filters or BaseFilters() + preset_filters = search_request.human_selected_filters or BaseFilters() if persona and persona.document_sets and preset_filters.document_set is None: preset_filters.document_set = [ document_set.name for document_set in persona.document_sets @@ -65,16 +63,20 @@ def retrieval_preprocessing( if disable_llm_filter_extraction: auto_detect_time_filter = False auto_detect_source_filter = False - elif retrieval_details.enable_auto_detect_filters is False: + elif enable_auto_detect_filters is False: logger.debug("Retrieval details disables auto detect filters") auto_detect_time_filter = False auto_detect_source_filter = False - elif persona.llm_filter_extraction is False: + elif persona and persona.llm_filter_extraction is False: logger.debug("Persona disables auto detect filters") auto_detect_time_filter = False auto_detect_source_filter = False - if time_filter is not None and persona.recency_bias != RecencyBiasSetting.AUTO: + if ( + time_filter is not None + and persona + and persona.recency_bias != RecencyBiasSetting.AUTO + ): auto_detect_time_filter = False logger.debug("Not extract time filter - already provided") if source_filter is not None: @@ -138,24 +140,18 @@ def retrieval_preprocessing( access_control_list=user_acl_filters, ) - # Tranformer-based re-ranking to run at same time as LLM chunk relevance filter - # This one is only set globally, not via query or Persona settings - skip_reranking = ( - skip_rerank_realtime - if retrieval_details.real_time - else skip_rerank_non_realtime - ) - - llm_chunk_filter = persona.llm_relevance_filter + llm_chunk_filter = False + if persona: + llm_chunk_filter = persona.llm_relevance_filter if disable_llm_chunk_filter: llm_chunk_filter = False # Decays at 1 / (1 + (multiplier * num years)) - if persona.recency_bias == RecencyBiasSetting.NO_DECAY: + if persona and persona.recency_bias == RecencyBiasSetting.NO_DECAY: recency_bias_multiplier = 0.0 - elif persona.recency_bias == RecencyBiasSetting.BASE_DECAY: + elif persona and persona.recency_bias == RecencyBiasSetting.BASE_DECAY: recency_bias_multiplier = base_recency_decay - elif persona.recency_bias == RecencyBiasSetting.FAVOR_RECENT: + elif persona and persona.recency_bias == RecencyBiasSetting.FAVOR_RECENT: recency_bias_multiplier = base_recency_decay * favor_recent_decay_multiplier else: if predicted_favor_recent: @@ -166,14 +162,12 @@ def retrieval_preprocessing( return ( SearchQuery( query=query, - search_type=persona.search_type, + search_type=persona.search_type if persona else SearchType.HYBRID, filters=final_filters, recency_bias_multiplier=recency_bias_multiplier, - num_hits=retrieval_details.limit - if retrieval_details.limit is not None - else NUM_RETURNED_HITS, - offset=retrieval_details.offset or 0, - skip_rerank=skip_reranking, + num_hits=limit if limit is not None else NUM_RETURNED_HITS, + offset=offset or 0, + skip_rerank=search_request.skip_rerank, skip_llm_chunk_filter=not llm_chunk_filter, ), predicted_search_type, diff --git a/backend/danswer/search/retrieval/search_runner.py b/backend/danswer/search/retrieval/search_runner.py new file mode 100644 index 0000000000..3dff76d96e --- /dev/null +++ b/backend/danswer/search/retrieval/search_runner.py @@ -0,0 +1,256 @@ +import string +from collections.abc import Callable + +from nltk.corpus import stopwords # type:ignore +from nltk.stem import WordNetLemmatizer # type:ignore +from nltk.tokenize import word_tokenize # type:ignore +from sqlalchemy.orm import Session + +from danswer.chat.models import LlmDoc +from danswer.configs.app_configs import MODEL_SERVER_HOST +from danswer.configs.app_configs import MODEL_SERVER_PORT +from danswer.configs.chat_configs import HYBRID_ALPHA +from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION +from danswer.db.embedding_model import get_current_db_embedding_model +from danswer.document_index.interfaces import DocumentIndex +from danswer.indexing.models import InferenceChunk +from danswer.search.models import ChunkMetric +from danswer.search.models import IndexFilters +from danswer.search.models import MAX_METRICS_CONTENT +from danswer.search.models import RetrievalMetricsContainer +from danswer.search.models import SearchQuery +from danswer.search.models import SearchType +from danswer.search.search_nlp_models import EmbeddingModel +from danswer.search.search_nlp_models import EmbedTextType +from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion +from danswer.utils.logger import setup_logger +from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel +from danswer.utils.timing import log_function_time + + +logger = setup_logger() + + +def lemmatize_text(text: str) -> list[str]: + lemmatizer = WordNetLemmatizer() + word_tokens = word_tokenize(text) + return [lemmatizer.lemmatize(word) for word in word_tokens] + + +def remove_stop_words_and_punctuation(text: str) -> list[str]: + stop_words = set(stopwords.words("english")) + word_tokens = word_tokenize(text) + text_trimmed = [ + word + for word in word_tokens + if (word.casefold() not in stop_words and word not in string.punctuation) + ] + return text_trimmed or word_tokens + + +def query_processing( + query: str, +) -> str: + query = " ".join(remove_stop_words_and_punctuation(query)) + query = " ".join(lemmatize_text(query)) + return query + + +def combine_retrieval_results( + chunk_sets: list[list[InferenceChunk]], +) -> list[InferenceChunk]: + all_chunks = [chunk for chunk_set in chunk_sets for chunk in chunk_set] + + unique_chunks: dict[tuple[str, int], InferenceChunk] = {} + for chunk in all_chunks: + key = (chunk.document_id, chunk.chunk_id) + if key not in unique_chunks: + unique_chunks[key] = chunk + continue + + stored_chunk_score = unique_chunks[key].score or 0 + this_chunk_score = chunk.score or 0 + if stored_chunk_score < this_chunk_score: + unique_chunks[key] = chunk + + sorted_chunks = sorted( + unique_chunks.values(), key=lambda x: x.score or 0, reverse=True + ) + + return sorted_chunks + + +@log_function_time(print_only=True) +def doc_index_retrieval( + query: SearchQuery, + document_index: DocumentIndex, + db_session: Session, + hybrid_alpha: float = HYBRID_ALPHA, +) -> list[InferenceChunk]: + if query.search_type == SearchType.KEYWORD: + top_chunks = document_index.keyword_retrieval( + query=query.query, + filters=query.filters, + time_decay_multiplier=query.recency_bias_multiplier, + num_to_retrieve=query.num_hits, + ) + else: + db_embedding_model = get_current_db_embedding_model(db_session) + + model = EmbeddingModel( + model_name=db_embedding_model.model_name, + query_prefix=db_embedding_model.query_prefix, + passage_prefix=db_embedding_model.passage_prefix, + normalize=db_embedding_model.normalize, + # The below are globally set, this flow always uses the indexing one + server_host=MODEL_SERVER_HOST, + server_port=MODEL_SERVER_PORT, + ) + + query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0] + + if query.search_type == SearchType.SEMANTIC: + top_chunks = document_index.semantic_retrieval( + query=query.query, + query_embedding=query_embedding, + filters=query.filters, + time_decay_multiplier=query.recency_bias_multiplier, + num_to_retrieve=query.num_hits, + ) + + elif query.search_type == SearchType.HYBRID: + top_chunks = document_index.hybrid_retrieval( + query=query.query, + query_embedding=query_embedding, + filters=query.filters, + time_decay_multiplier=query.recency_bias_multiplier, + num_to_retrieve=query.num_hits, + offset=query.offset, + hybrid_alpha=hybrid_alpha, + ) + + else: + raise RuntimeError("Invalid Search Flow") + + return top_chunks + + +def _simplify_text(text: str) -> str: + return "".join( + char for char in text if char not in string.punctuation and not char.isspace() + ).lower() + + +def retrieve_chunks( + query: SearchQuery, + document_index: DocumentIndex, + db_session: Session, + hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search + multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION, + retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] + | None = None, +) -> list[InferenceChunk]: + """Returns a list of the best chunks from an initial keyword/semantic/ hybrid search.""" + # Don't do query expansion on complex queries, rephrasings likely would not work well + if not multilingual_expansion_str or "\n" in query.query or "\r" in query.query: + top_chunks = doc_index_retrieval( + query=query, + document_index=document_index, + db_session=db_session, + hybrid_alpha=hybrid_alpha, + ) + else: + simplified_queries = set() + run_queries: list[tuple[Callable, tuple]] = [] + + # Currently only uses query expansion on multilingual use cases + query_rephrases = multilingual_query_expansion( + query.query, multilingual_expansion_str + ) + # Just to be extra sure, add the original query. + query_rephrases.append(query.query) + for rephrase in set(query_rephrases): + # Sometimes the model rephrases the query in the same language with minor changes + # Avoid doing an extra search with the minor changes as this biases the results + simplified_rephrase = _simplify_text(rephrase) + if simplified_rephrase in simplified_queries: + continue + simplified_queries.add(simplified_rephrase) + + q_copy = query.copy(update={"query": rephrase}, deep=True) + run_queries.append( + ( + doc_index_retrieval, + (q_copy, document_index, db_session, hybrid_alpha), + ) + ) + parallel_search_results = run_functions_tuples_in_parallel(run_queries) + top_chunks = combine_retrieval_results(parallel_search_results) + + if not top_chunks: + logger.info( + f"{query.search_type.value.capitalize()} search returned no results " + f"with filters: {query.filters}" + ) + return [] + + if retrieval_metrics_callback is not None: + chunk_metrics = [ + ChunkMetric( + document_id=chunk.document_id, + chunk_content_start=chunk.content[:MAX_METRICS_CONTENT], + first_link=chunk.source_links[0] if chunk.source_links else None, + score=chunk.score if chunk.score is not None else 0, + ) + for chunk in top_chunks + ] + retrieval_metrics_callback( + RetrievalMetricsContainer( + search_type=query.search_type, metrics=chunk_metrics + ) + ) + + return top_chunks + + +def combine_inference_chunks(inf_chunks: list[InferenceChunk]) -> LlmDoc: + if not inf_chunks: + raise ValueError("Cannot combine empty list of chunks") + + # Use the first link of the document + first_chunk = inf_chunks[0] + chunk_texts = [chunk.content for chunk in inf_chunks] + return LlmDoc( + document_id=first_chunk.document_id, + content="\n".join(chunk_texts), + semantic_identifier=first_chunk.semantic_identifier, + source_type=first_chunk.source_type, + metadata=first_chunk.metadata, + updated_at=first_chunk.updated_at, + link=first_chunk.source_links[0] if first_chunk.source_links else None, + ) + + +def inference_documents_from_ids( + doc_identifiers: list[tuple[str, int]], + document_index: DocumentIndex, +) -> list[LlmDoc]: + # Currently only fetches whole docs + doc_ids_set = set(doc_id for doc_id, chunk_id in doc_identifiers) + + # No need for ACL here because the doc ids were validated beforehand + filters = IndexFilters(access_control_list=None) + + functions_with_args: list[tuple[Callable, tuple]] = [ + (document_index.id_based_retrieval, (doc_id, None, filters)) + for doc_id in doc_ids_set + ] + + parallel_results = run_functions_tuples_in_parallel( + functions_with_args, allow_failures=True + ) + + # Any failures to retrieve would give a None, drop the Nones and empty lists + inference_chunks_sets = [res for res in parallel_results if res] + + return [combine_inference_chunks(chunk_set) for chunk_set in inference_chunks_sets] diff --git a/backend/danswer/search/search_runner.py b/backend/danswer/search/search_runner.py deleted file mode 100644 index 18bfa1a3c1..0000000000 --- a/backend/danswer/search/search_runner.py +++ /dev/null @@ -1,645 +0,0 @@ -import string -from collections.abc import Callable -from collections.abc import Iterator -from typing import cast - -import numpy -from nltk.corpus import stopwords # type:ignore -from nltk.stem import WordNetLemmatizer # type:ignore -from nltk.tokenize import word_tokenize # type:ignore -from sqlalchemy.orm import Session - -from danswer.chat.models import LlmDoc -from danswer.configs.app_configs import MODEL_SERVER_HOST -from danswer.configs.app_configs import MODEL_SERVER_PORT -from danswer.configs.chat_configs import HYBRID_ALPHA -from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION -from danswer.configs.chat_configs import NUM_RERANKED_RESULTS -from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX -from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN -from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH -from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW -from danswer.db.embedding_model import get_current_db_embedding_model -from danswer.document_index.document_index_utils import ( - translate_boost_count_to_multiplier, -) -from danswer.document_index.interfaces import DocumentIndex -from danswer.indexing.models import InferenceChunk -from danswer.search.models import ChunkMetric -from danswer.search.models import IndexFilters -from danswer.search.models import MAX_METRICS_CONTENT -from danswer.search.models import RerankMetricsContainer -from danswer.search.models import RetrievalMetricsContainer -from danswer.search.models import SearchDoc -from danswer.search.models import SearchQuery -from danswer.search.models import SearchType -from danswer.search.search_nlp_models import CrossEncoderEnsembleModel -from danswer.search.search_nlp_models import EmbeddingModel -from danswer.search.search_nlp_models import EmbedTextType -from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks -from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion -from danswer.utils.logger import setup_logger -from danswer.utils.threadpool_concurrency import FunctionCall -from danswer.utils.threadpool_concurrency import run_functions_in_parallel -from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel -from danswer.utils.timing import log_function_time - - -logger = setup_logger() - - -def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None: - top_links = [ - c.source_links[0] if c.source_links is not None else "No Link" for c in chunks - ] - logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}") - - -def lemmatize_text(text: str) -> list[str]: - lemmatizer = WordNetLemmatizer() - word_tokens = word_tokenize(text) - return [lemmatizer.lemmatize(word) for word in word_tokens] - - -def remove_stop_words_and_punctuation(text: str) -> list[str]: - stop_words = set(stopwords.words("english")) - word_tokens = word_tokenize(text) - text_trimmed = [ - word - for word in word_tokens - if (word.casefold() not in stop_words and word not in string.punctuation) - ] - return text_trimmed or word_tokens - - -def query_processing( - query: str, -) -> str: - query = " ".join(remove_stop_words_and_punctuation(query)) - query = " ".join(lemmatize_text(query)) - return query - - -def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]: - search_docs = ( - [ - SearchDoc( - document_id=chunk.document_id, - chunk_ind=chunk.chunk_id, - semantic_identifier=chunk.semantic_identifier or "Unknown", - link=chunk.source_links.get(0) if chunk.source_links else None, - blurb=chunk.blurb, - source_type=chunk.source_type, - boost=chunk.boost, - hidden=chunk.hidden, - metadata=chunk.metadata, - score=chunk.score, - match_highlights=chunk.match_highlights, - updated_at=chunk.updated_at, - primary_owners=chunk.primary_owners, - secondary_owners=chunk.secondary_owners, - ) - for chunk in chunks - ] - if chunks - else [] - ) - return search_docs - - -def combine_retrieval_results( - chunk_sets: list[list[InferenceChunk]], -) -> list[InferenceChunk]: - all_chunks = [chunk for chunk_set in chunk_sets for chunk in chunk_set] - - unique_chunks: dict[tuple[str, int], InferenceChunk] = {} - for chunk in all_chunks: - key = (chunk.document_id, chunk.chunk_id) - if key not in unique_chunks: - unique_chunks[key] = chunk - continue - - stored_chunk_score = unique_chunks[key].score or 0 - this_chunk_score = chunk.score or 0 - if stored_chunk_score < this_chunk_score: - unique_chunks[key] = chunk - - sorted_chunks = sorted( - unique_chunks.values(), key=lambda x: x.score or 0, reverse=True - ) - - return sorted_chunks - - -@log_function_time(print_only=True) -def doc_index_retrieval( - query: SearchQuery, - document_index: DocumentIndex, - db_session: Session, - hybrid_alpha: float = HYBRID_ALPHA, -) -> list[InferenceChunk]: - if query.search_type == SearchType.KEYWORD: - top_chunks = document_index.keyword_retrieval( - query=query.query, - filters=query.filters, - time_decay_multiplier=query.recency_bias_multiplier, - num_to_retrieve=query.num_hits, - ) - else: - db_embedding_model = get_current_db_embedding_model(db_session) - - model = EmbeddingModel( - model_name=db_embedding_model.model_name, - query_prefix=db_embedding_model.query_prefix, - passage_prefix=db_embedding_model.passage_prefix, - normalize=db_embedding_model.normalize, - # The below are globally set, this flow always uses the indexing one - server_host=MODEL_SERVER_HOST, - server_port=MODEL_SERVER_PORT, - ) - - query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0] - - if query.search_type == SearchType.SEMANTIC: - top_chunks = document_index.semantic_retrieval( - query=query.query, - query_embedding=query_embedding, - filters=query.filters, - time_decay_multiplier=query.recency_bias_multiplier, - num_to_retrieve=query.num_hits, - ) - - elif query.search_type == SearchType.HYBRID: - top_chunks = document_index.hybrid_retrieval( - query=query.query, - query_embedding=query_embedding, - filters=query.filters, - time_decay_multiplier=query.recency_bias_multiplier, - num_to_retrieve=query.num_hits, - offset=query.offset, - hybrid_alpha=hybrid_alpha, - ) - - else: - raise RuntimeError("Invalid Search Flow") - - return top_chunks - - -@log_function_time(print_only=True) -def semantic_reranking( - query: str, - chunks: list[InferenceChunk], - rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, - model_min: int = CROSS_ENCODER_RANGE_MIN, - model_max: int = CROSS_ENCODER_RANGE_MAX, -) -> tuple[list[InferenceChunk], list[int]]: - """Reranks chunks based on cross-encoder models. Additionally provides the original indices - of the chunks in their new sorted order. - - Note: this updates the chunks in place, it updates the chunk scores which came from retrieval - """ - cross_encoders = CrossEncoderEnsembleModel() - passages = [chunk.content for chunk in chunks] - sim_scores_floats = cross_encoders.predict(query=query, passages=passages) - - sim_scores = [numpy.array(scores) for scores in sim_scores_floats] - - raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores)) - - cross_models_min = numpy.min(sim_scores) - - shifted_sim_scores = sum( - [enc_n_scores - cross_models_min for enc_n_scores in sim_scores] - ) / len(sim_scores) - - boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks] - recency_multiplier = [chunk.recency_bias for chunk in chunks] - boosted_sim_scores = shifted_sim_scores * boosts * recency_multiplier - normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / ( - model_max - model_min - ) - orig_indices = [i for i in range(len(normalized_b_s_scores))] - scored_results = list( - zip(normalized_b_s_scores, raw_sim_scores, chunks, orig_indices) - ) - scored_results.sort(key=lambda x: x[0], reverse=True) - ranked_sim_scores, ranked_raw_scores, ranked_chunks, ranked_indices = zip( - *scored_results - ) - - logger.debug( - f"Reranked (Boosted + Time Weighted) similarity scores: {ranked_sim_scores}" - ) - - # Assign new chunk scores based on reranking - for ind, chunk in enumerate(ranked_chunks): - chunk.score = ranked_sim_scores[ind] - - if rerank_metrics_callback is not None: - chunk_metrics = [ - ChunkMetric( - document_id=chunk.document_id, - chunk_content_start=chunk.content[:MAX_METRICS_CONTENT], - first_link=chunk.source_links[0] if chunk.source_links else None, - score=chunk.score if chunk.score is not None else 0, - ) - for chunk in ranked_chunks - ] - - rerank_metrics_callback( - RerankMetricsContainer( - metrics=chunk_metrics, raw_similarity_scores=ranked_raw_scores # type: ignore - ) - ) - - return list(ranked_chunks), list(ranked_indices) - - -def apply_boost_legacy( - chunks: list[InferenceChunk], - norm_min: float = SIM_SCORE_RANGE_LOW, - norm_max: float = SIM_SCORE_RANGE_HIGH, -) -> list[InferenceChunk]: - scores = [chunk.score or 0 for chunk in chunks] - boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks] - - logger.debug(f"Raw similarity scores: {scores}") - - score_min = min(scores) - score_max = max(scores) - score_range = score_max - score_min - - if score_range != 0: - boosted_scores = [ - ((score - score_min) / score_range) * boost - for score, boost in zip(scores, boosts) - ] - unnormed_boosted_scores = [ - score * score_range + score_min for score in boosted_scores - ] - else: - unnormed_boosted_scores = [ - score * boost for score, boost in zip(scores, boosts) - ] - - norm_min = min(norm_min, min(scores)) - norm_max = max(norm_max, max(scores)) - # This should never be 0 unless user has done some weird/wrong settings - norm_range = norm_max - norm_min - - # For score display purposes - if norm_range != 0: - re_normed_scores = [ - ((score - norm_min) / norm_range) for score in unnormed_boosted_scores - ] - else: - re_normed_scores = unnormed_boosted_scores - - rescored_chunks = list(zip(re_normed_scores, chunks)) - rescored_chunks.sort(key=lambda x: x[0], reverse=True) - sorted_boosted_scores, boost_sorted_chunks = zip(*rescored_chunks) - - final_chunks = list(boost_sorted_chunks) - final_scores = list(sorted_boosted_scores) - for ind, chunk in enumerate(final_chunks): - chunk.score = final_scores[ind] - - logger.debug(f"Boost sorted similary scores: {list(final_scores)}") - - return final_chunks - - -def apply_boost( - chunks: list[InferenceChunk], - # Need the range of values to not be too spread out for applying boost - # therefore norm across only the top few results - norm_cutoff: int = NUM_RERANKED_RESULTS, - norm_min: float = SIM_SCORE_RANGE_LOW, - norm_max: float = SIM_SCORE_RANGE_HIGH, -) -> list[InferenceChunk]: - scores = [chunk.score or 0.0 for chunk in chunks] - logger.debug(f"Raw similarity scores: {scores}") - - boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks] - recency_multiplier = [chunk.recency_bias for chunk in chunks] - - norm_min = min(norm_min, min(scores[:norm_cutoff])) - norm_max = max(norm_max, max(scores[:norm_cutoff])) - # This should never be 0 unless user has done some weird/wrong settings - norm_range = norm_max - norm_min - - boosted_scores = [ - max(0, (score - norm_min) * boost * recency / norm_range) - for score, boost, recency in zip(scores, boosts, recency_multiplier) - ] - - rescored_chunks = list(zip(boosted_scores, chunks)) - rescored_chunks.sort(key=lambda x: x[0], reverse=True) - sorted_boosted_scores, boost_sorted_chunks = zip(*rescored_chunks) - - final_chunks = list(boost_sorted_chunks) - final_scores = list(sorted_boosted_scores) - for ind, chunk in enumerate(final_chunks): - chunk.score = final_scores[ind] - - logger.debug( - f"Boosted + Time Weighted sorted similarity scores: {list(final_scores)}" - ) - - return final_chunks - - -def _simplify_text(text: str) -> str: - return "".join( - char for char in text if char not in string.punctuation and not char.isspace() - ).lower() - - -def retrieve_chunks( - query: SearchQuery, - document_index: DocumentIndex, - db_session: Session, - hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search - multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION, - retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] - | None = None, -) -> list[InferenceChunk]: - """Returns a list of the best chunks from an initial keyword/semantic/ hybrid search.""" - # Don't do query expansion on complex queries, rephrasings likely would not work well - if not multilingual_expansion_str or "\n" in query.query or "\r" in query.query: - top_chunks = doc_index_retrieval( - query=query, - document_index=document_index, - db_session=db_session, - hybrid_alpha=hybrid_alpha, - ) - else: - simplified_queries = set() - run_queries: list[tuple[Callable, tuple]] = [] - - # Currently only uses query expansion on multilingual use cases - query_rephrases = multilingual_query_expansion( - query.query, multilingual_expansion_str - ) - # Just to be extra sure, add the original query. - query_rephrases.append(query.query) - for rephrase in set(query_rephrases): - # Sometimes the model rephrases the query in the same language with minor changes - # Avoid doing an extra search with the minor changes as this biases the results - simplified_rephrase = _simplify_text(rephrase) - if simplified_rephrase in simplified_queries: - continue - simplified_queries.add(simplified_rephrase) - - q_copy = query.copy(update={"query": rephrase}, deep=True) - run_queries.append( - ( - doc_index_retrieval, - (q_copy, document_index, db_session, hybrid_alpha), - ) - ) - parallel_search_results = run_functions_tuples_in_parallel(run_queries) - top_chunks = combine_retrieval_results(parallel_search_results) - - if not top_chunks: - logger.info( - f"{query.search_type.value.capitalize()} search returned no results " - f"with filters: {query.filters}" - ) - return [] - - if retrieval_metrics_callback is not None: - chunk_metrics = [ - ChunkMetric( - document_id=chunk.document_id, - chunk_content_start=chunk.content[:MAX_METRICS_CONTENT], - first_link=chunk.source_links[0] if chunk.source_links else None, - score=chunk.score if chunk.score is not None else 0, - ) - for chunk in top_chunks - ] - retrieval_metrics_callback( - RetrievalMetricsContainer( - search_type=query.search_type, metrics=chunk_metrics - ) - ) - - return top_chunks - - -def should_rerank(query: SearchQuery) -> bool: - # Don't re-rank for keyword search - return query.search_type != SearchType.KEYWORD and not query.skip_rerank - - -def should_apply_llm_based_relevance_filter(query: SearchQuery) -> bool: - return not query.skip_llm_chunk_filter - - -def rerank_chunks( - query: SearchQuery, - chunks_to_rerank: list[InferenceChunk], - rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> list[InferenceChunk]: - ranked_chunks, _ = semantic_reranking( - query=query.query, - chunks=chunks_to_rerank[: query.num_rerank], - rerank_metrics_callback=rerank_metrics_callback, - ) - lower_chunks = chunks_to_rerank[query.num_rerank :] - # Scores from rerank cannot be meaningfully combined with scores without rerank - for lower_chunk in lower_chunks: - lower_chunk.score = None - ranked_chunks.extend(lower_chunks) - return ranked_chunks - - -@log_function_time(print_only=True) -def filter_chunks( - query: SearchQuery, - chunks_to_filter: list[InferenceChunk], -) -> list[str]: - """Filters chunks based on whether the LLM thought they were relevant to the query. - - Returns a list of the unique chunk IDs that were marked as relevant""" - chunks_to_filter = chunks_to_filter[: query.max_llm_filter_chunks] - llm_chunk_selection = llm_batch_eval_chunks( - query=query.query, - chunk_contents=[chunk.content for chunk in chunks_to_filter], - ) - return [ - chunk.unique_id - for ind, chunk in enumerate(chunks_to_filter) - if llm_chunk_selection[ind] - ] - - -def full_chunk_search( - query: SearchQuery, - document_index: DocumentIndex, - db_session: Session, - hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search - multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION, - retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] - | None = None, - rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> tuple[list[InferenceChunk], list[bool]]: - """A utility which provides an easier interface than `full_chunk_search_generator`. - Rather than returning the chunks and llm relevance filter results in two separate - yields, just returns them both at once.""" - search_generator = full_chunk_search_generator( - search_query=query, - document_index=document_index, - db_session=db_session, - hybrid_alpha=hybrid_alpha, - multilingual_expansion_str=multilingual_expansion_str, - retrieval_metrics_callback=retrieval_metrics_callback, - rerank_metrics_callback=rerank_metrics_callback, - ) - top_chunks = cast(list[InferenceChunk], next(search_generator)) - llm_chunk_selection = cast(list[bool], next(search_generator)) - return top_chunks, llm_chunk_selection - - -def empty_search_generator() -> Iterator[list[InferenceChunk] | list[bool]]: - yield cast(list[InferenceChunk], []) - yield cast(list[bool], []) - - -def full_chunk_search_generator( - search_query: SearchQuery, - document_index: DocumentIndex, - db_session: Session, - hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search - multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION, - retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] - | None = None, - rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> Iterator[list[InferenceChunk] | list[bool]]: - """Always yields twice. Once with the selected chunks and once with the LLM relevance filter result. - If LLM filter results are turned off, returns a list of False - """ - chunks_yielded = False - - retrieved_chunks = retrieve_chunks( - query=search_query, - document_index=document_index, - db_session=db_session, - hybrid_alpha=hybrid_alpha, - multilingual_expansion_str=multilingual_expansion_str, - retrieval_metrics_callback=retrieval_metrics_callback, - ) - - if not retrieved_chunks: - yield cast(list[InferenceChunk], []) - yield cast(list[bool], []) - return - - post_processing_tasks: list[FunctionCall] = [] - - rerank_task_id = None - if should_rerank(search_query): - post_processing_tasks.append( - FunctionCall( - rerank_chunks, - ( - search_query, - retrieved_chunks, - rerank_metrics_callback, - ), - ) - ) - rerank_task_id = post_processing_tasks[-1].result_id - else: - final_chunks = retrieved_chunks - # NOTE: if we don't rerank, we can return the chunks immediately - # since we know this is the final order - _log_top_chunk_links(search_query.search_type.value, final_chunks) - yield final_chunks - chunks_yielded = True - - llm_filter_task_id = None - if should_apply_llm_based_relevance_filter(search_query): - post_processing_tasks.append( - FunctionCall( - filter_chunks, - (search_query, retrieved_chunks[: search_query.max_llm_filter_chunks]), - ) - ) - llm_filter_task_id = post_processing_tasks[-1].result_id - - post_processing_results = ( - run_functions_in_parallel(post_processing_tasks) - if post_processing_tasks - else {} - ) - reranked_chunks = cast( - list[InferenceChunk] | None, - post_processing_results.get(str(rerank_task_id)) if rerank_task_id else None, - ) - if reranked_chunks: - if chunks_yielded: - logger.error( - "Trying to yield re-ranked chunks, but chunks were already yielded. This should never happen." - ) - else: - _log_top_chunk_links(search_query.search_type.value, reranked_chunks) - yield reranked_chunks - - llm_chunk_selection = cast( - list[str] | None, - post_processing_results.get(str(llm_filter_task_id)) - if llm_filter_task_id - else None, - ) - if llm_chunk_selection is not None: - yield [ - chunk.unique_id in llm_chunk_selection - for chunk in reranked_chunks or retrieved_chunks - ] - else: - yield [False for _ in reranked_chunks or retrieved_chunks] - - -def combine_inference_chunks(inf_chunks: list[InferenceChunk]) -> LlmDoc: - if not inf_chunks: - raise ValueError("Cannot combine empty list of chunks") - - # Use the first link of the document - first_chunk = inf_chunks[0] - chunk_texts = [chunk.content for chunk in inf_chunks] - return LlmDoc( - document_id=first_chunk.document_id, - content="\n".join(chunk_texts), - semantic_identifier=first_chunk.semantic_identifier, - source_type=first_chunk.source_type, - metadata=first_chunk.metadata, - updated_at=first_chunk.updated_at, - link=first_chunk.source_links[0] if first_chunk.source_links else None, - ) - - -def inference_documents_from_ids( - doc_identifiers: list[tuple[str, int]], - document_index: DocumentIndex, -) -> list[LlmDoc]: - # Currently only fetches whole docs - doc_ids_set = set(doc_id for doc_id, chunk_id in doc_identifiers) - - # No need for ACL here because the doc ids were validated beforehand - filters = IndexFilters(access_control_list=None) - - functions_with_args: list[tuple[Callable, tuple]] = [ - (document_index.id_based_retrieval, (doc_id, None, filters)) - for doc_id in doc_ids_set - ] - - parallel_results = run_functions_tuples_in_parallel( - functions_with_args, allow_failures=True - ) - - # Any failures to retrieve would give a None, drop the Nones and empty lists - inference_chunks_sets = [res for res in parallel_results if res] - - return [combine_inference_chunks(chunk_set) for chunk_set in inference_chunks_sets] diff --git a/backend/danswer/search/utils.py b/backend/danswer/search/utils.py new file mode 100644 index 0000000000..4b01f70eb9 --- /dev/null +++ b/backend/danswer/search/utils.py @@ -0,0 +1,29 @@ +from danswer.indexing.models import InferenceChunk +from danswer.search.models import SearchDoc + + +def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]: + search_docs = ( + [ + SearchDoc( + document_id=chunk.document_id, + chunk_ind=chunk.chunk_id, + semantic_identifier=chunk.semantic_identifier or "Unknown", + link=chunk.source_links.get(0) if chunk.source_links else None, + blurb=chunk.blurb, + source_type=chunk.source_type, + boost=chunk.boost, + hidden=chunk.hidden, + metadata=chunk.metadata, + score=chunk.score, + match_highlights=chunk.match_highlights, + updated_at=chunk.updated_at, + primary_owners=chunk.primary_owners, + secondary_owners=chunk.secondary_owners, + ) + for chunk in chunks + ] + if chunks + else [] + ) + return search_docs diff --git a/backend/danswer/server/documents/document.py b/backend/danswer/server/documents/document.py index ea080b0335..3abab33029 100644 --- a/backend/danswer/server/documents/document.py +++ b/backend/danswer/server/documents/document.py @@ -11,8 +11,8 @@ from danswer.document_index.factory import get_default_document_index from danswer.llm.utils import get_default_llm_token_encode from danswer.prompts.prompt_utils import build_doc_context_str -from danswer.search.access_filters import build_access_filters_for_user from danswer.search.models import IndexFilters +from danswer.search.preprocessing.access_filters import build_access_filters_for_user from danswer.server.documents.models import ChunkInfo from danswer.server.documents.models import DocumentInfo diff --git a/backend/danswer/server/features/persona/models.py b/backend/danswer/server/features/persona/models.py index a724ac5f3e..4cc80eec0e 100644 --- a/backend/danswer/server/features/persona/models.py +++ b/backend/danswer/server/features/persona/models.py @@ -4,7 +4,7 @@ from danswer.db.models import Persona from danswer.db.models import StarterMessage -from danswer.search.models import RecencyBiasSetting +from danswer.search.enums import RecencyBiasSetting from danswer.server.features.document_set.models import DocumentSet from danswer.server.features.prompt.models import PromptSnapshot diff --git a/backend/danswer/server/gpts/api.py b/backend/danswer/server/gpts/api.py index 9800032520..bfada9b559 100644 --- a/backend/danswer/server/gpts/api.py +++ b/backend/danswer/server/gpts/api.py @@ -6,13 +6,9 @@ from pydantic import BaseModel from sqlalchemy.orm import Session -from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_session -from danswer.document_index.factory import get_default_document_index -from danswer.search.access_filters import build_access_filters_for_user -from danswer.search.models import IndexFilters -from danswer.search.models import SearchQuery -from danswer.search.search_runner import full_chunk_search +from danswer.search.models import SearchRequest +from danswer.search.pipeline import SearchPipeline from danswer.server.danswer_api.ingestion import api_key_dep from danswer.utils.logger import setup_logger @@ -70,27 +66,13 @@ def gpt_search( _: str | None = Depends(api_key_dep), db_session: Session = Depends(get_session), ) -> GptSearchResponse: - query = search_request.query - - user_acl_filters = build_access_filters_for_user(None, db_session) - final_filters = IndexFilters(access_control_list=user_acl_filters) - - search_query = SearchQuery( - query=query, - filters=final_filters, - recency_bias_multiplier=1.0, - skip_llm_chunk_filter=True, - ) - - embedding_model = get_current_db_embedding_model(db_session) - - document_index = get_default_document_index( - primary_index_name=embedding_model.index_name, secondary_index_name=None - ) - - top_chunks, __ = full_chunk_search( - query=search_query, document_index=document_index, db_session=db_session - ) + top_chunks = SearchPipeline( + search_request=SearchRequest( + query=search_request.query, + ), + user=None, + db_session=db_session, + ).reranked_docs return GptSearchResponse( matching_document_chunks=[ diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index 6d8529486f..5150eb9ce1 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -15,11 +15,11 @@ from danswer.document_index.vespa.index import VespaIndex from danswer.one_shot_answer.answer_question import stream_search_answer from danswer.one_shot_answer.models import DirectQARequest -from danswer.search.access_filters import build_access_filters_for_user -from danswer.search.danswer_helper import recommend_search_flow from danswer.search.models import IndexFilters from danswer.search.models import SearchDoc -from danswer.search.search_runner import chunks_to_search_docs +from danswer.search.preprocessing.access_filters import build_access_filters_for_user +from danswer.search.preprocessing.danswer_helper import recommend_search_flow +from danswer.search.utils import chunks_to_search_docs from danswer.secondary_llm_flows.query_validation import get_query_answerability from danswer.secondary_llm_flows.query_validation import stream_query_answerability from danswer.server.query_and_chat.models import AdminSearchRequest diff --git a/backend/tests/regression/search_quality/eval_search.py b/backend/tests/regression/search_quality/eval_search.py index 7cd3e6068c..d40ae13480 100644 --- a/backend/tests/regression/search_quality/eval_search.py +++ b/backend/tests/regression/search_quality/eval_search.py @@ -8,15 +8,12 @@ from sqlalchemy.orm import Session from danswer.chat.chat_utils import get_chunks_for_qa -from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_sqlalchemy_engine -from danswer.document_index.factory import get_default_document_index from danswer.indexing.models import InferenceChunk -from danswer.search.models import IndexFilters from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer -from danswer.search.models import SearchQuery -from danswer.search.search_runner import full_chunk_search +from danswer.search.models import SearchRequest +from danswer.search.pipeline import SearchPipeline from danswer.utils.callbacks import MetricsHander @@ -81,35 +78,22 @@ def get_search_results( RetrievalMetricsContainer | None, RerankMetricsContainer | None, ]: - filters = IndexFilters( - source_type=None, - document_set=None, - time_cutoff=None, - access_control_list=None, - ) - search_query = SearchQuery( - query=query, - filters=filters, - recency_bias_multiplier=1.0, - ) - retrieval_metrics = MetricsHander[RetrievalMetricsContainer]() rerank_metrics = MetricsHander[RerankMetricsContainer]() with Session(get_sqlalchemy_engine()) as db_session: - embedding_model = get_current_db_embedding_model(db_session) - - document_index = get_default_document_index( - primary_index_name=embedding_model.index_name, secondary_index_name=None - ) - - top_chunks, llm_chunk_selection = full_chunk_search( - query=search_query, - document_index=document_index, - db_session=db_session, - retrieval_metrics_callback=retrieval_metrics.record_metric, - rerank_metrics_callback=rerank_metrics.record_metric, - ) + search_pipeline = SearchPipeline( + search_request=SearchRequest( + query=query, + ), + user=None, + db_session=db_session, + retrieval_metrics_callback=retrieval_metrics.record_metric, + rerank_metrics_callback=rerank_metrics.record_metric, + ) + + top_chunks = search_pipeline.reranked_docs + llm_chunk_selection = search_pipeline.chunk_relevance_list llm_chunks_indices = get_chunks_for_qa( chunks=top_chunks, From f135ba9c0cb02e2141acced80463ed88358a1721 Mon Sep 17 00:00:00 2001 From: Weves Date: Mon, 25 Mar 2024 12:09:27 -0700 Subject: [PATCH 19/58] Rework LLM answering flow --- backend/danswer/chat/chat_utils.py | 456 +----------------- backend/danswer/chat/models.py | 11 +- backend/danswer/chat/process_message.py | 262 ++-------- .../slack/handlers/handle_message.py | 2 +- backend/danswer/llm/answering/answer.py | 176 +++++++ backend/danswer/llm/answering/doc_pruning.py | 205 ++++++++ backend/danswer/llm/answering/models.py | 77 +++ .../llm/answering/prompts/citations_prompt.py | 281 +++++++++++ .../llm/answering/prompts/quotes_prompt.py | 88 ++++ .../danswer/llm/answering/prompts/utils.py | 20 + .../stream_processing/citation_processing.py | 126 +++++ .../stream_processing/quotes_processing.py | 282 +++++++++++ .../llm/answering/stream_processing/utils.py | 17 + backend/danswer/llm/utils.py | 7 +- .../one_shot_answer/answer_question.py | 228 ++------- backend/danswer/one_shot_answer/factory.py | 48 -- backend/danswer/one_shot_answer/interfaces.py | 26 - backend/danswer/one_shot_answer/qa_block.py | 313 ------------ backend/danswer/one_shot_answer/qa_utils.py | 261 ---------- backend/danswer/search/pipeline.py | 46 +- .../danswer/search/retrieval/search_runner.py | 2 + .../danswer/server/features/persona/api.py | 2 +- .../server/query_and_chat/chat_backend.py | 2 +- .../answer_quality/eval_direct_qa.py | 11 - .../regression/search_quality/eval_search.py | 12 +- .../unit/danswer/direct_qa/test_qa_utils.py | 8 +- 26 files changed, 1404 insertions(+), 1565 deletions(-) create mode 100644 backend/danswer/llm/answering/answer.py create mode 100644 backend/danswer/llm/answering/doc_pruning.py create mode 100644 backend/danswer/llm/answering/models.py create mode 100644 backend/danswer/llm/answering/prompts/citations_prompt.py create mode 100644 backend/danswer/llm/answering/prompts/quotes_prompt.py create mode 100644 backend/danswer/llm/answering/prompts/utils.py create mode 100644 backend/danswer/llm/answering/stream_processing/citation_processing.py create mode 100644 backend/danswer/llm/answering/stream_processing/quotes_processing.py create mode 100644 backend/danswer/llm/answering/stream_processing/utils.py delete mode 100644 backend/danswer/one_shot_answer/factory.py delete mode 100644 backend/danswer/one_shot_answer/interfaces.py delete mode 100644 backend/danswer/one_shot_answer/qa_block.py diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index fe97b0b392..ee2f582c95 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -1,97 +1,29 @@ import re -from collections.abc import Callable -from collections.abc import Iterator from collections.abc import Sequence -from functools import lru_cache -from typing import cast -from langchain.schema.messages import BaseMessage -from langchain.schema.messages import HumanMessage -from langchain.schema.messages import SystemMessage from sqlalchemy.orm import Session -from tiktoken.core import Encoding from danswer.chat.models import CitationInfo -from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc -from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION -from danswer.configs.chat_configs import STOP_STREAM_PAT -from danswer.configs.constants import IGNORE_FOR_QA -from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE -from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS from danswer.db.chat import get_chat_messages_by_session -from danswer.db.chat import get_default_prompt from danswer.db.models import ChatMessage -from danswer.db.models import Persona -from danswer.db.models import Prompt from danswer.indexing.models import InferenceChunk -from danswer.llm.utils import check_number_of_tokens -from danswer.llm.utils import get_default_llm_tokenizer -from danswer.llm.utils import get_default_llm_version -from danswer.llm.utils import get_max_input_tokens -from danswer.llm.utils import tokenizer_trim_content -from danswer.prompts.chat_prompts import ADDITIONAL_INFO -from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT -from danswer.prompts.chat_prompts import CHAT_USER_PROMPT -from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT -from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT -from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT -from danswer.prompts.constants import TRIPLE_BACKTICK -from danswer.prompts.prompt_utils import build_complete_context_str -from danswer.prompts.prompt_utils import build_task_prompt_reminders -from danswer.prompts.prompt_utils import get_current_llm_day_time -from danswer.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT -from danswer.prompts.token_counts import ( - CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT, -) -from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT -from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT -from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT from danswer.utils.logger import setup_logger logger = setup_logger() -@lru_cache() -def build_chat_system_message( - prompt: Prompt, - context_exists: bool, - llm_tokenizer_encode_func: Callable, - citation_line: str = REQUIRE_CITATION_STATEMENT, - no_citation_line: str = NO_CITATION_STATEMENT, -) -> tuple[SystemMessage | None, int]: - system_prompt = prompt.system_prompt.strip() - if prompt.include_citations: - if context_exists: - system_prompt += citation_line - else: - system_prompt += no_citation_line - if prompt.datetime_aware: - if system_prompt: - system_prompt += ADDITIONAL_INFO.format( - datetime_info=get_current_llm_day_time() - ) - else: - system_prompt = get_current_llm_day_time() - - if not system_prompt: - return None, 0 - - token_count = len(llm_tokenizer_encode_func(system_prompt)) - system_msg = SystemMessage(content=system_prompt) - - return system_msg, token_count - - def llm_doc_from_inference_chunk(inf_chunk: InferenceChunk) -> LlmDoc: return LlmDoc( document_id=inf_chunk.document_id, content=inf_chunk.content, + blurb=inf_chunk.blurb, semantic_identifier=inf_chunk.semantic_identifier, source_type=inf_chunk.source_type, metadata=inf_chunk.metadata, updated_at=inf_chunk.updated_at, link=inf_chunk.source_links[0] if inf_chunk.source_links else None, + source_links=inf_chunk.source_links, ) @@ -108,170 +40,6 @@ def map_document_id_order( return order_mapping -def build_chat_user_message( - chat_message: ChatMessage, - prompt: Prompt, - context_docs: list[LlmDoc], - llm_tokenizer_encode_func: Callable, - all_doc_useful: bool, - user_prompt_template: str = CHAT_USER_PROMPT, - context_free_template: str = CHAT_USER_CONTEXT_FREE_PROMPT, - ignore_str: str = DEFAULT_IGNORE_STATEMENT, -) -> tuple[HumanMessage, int]: - user_query = chat_message.message - - if not context_docs: - # Simpler prompt for cases where there is no context - user_prompt = ( - context_free_template.format( - task_prompt=prompt.task_prompt, user_query=user_query - ) - if prompt.task_prompt - else user_query - ) - user_prompt = user_prompt.strip() - token_count = len(llm_tokenizer_encode_func(user_prompt)) - user_msg = HumanMessage(content=user_prompt) - return user_msg, token_count - - context_docs_str = build_complete_context_str( - cast(list[LlmDoc | InferenceChunk], context_docs) - ) - optional_ignore = "" if all_doc_useful else ignore_str - - task_prompt_with_reminder = build_task_prompt_reminders(prompt) - - user_prompt = user_prompt_template.format( - optional_ignore_statement=optional_ignore, - context_docs_str=context_docs_str, - task_prompt=task_prompt_with_reminder, - user_query=user_query, - ) - - user_prompt = user_prompt.strip() - token_count = len(llm_tokenizer_encode_func(user_prompt)) - user_msg = HumanMessage(content=user_prompt) - - return user_msg, token_count - - -def _get_usable_chunks( - chunks: list[InferenceChunk], token_limit: int -) -> list[InferenceChunk]: - total_token_count = 0 - usable_chunks = [] - for chunk in chunks: - chunk_token_count = check_number_of_tokens(chunk.content) - if total_token_count + chunk_token_count > token_limit: - break - - total_token_count += chunk_token_count - usable_chunks.append(chunk) - - # try and return at least one chunk if possible. This chunk will - # get truncated later on in the pipeline. This would only occur if - # the first chunk is larger than the token limit (usually due to character - # count -> token count mismatches caused by special characters / non-ascii - # languages) - if not usable_chunks and chunks: - usable_chunks = [chunks[0]] - - return usable_chunks - - -def get_usable_chunks( - chunks: list[InferenceChunk], - token_limit: int, - offset: int = 0, -) -> list[InferenceChunk]: - offset_into_chunks = 0 - usable_chunks: list[InferenceChunk] = [] - for _ in range(min(offset + 1, 1)): # go through this process at least once - if offset_into_chunks >= len(chunks) and offset_into_chunks > 0: - raise ValueError( - "Chunks offset too large, should not retry this many times" - ) - - usable_chunks = _get_usable_chunks( - chunks=chunks[offset_into_chunks:], token_limit=token_limit - ) - offset_into_chunks += len(usable_chunks) - - return usable_chunks - - -def get_chunks_for_qa( - chunks: list[InferenceChunk], - llm_chunk_selection: list[bool], - token_limit: int | None, - llm_tokenizer: Encoding | None = None, - batch_offset: int = 0, -) -> list[int]: - """ - Gives back indices of chunks to pass into the LLM for Q&A. - - Only selects chunks viable for Q&A, within the token limit, and prioritize those selected - by the LLM in a separate flow (this can be turned off) - - Note, the batch_offset calculation has to count the batches from the beginning each time as - there's no way to know which chunks were included in the prior batches without recounting atm, - this is somewhat slow as it requires tokenizing all the chunks again - """ - token_leeway = 50 - batch_index = 0 - latest_batch_indices: list[int] = [] - token_count = 0 - - # First iterate the LLM selected chunks, then iterate the rest if tokens remaining - for selection_target in [True, False]: - for ind, chunk in enumerate(chunks): - if llm_chunk_selection[ind] is not selection_target or chunk.metadata.get( - IGNORE_FOR_QA - ): - continue - - # We calculate it live in case the user uses a different LLM + tokenizer - chunk_token = check_number_of_tokens(chunk.content) - if chunk_token > DOC_EMBEDDING_CONTEXT_SIZE + token_leeway: - logger.warning( - "Found more tokens in chunk than expected, " - "likely mismatch between embedding and LLM tokenizers. Trimming content..." - ) - chunk.content = tokenizer_trim_content( - content=chunk.content, - desired_length=DOC_EMBEDDING_CONTEXT_SIZE, - tokenizer=llm_tokenizer or get_default_llm_tokenizer(), - ) - - # 50 for an approximate/slight overestimate for # tokens for metadata for the chunk - token_count += chunk_token + token_leeway - - # Always use at least 1 chunk - if ( - token_limit is None - or token_count <= token_limit - or not latest_batch_indices - ): - latest_batch_indices.append(ind) - current_chunk_unused = False - else: - current_chunk_unused = True - - if token_limit is not None and token_count >= token_limit: - if batch_index < batch_offset: - batch_index += 1 - if current_chunk_unused: - latest_batch_indices = [ind] - token_count = chunk_token - else: - latest_batch_indices = [] - token_count = 0 - else: - return latest_batch_indices - - return latest_batch_indices - - def create_chat_chain( chat_session_id: int, db_session: Session, @@ -341,157 +109,6 @@ def combine_message_chain( return "\n\n".join(message_strs) -_PER_MESSAGE_TOKEN_BUFFER = 7 - - -def find_last_index(lst: list[int], max_prompt_tokens: int) -> int: - """From the back, find the index of the last element to include - before the list exceeds the maximum""" - running_sum = 0 - - last_ind = 0 - for i in range(len(lst) - 1, -1, -1): - running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER - if running_sum > max_prompt_tokens: - last_ind = i + 1 - break - if last_ind >= len(lst): - raise ValueError("Last message alone is too large!") - return last_ind - - -def drop_messages_history_overflow( - system_msg: BaseMessage | None, - system_token_count: int, - history_msgs: list[BaseMessage], - history_token_counts: list[int], - final_msg: BaseMessage, - final_msg_token_count: int, - max_allowed_tokens: int, -) -> list[BaseMessage]: - """As message history grows, messages need to be dropped starting from the furthest in the past. - The System message should be kept if at all possible and the latest user input which is inserted in the - prompt template must be included""" - if len(history_msgs) != len(history_token_counts): - # This should never happen - raise ValueError("Need exactly 1 token count per message for tracking overflow") - - prompt: list[BaseMessage] = [] - - # Start dropping from the history if necessary - all_tokens = history_token_counts + [system_token_count, final_msg_token_count] - ind_prev_msg_start = find_last_index( - all_tokens, max_prompt_tokens=max_allowed_tokens - ) - - if system_msg and ind_prev_msg_start <= len(history_msgs): - prompt.append(system_msg) - - prompt.extend(history_msgs[ind_prev_msg_start:]) - - prompt.append(final_msg) - - return prompt - - -def in_code_block(llm_text: str) -> bool: - count = llm_text.count(TRIPLE_BACKTICK) - return count % 2 != 0 - - -def extract_citations_from_stream( - tokens: Iterator[str], - context_docs: list[LlmDoc], - doc_id_to_rank_map: dict[str, int], - stop_stream: str | None = STOP_STREAM_PAT, -) -> Iterator[DanswerAnswerPiece | CitationInfo]: - llm_out = "" - max_citation_num = len(context_docs) - curr_segment = "" - prepend_bracket = False - cited_inds = set() - hold = "" - for raw_token in tokens: - if stop_stream: - next_hold = hold + raw_token - - if stop_stream in next_hold: - break - - if next_hold == stop_stream[: len(next_hold)]: - hold = next_hold - continue - - token = next_hold - hold = "" - else: - token = raw_token - - # Special case of [1][ where ][ is a single token - # This is where the model attempts to do consecutive citations like [1][2] - if prepend_bracket: - curr_segment += "[" + curr_segment - prepend_bracket = False - - curr_segment += token - llm_out += token - - possible_citation_pattern = r"(\[\d*$)" # [1, [, etc - possible_citation_found = re.search(possible_citation_pattern, curr_segment) - - citation_pattern = r"\[(\d+)\]" # [1], [2] etc - citation_found = re.search(citation_pattern, curr_segment) - - if citation_found and not in_code_block(llm_out): - numerical_value = int(citation_found.group(1)) - if 1 <= numerical_value <= max_citation_num: - context_llm_doc = context_docs[ - numerical_value - 1 - ] # remove 1 index offset - - link = context_llm_doc.link - target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id] - - # Use the citation number for the document's rank in - # the search (or selected docs) results - curr_segment = re.sub( - rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment - ) - - if target_citation_num not in cited_inds: - cited_inds.add(target_citation_num) - yield CitationInfo( - citation_num=target_citation_num, - document_id=context_llm_doc.document_id, - ) - - if link: - curr_segment = re.sub(r"\[", "[[", curr_segment, count=1) - curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1) - - # In case there's another open bracket like [1][, don't want to match this - possible_citation_found = None - - # if we see "[", but haven't seen the right side, hold back - this may be a - # citation that needs to be replaced with a link - if possible_citation_found: - continue - - # Special case with back to back citations [1][2] - if curr_segment and curr_segment[-1] == "[": - curr_segment = curr_segment[:-1] - prepend_bracket = True - - yield DanswerAnswerPiece(answer_piece=curr_segment) - curr_segment = "" - - if curr_segment: - if prepend_bracket: - yield DanswerAnswerPiece(answer_piece="[" + curr_segment) - else: - yield DanswerAnswerPiece(answer_piece=curr_segment) - - def reorganize_citations( answer: str, citations: list[CitationInfo] ) -> tuple[str, list[CitationInfo]]: @@ -547,72 +164,3 @@ def slack_link_format(match: re.Match) -> str: new_citation_info[citation.citation_num] = citation return new_answer, list(new_citation_info.values()) - - -def get_prompt_tokens(prompt: Prompt) -> int: - # Note: currently custom prompts do not allow datetime aware, only default prompts - return ( - check_number_of_tokens(prompt.system_prompt) - + check_number_of_tokens(prompt.task_prompt) - + CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT - + CITATION_STATEMENT_TOKEN_CNT - + CITATION_REMINDER_TOKEN_CNT - + (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0) - + (ADDITIONAL_INFO_TOKEN_CNT if prompt.datetime_aware else 0) - ) - - -# buffer just to be safe so that we don't overflow the token limit due to -# a small miscalculation -_MISC_BUFFER = 40 - - -def compute_max_document_tokens( - persona: Persona, - actual_user_input: str | None = None, - max_llm_token_override: int | None = None, -) -> int: - """Estimates the number of tokens available for context documents. Formula is roughly: - - ( - model_context_window - reserved_output_tokens - prompt_tokens - - (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe) - ) - - The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g. - if we're trying to determine if the user should be able to select another document) then we just set an - arbitrary "upper bound". - """ - llm_name = get_default_llm_version()[0] - if persona.llm_model_version_override: - llm_name = persona.llm_model_version_override - - # if we can't find a number of tokens, just assume some common default - max_input_tokens = ( - max_llm_token_override - if max_llm_token_override - else get_max_input_tokens(model_name=llm_name) - ) - if persona.prompts: - # TODO this may not always be the first prompt - prompt_tokens = get_prompt_tokens(persona.prompts[0]) - else: - prompt_tokens = get_prompt_tokens(get_default_prompt()) - - user_input_tokens = ( - check_number_of_tokens(actual_user_input) - if actual_user_input is not None - else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS - ) - - return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER - - -def compute_max_llm_input_tokens(persona: Persona) -> int: - """Maximum tokens allows in the input to the LLM (of any type).""" - llm_name = get_default_llm_version()[0] - if persona.llm_model_version_override: - llm_name = persona.llm_model_version_override - - input_tokens = get_max_input_tokens(model_name=llm_name) - return input_tokens - _MISC_BUFFER diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index 47d554de77..d2dd9f31fa 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -16,11 +16,13 @@ class LlmDoc(BaseModel): document_id: str content: str + blurb: str semantic_identifier: str source_type: DocumentSource metadata: dict[str, str | list[str]] updated_at: datetime | None link: str | None + source_links: dict[int, str] | None # First chunk of info for streaming QA @@ -100,9 +102,12 @@ class QAResponse(SearchResponse, DanswerAnswer): error_msg: str | None = None -AnswerQuestionStreamReturn = Iterator[ - DanswerAnswerPiece | DanswerQuotes | DanswerContexts | StreamingError -] +AnswerQuestionPossibleReturn = ( + DanswerAnswerPiece | DanswerQuotes | CitationInfo | DanswerContexts | StreamingError +) + + +AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn] class LLMMetricsContainer(BaseModel): diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 9cd78c963b..270afc67e2 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -5,16 +5,8 @@ from sqlalchemy.orm import Session -from danswer.chat.chat_utils import build_chat_system_message -from danswer.chat.chat_utils import build_chat_user_message -from danswer.chat.chat_utils import compute_max_document_tokens -from danswer.chat.chat_utils import compute_max_llm_input_tokens from danswer.chat.chat_utils import create_chat_chain -from danswer.chat.chat_utils import drop_messages_history_overflow -from danswer.chat.chat_utils import extract_citations_from_stream -from danswer.chat.chat_utils import get_chunks_for_qa from danswer.chat.chat_utils import llm_doc_from_inference_chunk -from danswer.chat.chat_utils import map_document_id_order from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc @@ -23,9 +15,7 @@ from danswer.chat.models import StreamingError from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT -from danswer.configs.constants import DISABLED_GEN_AI_MSG from danswer.configs.constants import MessageType -from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.db.chat import create_db_search_doc from danswer.db.chat import create_new_chat_message from danswer.db.chat import get_chat_message @@ -37,21 +27,17 @@ from danswer.db.chat import translate_db_search_doc_to_server_search_doc from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_session_context_manager -from danswer.db.models import ChatMessage -from danswer.db.models import Persona from danswer.db.models import SearchDoc as DbSearchDoc from danswer.db.models import User from danswer.document_index.factory import get_default_document_index -from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.answer import Answer +from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import CitationConfig +from danswer.llm.answering.models import DocumentPruningConfig +from danswer.llm.answering.models import PreviousMessage from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm -from danswer.llm.interfaces import LLM from danswer.llm.utils import get_default_llm_tokenizer -from danswer.llm.utils import get_default_llm_version -from danswer.llm.utils import get_max_input_tokens -from danswer.llm.utils import tokenizer_trim_content -from danswer.llm.utils import translate_history_to_basemessages -from danswer.prompts.prompt_utils import build_doc_context_str from danswer.search.models import OptionalSearchSetting from danswer.search.models import SearchRequest from danswer.search.pipeline import SearchPipeline @@ -68,72 +54,6 @@ logger = setup_logger() -def generate_ai_chat_response( - query_message: ChatMessage, - history: list[ChatMessage], - persona: Persona, - context_docs: list[LlmDoc], - doc_id_to_rank_map: dict[str, int], - llm: LLM | None, - llm_tokenizer_encode_func: Callable, - all_doc_useful: bool, -) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]: - if llm is None: - try: - llm = get_default_llm() - except GenAIDisabledException: - # Not an error if it's a user configuration - yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG) - return - - if query_message.prompt is None: - raise RuntimeError("No prompt received for generating Gen AI answer.") - - try: - context_exists = len(context_docs) > 0 - - system_message_or_none, system_tokens = build_chat_system_message( - prompt=query_message.prompt, - context_exists=context_exists, - llm_tokenizer_encode_func=llm_tokenizer_encode_func, - ) - - history_basemessages, history_token_counts = translate_history_to_basemessages( - history - ) - - # Be sure the context_docs passed to build_chat_user_message - # Is the same as passed in later for extracting citations - user_message, user_tokens = build_chat_user_message( - chat_message=query_message, - prompt=query_message.prompt, - context_docs=context_docs, - llm_tokenizer_encode_func=llm_tokenizer_encode_func, - all_doc_useful=all_doc_useful, - ) - - prompt = drop_messages_history_overflow( - system_msg=system_message_or_none, - system_token_count=system_tokens, - history_msgs=history_basemessages, - history_token_counts=history_token_counts, - final_msg=user_message, - final_msg_token_count=user_tokens, - max_allowed_tokens=compute_max_llm_input_tokens(persona), - ) - - # Good Debug/Breakpoint - tokens = llm.stream(prompt) - - yield from extract_citations_from_stream( - tokens, context_docs, doc_id_to_rank_map - ) - - except Exception as e: - logger.exception(f"LLM failed to produce valid chat message, error: {e}") - yield StreamingError(error=str(e)) - - def translate_citations( citations_list: list[CitationInfo], db_docs: list[DbSearchDoc] ) -> dict[int, int]: @@ -154,24 +74,26 @@ def translate_citations( return citation_to_saved_doc_id_map +ChatPacketStream = Iterator[ + StreamingError + | QADocsResponse + | LLMRelevanceFilterResponse + | ChatMessageDetail + | DanswerAnswerPiece + | CitationInfo +] + + def stream_chat_message_objects( new_msg_req: CreateChatMessageRequest, user: User | None, db_session: Session, # Needed to translate persona num_chunks to tokens to the LLM default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT, - default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE, # For flow with search, don't include as many chunks as possible since we need to leave space # for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE, -) -> Iterator[ - StreamingError - | QADocsResponse - | LLMRelevanceFilterResponse - | ChatMessageDetail - | DanswerAnswerPiece - | CitationInfo -]: +) -> ChatPacketStream: """Streams in order: 1. [conditional] Retrieved documents if a search needs to be run 2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on @@ -277,10 +199,6 @@ def stream_chat_message_objects( query_message=final_msg, history=history_msgs, llm=llm ) - max_document_tokens = compute_max_document_tokens( - persona=persona, actual_user_input=message_text - ) - rephrased_query = None if reference_doc_ids: identifier_tuples = get_doc_query_identifiers_from_model( @@ -296,64 +214,8 @@ def stream_chat_message_objects( doc_identifiers=identifier_tuples, document_index=document_index, ) - - # truncate the last document if it exceeds the token limit - tokens_per_doc = [ - len( - llm_tokenizer_encode_func( - build_doc_context_str( - semantic_identifier=llm_doc.semantic_identifier, - source_type=llm_doc.source_type, - content=llm_doc.content, - metadata_dict=llm_doc.metadata, - updated_at=llm_doc.updated_at, - ind=ind, - ) - ) - ) - for ind, llm_doc in enumerate(llm_docs) - ] - final_doc_ind = None - total_tokens = 0 - for ind, tokens in enumerate(tokens_per_doc): - total_tokens += tokens - if total_tokens > max_document_tokens: - final_doc_ind = ind - break - if final_doc_ind is not None: - # only allow the final document to get truncated - # if more than that, then the user message is too long - if final_doc_ind != len(tokens_per_doc) - 1: - yield StreamingError( - error="LLM context window exceeded. Please de-select some documents or shorten your query." - ) - return - - final_doc_desired_length = tokens_per_doc[final_doc_ind] - ( - total_tokens - max_document_tokens - ) - # 75 tokens is a reasonable over-estimate of the metadata and title - final_doc_content_length = final_doc_desired_length - 75 - # this could occur if we only have space for the title / metadata - # not ideal, but it's the most reasonable thing to do - # NOTE: the frontend prevents documents from being selected if - # less than 75 tokens are available to try and avoid this situation - # from occuring in the first place - if final_doc_content_length <= 0: - logger.error( - f"Final doc ({llm_docs[final_doc_ind].semantic_identifier}) content " - "length is less than 0. Removing this doc from the final prompt." - ) - llm_docs.pop() - else: - llm_docs[final_doc_ind].content = tokenizer_trim_content( - content=llm_docs[final_doc_ind].content, - desired_length=final_doc_content_length, - tokenizer=llm_tokenizer, - ) - - doc_id_to_rank_map = map_document_id_order( - cast(list[InferenceChunk | LlmDoc], llm_docs) + document_pruning_config = DocumentPruningConfig( + is_manually_selected_docs=True ) # In case the search doc is deleted, just don't include it @@ -393,9 +255,6 @@ def stream_chat_message_objects( top_chunks = search_pipeline.reranked_docs top_docs = chunks_to_search_docs(top_chunks) - # Get ranking of the documents for citation purposes later - doc_id_to_rank_map = map_document_id_order(top_chunks) - reference_db_search_docs = [ create_db_search_doc(server_search_doc=top_doc, db_session=db_session) for top_doc in top_docs @@ -423,41 +282,21 @@ def stream_chat_message_objects( ) yield llm_relevance_filtering_response - # Prep chunks to pass to LLM - num_llm_chunks = ( - persona.num_chunks - if persona.num_chunks is not None - else default_num_chunks + document_pruning_config = DocumentPruningConfig( + max_chunks=int( + persona.num_chunks + if persona.num_chunks is not None + else default_num_chunks + ), + max_window_percentage=max_document_percentage, ) - llm_name = get_default_llm_version()[0] - if persona.llm_model_version_override: - llm_name = persona.llm_model_version_override - - llm_max_input_tokens = get_max_input_tokens(model_name=llm_name) - - llm_token_based_chunk_lim = max_document_percentage * llm_max_input_tokens - - chunk_token_limit = int( - min( - num_llm_chunks * default_chunk_size, - max_document_tokens, - llm_token_based_chunk_lim, - ) - ) - llm_chunks_indices = get_chunks_for_qa( - chunks=top_chunks, - llm_chunk_selection=search_pipeline.chunk_relevance_list, - token_limit=chunk_token_limit, - llm_tokenizer=llm_tokenizer, - ) - llm_chunks = [top_chunks[i] for i in llm_chunks_indices] - llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in llm_chunks] + llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in top_chunks] else: llm_docs = [] - doc_id_to_rank_map = {} reference_db_search_docs = None + document_pruning_config = DocumentPruningConfig() # Cannot determine these without the LLM step or breaking out early partial_response = partial( @@ -495,33 +334,24 @@ def stream_chat_message_objects( return # LLM prompt building, response capturing, etc. - response_packets = generate_ai_chat_response( - query_message=final_msg, - history=history_msgs, + answer = Answer( + question=final_msg.message, + docs=llm_docs, + answer_style_config=AnswerStyleConfig( + citation_config=CitationConfig( + all_docs_useful=reference_db_search_docs is not None + ), + document_pruning_config=document_pruning_config, + ), + prompt=final_msg.prompt, persona=persona, - context_docs=llm_docs, - doc_id_to_rank_map=doc_id_to_rank_map, - llm=llm, - llm_tokenizer_encode_func=llm_tokenizer_encode_func, - all_doc_useful=reference_doc_ids is not None, + message_history=[ + PreviousMessage.from_chat_message(msg) for msg in history_msgs + ], ) + # generator will not include quotes, so we can cast + yield from cast(ChatPacketStream, answer.processed_streamed_output) - # Capture outputs and errors - llm_output = "" - error: str | None = None - citations: list[CitationInfo] = [] - for packet in response_packets: - if isinstance(packet, DanswerAnswerPiece): - token = packet.answer_piece - if token: - llm_output += token - elif isinstance(packet, StreamingError): - error = packet.error - elif isinstance(packet, CitationInfo): - citations.append(packet) - continue - - yield packet except Exception as e: logger.exception(e) @@ -535,16 +365,16 @@ def stream_chat_message_objects( db_citations = None if reference_db_search_docs: db_citations = translate_citations( - citations_list=citations, + citations_list=answer.citations, db_docs=reference_db_search_docs, ) # Saving Gen AI answer and responding with message info gen_ai_response_message = partial_response( - message=llm_output, - token_count=len(llm_tokenizer_encode_func(llm_output)), + message=answer.llm_answer, + token_count=len(llm_tokenizer_encode_func(answer.llm_answer)), citations=db_citations, - error=error, + error=None, ) msg_detail_response = translate_db_message_to_chat_message_detail( diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index 1e065dd1da..b3fdb79c88 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -12,7 +12,6 @@ from slack_sdk.models.blocks import DividerBlock from sqlalchemy.orm import Session -from danswer.chat.chat_utils import compute_max_document_tokens from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER @@ -39,6 +38,7 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import SlackBotConfig from danswer.db.models import SlackBotResponseType +from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens from danswer.llm.utils import check_number_of_tokens from danswer.llm.utils import get_default_llm_version from danswer.llm.utils import get_max_input_tokens diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py new file mode 100644 index 0000000000..76d399d8bd --- /dev/null +++ b/backend/danswer/llm/answering/answer.py @@ -0,0 +1,176 @@ +from collections.abc import Iterator +from typing import cast + +from langchain.schema.messages import BaseMessage + +from danswer.chat.models import AnswerQuestionPossibleReturn +from danswer.chat.models import AnswerQuestionStreamReturn +from danswer.chat.models import CitationInfo +from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import LlmDoc +from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE +from danswer.configs.chat_configs import QA_TIMEOUT +from danswer.db.models import Persona +from danswer.db.models import Prompt +from danswer.llm.answering.doc_pruning import prune_documents +from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import PreviousMessage +from danswer.llm.answering.models import StreamProcessor +from danswer.llm.answering.prompts.citations_prompt import build_citations_prompt +from danswer.llm.answering.prompts.quotes_prompt import ( + build_quotes_prompt, +) +from danswer.llm.answering.stream_processing.citation_processing import ( + build_citation_processor, +) +from danswer.llm.answering.stream_processing.quotes_processing import ( + build_quotes_processor, +) +from danswer.llm.factory import get_default_llm +from danswer.llm.utils import get_default_llm_tokenizer + + +def _get_stream_processor( + docs: list[LlmDoc], answer_style_configs: AnswerStyleConfig +) -> StreamProcessor: + if answer_style_configs.citation_config: + return build_citation_processor( + context_docs=docs, + ) + if answer_style_configs.quotes_config: + return build_quotes_processor( + context_docs=docs, is_json_prompt=not (QA_PROMPT_OVERRIDE == "weak") + ) + + raise RuntimeError("Not implemented yet") + + +class Answer: + def __init__( + self, + question: str, + docs: list[LlmDoc], + answer_style_config: AnswerStyleConfig, + prompt: Prompt, + persona: Persona, + # must be the same length as `docs`. If None, all docs are considered "relevant" + doc_relevance_list: list[bool] | None = None, + message_history: list[PreviousMessage] | None = None, + single_message_history: str | None = None, + timeout: int = QA_TIMEOUT, + ) -> None: + if single_message_history and message_history: + raise ValueError( + "Cannot provide both `message_history` and `single_message_history`" + ) + + self.question = question + self.docs = docs + self.doc_relevance_list = doc_relevance_list + self.message_history = message_history or [] + # used for QA flow where we only want to send a single message + self.single_message_history = single_message_history + + self.answer_style_config = answer_style_config + + self.llm = get_default_llm( + gen_ai_model_version_override=persona.llm_model_version_override, + timeout=timeout, + ) + self.llm_tokenizer = get_default_llm_tokenizer() + + self.prompt = prompt + self.persona = persona + + self.process_stream_fn = _get_stream_processor(docs, answer_style_config) + + self._final_prompt: list[BaseMessage] | None = None + + self._pruned_docs: list[LlmDoc] | None = None + + self._streamed_output: list[str] | None = None + self._processed_stream: list[AnswerQuestionPossibleReturn] | None = None + + @property + def pruned_docs(self) -> list[LlmDoc]: + if self._pruned_docs is not None: + return self._pruned_docs + + self._pruned_docs = prune_documents( + docs=self.docs, + doc_relevance_list=self.doc_relevance_list, + persona=self.persona, + question=self.question, + document_pruning_config=self.answer_style_config.document_pruning_config, + ) + return self._pruned_docs + + @property + def final_prompt(self) -> list[BaseMessage]: + if self._final_prompt is not None: + return self._final_prompt + + if self.answer_style_config.citation_config: + self._final_prompt = build_citations_prompt( + question=self.question, + message_history=self.message_history, + persona=self.persona, + prompt=self.prompt, + context_docs=self.pruned_docs, + all_doc_useful=self.answer_style_config.citation_config.all_docs_useful, + llm_tokenizer_encode_func=self.llm_tokenizer.encode, + history_message=self.single_message_history or "", + ) + elif self.answer_style_config.quotes_config: + self._final_prompt = build_quotes_prompt( + question=self.question, + context_docs=self.pruned_docs, + history_str=self.single_message_history or "", + prompt=self.prompt, + ) + + return cast(list[BaseMessage], self._final_prompt) + + @property + def raw_streamed_output(self) -> Iterator[str]: + if self._streamed_output is not None: + yield from self._streamed_output + return + + streamed_output = [] + for message in self.llm.stream(self.final_prompt): + streamed_output.append(message) + yield message + + self._streamed_output = streamed_output + + @property + def processed_streamed_output(self) -> AnswerQuestionStreamReturn: + if self._processed_stream is not None: + yield from self._processed_stream + return + + processed_stream = [] + for processed_packet in self.process_stream_fn(self.raw_streamed_output): + processed_stream.append(processed_packet) + yield processed_packet + + self._processed_stream = processed_stream + + @property + def llm_answer(self) -> str: + answer = "" + for packet in self.processed_streamed_output: + if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece: + answer += packet.answer_piece + + return answer + + @property + def citations(self) -> list[CitationInfo]: + citations: list[CitationInfo] = [] + for packet in self.processed_streamed_output: + if isinstance(packet, CitationInfo): + citations.append(packet) + + return citations diff --git a/backend/danswer/llm/answering/doc_pruning.py b/backend/danswer/llm/answering/doc_pruning.py new file mode 100644 index 0000000000..29c913673d --- /dev/null +++ b/backend/danswer/llm/answering/doc_pruning.py @@ -0,0 +1,205 @@ +from copy import deepcopy +from typing import TypeVar + +from danswer.chat.models import ( + LlmDoc, +) +from danswer.configs.constants import IGNORE_FOR_QA +from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE +from danswer.db.models import Persona +from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.models import DocumentPruningConfig +from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens +from danswer.llm.utils import get_default_llm_tokenizer +from danswer.llm.utils import tokenizer_trim_content +from danswer.prompts.prompt_utils import build_doc_context_str +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + +T = TypeVar("T", bound=LlmDoc | InferenceChunk) + +_METADATA_TOKEN_ESTIMATE = 75 + + +class PruningError(Exception): + pass + + +def _compute_limit( + persona: Persona, + question: str, + max_chunks: int | None, + max_window_percentage: float | None, + max_tokens: int | None, +) -> int: + llm_max_document_tokens = compute_max_document_tokens( + persona=persona, actual_user_input=question + ) + + window_percentage_based_limit = ( + max_window_percentage * llm_max_document_tokens + if max_window_percentage + else None + ) + chunk_count_based_limit = ( + max_chunks * DOC_EMBEDDING_CONTEXT_SIZE if max_chunks else None + ) + + limit_options = [ + lim + for lim in [ + window_percentage_based_limit, + chunk_count_based_limit, + max_tokens, + llm_max_document_tokens, + ] + if lim + ] + return int(min(limit_options)) + + +def reorder_docs( + docs: list[T], + doc_relevance_list: list[bool] | None, +) -> list[T]: + if doc_relevance_list is None: + return docs + + reordered_docs: list[T] = [] + if doc_relevance_list is not None: + for selection_target in [True, False]: + for doc, is_relevant in zip(docs, doc_relevance_list): + if is_relevant == selection_target: + reordered_docs.append(doc) + return reordered_docs + + +def _remove_docs_to_ignore(docs: list[LlmDoc]) -> list[LlmDoc]: + return [doc for doc in docs if not doc.metadata.get(IGNORE_FOR_QA)] + + +def _apply_pruning( + docs: list[LlmDoc], + doc_relevance_list: list[bool] | None, + token_limit: int, + is_manually_selected_docs: bool, +) -> list[LlmDoc]: + llm_tokenizer = get_default_llm_tokenizer() + docs = deepcopy(docs) # don't modify in place + + # re-order docs with all the "relevant" docs at the front + docs = reorder_docs(docs=docs, doc_relevance_list=doc_relevance_list) + # remove docs that are explicitly marked as not for QA + docs = _remove_docs_to_ignore(docs=docs) + + tokens_per_doc: list[int] = [] + final_doc_ind = None + total_tokens = 0 + for ind, llm_doc in enumerate(docs): + doc_tokens = len( + llm_tokenizer.encode( + build_doc_context_str( + semantic_identifier=llm_doc.semantic_identifier, + source_type=llm_doc.source_type, + content=llm_doc.content, + metadata_dict=llm_doc.metadata, + updated_at=llm_doc.updated_at, + ind=ind, + ) + ) + ) + # if chunks, truncate chunks that are way too long + # this can happen if the embedding model tokenizer is different + # than the LLM tokenizer + if ( + not is_manually_selected_docs + and doc_tokens > DOC_EMBEDDING_CONTEXT_SIZE + _METADATA_TOKEN_ESTIMATE + ): + logger.warning( + "Found more tokens in chunk than expected, " + "likely mismatch between embedding and LLM tokenizers. Trimming content..." + ) + llm_doc.content = tokenizer_trim_content( + content=llm_doc.content, + desired_length=DOC_EMBEDDING_CONTEXT_SIZE, + tokenizer=llm_tokenizer, + ) + doc_tokens = DOC_EMBEDDING_CONTEXT_SIZE + tokens_per_doc.append(doc_tokens) + total_tokens += doc_tokens + if total_tokens > token_limit: + final_doc_ind = ind + break + + if final_doc_ind is not None: + if is_manually_selected_docs: + # for document selection, only allow the final document to get truncated + # if more than that, then the user message is too long + if final_doc_ind != len(docs) - 1: + raise PruningError( + "LLM context window exceeded. Please de-select some documents or shorten your query." + ) + + final_doc_desired_length = tokens_per_doc[final_doc_ind] - ( + total_tokens - token_limit + ) + final_doc_content_length = ( + final_doc_desired_length - _METADATA_TOKEN_ESTIMATE + ) + # this could occur if we only have space for the title / metadata + # not ideal, but it's the most reasonable thing to do + # NOTE: the frontend prevents documents from being selected if + # less than 75 tokens are available to try and avoid this situation + # from occuring in the first place + if final_doc_content_length <= 0: + logger.error( + f"Final doc ({docs[final_doc_ind].semantic_identifier}) content " + "length is less than 0. Removing this doc from the final prompt." + ) + docs.pop() + else: + docs[final_doc_ind].content = tokenizer_trim_content( + content=docs[final_doc_ind].content, + desired_length=final_doc_content_length, + tokenizer=llm_tokenizer, + ) + else: + # for regular search, don't truncate the final document unless it's the only one + if final_doc_ind != 0: + docs = docs[:final_doc_ind] + else: + docs[0].content = tokenizer_trim_content( + content=docs[0].content, + desired_length=token_limit - _METADATA_TOKEN_ESTIMATE, + tokenizer=llm_tokenizer, + ) + docs = [docs[0]] + + return docs + + +def prune_documents( + docs: list[LlmDoc], + doc_relevance_list: list[bool] | None, + persona: Persona, + question: str, + document_pruning_config: DocumentPruningConfig, +) -> list[LlmDoc]: + if doc_relevance_list is not None: + assert len(docs) == len(doc_relevance_list) + + doc_token_limit = _compute_limit( + persona=persona, + question=question, + max_chunks=document_pruning_config.max_chunks, + max_window_percentage=document_pruning_config.max_window_percentage, + max_tokens=document_pruning_config.max_tokens, + ) + return _apply_pruning( + docs=docs, + doc_relevance_list=doc_relevance_list, + token_limit=doc_token_limit, + is_manually_selected_docs=document_pruning_config.is_manually_selected_docs, + ) diff --git a/backend/danswer/llm/answering/models.py b/backend/danswer/llm/answering/models.py new file mode 100644 index 0000000000..360535ac80 --- /dev/null +++ b/backend/danswer/llm/answering/models.py @@ -0,0 +1,77 @@ +from collections.abc import Callable +from collections.abc import Iterator +from typing import Any +from typing import TYPE_CHECKING + +from pydantic import BaseModel +from pydantic import Field +from pydantic import root_validator + +from danswer.chat.models import AnswerQuestionStreamReturn +from danswer.configs.constants import MessageType + +if TYPE_CHECKING: + from danswer.db.models import ChatMessage + + +StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn] + + +class PreviousMessage(BaseModel): + """Simplified version of `ChatMessage`""" + + message: str + token_count: int + message_type: MessageType + + @classmethod + def from_chat_message(cls, chat_message: "ChatMessage") -> "PreviousMessage": + return cls( + message=chat_message.message, + token_count=chat_message.token_count, + message_type=chat_message.message_type, + ) + + +class DocumentPruningConfig(BaseModel): + max_chunks: int | None = None + max_window_percentage: float | None = None + max_tokens: int | None = None + # different pruning behavior is expected when the + # user manually selects documents they want to chat with + # e.g. we don't want to truncate each document to be no more + # than one chunk long + is_manually_selected_docs: bool = False + + +class CitationConfig(BaseModel): + all_docs_useful: bool = False + + +class QuotesConfig(BaseModel): + pass + + +class AnswerStyleConfig(BaseModel): + citation_config: CitationConfig | None = None + quotes_config: QuotesConfig | None = None + document_pruning_config: DocumentPruningConfig = Field( + default_factory=DocumentPruningConfig + ) + + @root_validator + def check_quotes_and_citation(cls, values: dict[str, Any]) -> dict[str, Any]: + citation_config = values.get("citation_config") + quotes_config = values.get("quotes_config") + + if citation_config is None and quotes_config is None: + raise ValueError( + "One of `citation_config` or `quotes_config` must be provided" + ) + + if citation_config is not None and quotes_config is not None: + raise ValueError( + "Only one of `citation_config` or `quotes_config` must be provided" + ) + + return values diff --git a/backend/danswer/llm/answering/prompts/citations_prompt.py b/backend/danswer/llm/answering/prompts/citations_prompt.py new file mode 100644 index 0000000000..61c42c19c7 --- /dev/null +++ b/backend/danswer/llm/answering/prompts/citations_prompt.py @@ -0,0 +1,281 @@ +from collections.abc import Callable +from functools import lru_cache +from typing import cast + +from langchain.schema.messages import BaseMessage +from langchain.schema.messages import HumanMessage +from langchain.schema.messages import SystemMessage + +from danswer.chat.models import LlmDoc +from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION +from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS +from danswer.db.chat import get_default_prompt +from danswer.db.models import Persona +from danswer.db.models import Prompt +from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.models import PreviousMessage +from danswer.llm.utils import check_number_of_tokens +from danswer.llm.utils import get_default_llm_tokenizer +from danswer.llm.utils import get_default_llm_version +from danswer.llm.utils import get_max_input_tokens +from danswer.llm.utils import translate_history_to_basemessages +from danswer.prompts.chat_prompts import ADDITIONAL_INFO +from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT +from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT +from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT +from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT +from danswer.prompts.direct_qa_prompts import ( + CITATIONS_PROMPT, +) +from danswer.prompts.prompt_utils import build_complete_context_str +from danswer.prompts.prompt_utils import build_task_prompt_reminders +from danswer.prompts.prompt_utils import get_current_llm_day_time +from danswer.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT +from danswer.prompts.token_counts import ( + CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT, +) +from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT +from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT +from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT + + +_PER_MESSAGE_TOKEN_BUFFER = 7 + + +def find_last_index(lst: list[int], max_prompt_tokens: int) -> int: + """From the back, find the index of the last element to include + before the list exceeds the maximum""" + running_sum = 0 + + last_ind = 0 + for i in range(len(lst) - 1, -1, -1): + running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER + if running_sum > max_prompt_tokens: + last_ind = i + 1 + break + if last_ind >= len(lst): + raise ValueError("Last message alone is too large!") + return last_ind + + +def drop_messages_history_overflow( + system_msg: BaseMessage | None, + system_token_count: int, + history_msgs: list[BaseMessage], + history_token_counts: list[int], + final_msg: BaseMessage, + final_msg_token_count: int, + max_allowed_tokens: int, +) -> list[BaseMessage]: + """As message history grows, messages need to be dropped starting from the furthest in the past. + The System message should be kept if at all possible and the latest user input which is inserted in the + prompt template must be included""" + if len(history_msgs) != len(history_token_counts): + # This should never happen + raise ValueError("Need exactly 1 token count per message for tracking overflow") + + prompt: list[BaseMessage] = [] + + # Start dropping from the history if necessary + all_tokens = history_token_counts + [system_token_count, final_msg_token_count] + ind_prev_msg_start = find_last_index( + all_tokens, max_prompt_tokens=max_allowed_tokens + ) + + if system_msg and ind_prev_msg_start <= len(history_msgs): + prompt.append(system_msg) + + prompt.extend(history_msgs[ind_prev_msg_start:]) + + prompt.append(final_msg) + + return prompt + + +def get_prompt_tokens(prompt: Prompt) -> int: + # Note: currently custom prompts do not allow datetime aware, only default prompts + return ( + check_number_of_tokens(prompt.system_prompt) + + check_number_of_tokens(prompt.task_prompt) + + CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT + + CITATION_STATEMENT_TOKEN_CNT + + CITATION_REMINDER_TOKEN_CNT + + (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0) + + (ADDITIONAL_INFO_TOKEN_CNT if prompt.datetime_aware else 0) + ) + + +# buffer just to be safe so that we don't overflow the token limit due to +# a small miscalculation +_MISC_BUFFER = 40 + + +def compute_max_document_tokens( + persona: Persona, + actual_user_input: str | None = None, + max_llm_token_override: int | None = None, +) -> int: + """Estimates the number of tokens available for context documents. Formula is roughly: + + ( + model_context_window - reserved_output_tokens - prompt_tokens + - (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe) + ) + + The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g. + if we're trying to determine if the user should be able to select another document) then we just set an + arbitrary "upper bound". + """ + llm_name = get_default_llm_version()[0] + if persona.llm_model_version_override: + llm_name = persona.llm_model_version_override + + # if we can't find a number of tokens, just assume some common default + max_input_tokens = ( + max_llm_token_override + if max_llm_token_override + else get_max_input_tokens(model_name=llm_name) + ) + if persona.prompts: + # TODO this may not always be the first prompt + prompt_tokens = get_prompt_tokens(persona.prompts[0]) + else: + prompt_tokens = get_prompt_tokens(get_default_prompt()) + + user_input_tokens = ( + check_number_of_tokens(actual_user_input) + if actual_user_input is not None + else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS + ) + + return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER + + +def compute_max_llm_input_tokens(persona: Persona) -> int: + """Maximum tokens allows in the input to the LLM (of any type).""" + llm_name = get_default_llm_version()[0] + if persona.llm_model_version_override: + llm_name = persona.llm_model_version_override + + input_tokens = get_max_input_tokens(model_name=llm_name) + return input_tokens - _MISC_BUFFER + + +@lru_cache() +def build_system_message( + prompt: Prompt, + context_exists: bool, + llm_tokenizer_encode_func: Callable, + citation_line: str = REQUIRE_CITATION_STATEMENT, + no_citation_line: str = NO_CITATION_STATEMENT, +) -> tuple[SystemMessage | None, int]: + system_prompt = prompt.system_prompt.strip() + if prompt.include_citations: + if context_exists: + system_prompt += citation_line + else: + system_prompt += no_citation_line + if prompt.datetime_aware: + if system_prompt: + system_prompt += ADDITIONAL_INFO.format( + datetime_info=get_current_llm_day_time() + ) + else: + system_prompt = get_current_llm_day_time() + + if not system_prompt: + return None, 0 + + token_count = len(llm_tokenizer_encode_func(system_prompt)) + system_msg = SystemMessage(content=system_prompt) + + return system_msg, token_count + + +def build_user_message( + question: str, + prompt: Prompt, + context_docs: list[LlmDoc] | list[InferenceChunk], + all_doc_useful: bool, + history_message: str, +) -> tuple[HumanMessage, int]: + llm_tokenizer = get_default_llm_tokenizer() + llm_tokenizer_encode_func = cast(Callable[[str], list[int]], llm_tokenizer.encode) + + if not context_docs: + # Simpler prompt for cases where there is no context + user_prompt = ( + CHAT_USER_CONTEXT_FREE_PROMPT.format( + task_prompt=prompt.task_prompt, user_query=question + ) + if prompt.task_prompt + else question + ) + user_prompt = user_prompt.strip() + token_count = len(llm_tokenizer_encode_func(user_prompt)) + user_msg = HumanMessage(content=user_prompt) + return user_msg, token_count + + context_docs_str = build_complete_context_str(context_docs) + optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT + + task_prompt_with_reminder = build_task_prompt_reminders(prompt) + + user_prompt = CITATIONS_PROMPT.format( + optional_ignore_statement=optional_ignore, + context_docs_str=context_docs_str, + task_prompt=task_prompt_with_reminder, + user_query=question, + history_block=history_message, + ) + + user_prompt = user_prompt.strip() + token_count = len(llm_tokenizer_encode_func(user_prompt)) + user_msg = HumanMessage(content=user_prompt) + + return user_msg, token_count + + +def build_citations_prompt( + question: str, + message_history: list[PreviousMessage], + persona: Persona, + prompt: Prompt, + context_docs: list[LlmDoc] | list[InferenceChunk], + all_doc_useful: bool, + history_message: str, + llm_tokenizer_encode_func: Callable, +) -> list[BaseMessage]: + context_exists = len(context_docs) > 0 + + system_message_or_none, system_tokens = build_system_message( + prompt=prompt, + context_exists=context_exists, + llm_tokenizer_encode_func=llm_tokenizer_encode_func, + ) + + history_basemessages, history_token_counts = translate_history_to_basemessages( + message_history + ) + + # Be sure the context_docs passed to build_chat_user_message + # Is the same as passed in later for extracting citations + user_message, user_tokens = build_user_message( + question=question, + prompt=prompt, + context_docs=context_docs, + all_doc_useful=all_doc_useful, + history_message=history_message, + ) + + final_prompt_msgs = drop_messages_history_overflow( + system_msg=system_message_or_none, + system_token_count=system_tokens, + history_msgs=history_basemessages, + history_token_counts=history_token_counts, + final_msg=user_message, + final_msg_token_count=user_tokens, + max_allowed_tokens=compute_max_llm_input_tokens(persona), + ) + + return final_prompt_msgs diff --git a/backend/danswer/llm/answering/prompts/quotes_prompt.py b/backend/danswer/llm/answering/prompts/quotes_prompt.py new file mode 100644 index 0000000000..c9e145e810 --- /dev/null +++ b/backend/danswer/llm/answering/prompts/quotes_prompt.py @@ -0,0 +1,88 @@ +from langchain.schema.messages import BaseMessage +from langchain.schema.messages import HumanMessage + +from danswer.chat.models import LlmDoc +from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION +from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE +from danswer.db.models import Prompt +from danswer.indexing.models import InferenceChunk +from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK +from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK +from danswer.prompts.direct_qa_prompts import JSON_PROMPT +from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT +from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT +from danswer.prompts.prompt_utils import build_complete_context_str + + +def _build_weak_llm_quotes_prompt( + question: str, + context_docs: list[LlmDoc] | list[InferenceChunk], + history_str: str, + prompt: Prompt, + use_language_hint: bool, +) -> list[BaseMessage]: + """Since Danswer supports a variety of LLMs, this less demanding prompt is provided + as an option to use with weaker LLMs such as small version, low float precision, quantized, + or distilled models. It only uses one context document and has very weak requirements of + output format. + """ + context_block = "" + if context_docs: + context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs[0].content) + + prompt_str = WEAK_LLM_PROMPT.format( + system_prompt=prompt.system_prompt, + context_block=context_block, + task_prompt=prompt.task_prompt, + user_query=question, + ) + return [HumanMessage(content=prompt_str)] + + +def _build_strong_llm_quotes_prompt( + question: str, + context_docs: list[LlmDoc] | list[InferenceChunk], + history_str: str, + prompt: Prompt, + use_language_hint: bool, +) -> list[BaseMessage]: + context_block = "" + if context_docs: + context_docs_str = build_complete_context_str(context_docs) + context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str) + + history_block = "" + if history_str: + history_block = HISTORY_BLOCK.format(history_str=history_str) + + full_prompt = JSON_PROMPT.format( + system_prompt=prompt.system_prompt, + context_block=context_block, + history_block=history_block, + task_prompt=prompt.task_prompt, + user_query=question, + language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "", + ).strip() + return [HumanMessage(content=full_prompt)] + + +def build_quotes_prompt( + question: str, + context_docs: list[LlmDoc] | list[InferenceChunk], + history_str: str, + prompt: Prompt, + use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), +) -> list[BaseMessage]: + prompt_builder = ( + _build_weak_llm_quotes_prompt + if QA_PROMPT_OVERRIDE == "weak" + else _build_strong_llm_quotes_prompt + ) + + return prompt_builder( + question=question, + context_docs=context_docs, + history_str=history_str, + prompt=prompt, + use_language_hint=use_language_hint, + ) diff --git a/backend/danswer/llm/answering/prompts/utils.py b/backend/danswer/llm/answering/prompts/utils.py new file mode 100644 index 0000000000..bcc8b89181 --- /dev/null +++ b/backend/danswer/llm/answering/prompts/utils.py @@ -0,0 +1,20 @@ +from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT +from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT + + +def build_dummy_prompt( + system_prompt: str, task_prompt: str, retrieval_disabled: bool +) -> str: + if retrieval_disabled: + return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format( + user_query="", + system_prompt=system_prompt, + task_prompt=task_prompt, + ).strip() + + return PARAMATERIZED_PROMPT.format( + context_docs_str="", + user_query="", + system_prompt=system_prompt, + task_prompt=task_prompt, + ).strip() diff --git a/backend/danswer/llm/answering/stream_processing/citation_processing.py b/backend/danswer/llm/answering/stream_processing/citation_processing.py new file mode 100644 index 0000000000..a26021835c --- /dev/null +++ b/backend/danswer/llm/answering/stream_processing/citation_processing.py @@ -0,0 +1,126 @@ +import re +from collections.abc import Iterator + +from danswer.chat.models import AnswerQuestionStreamReturn +from danswer.chat.models import CitationInfo +from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import LlmDoc +from danswer.configs.chat_configs import STOP_STREAM_PAT +from danswer.llm.answering.models import StreamProcessor +from danswer.llm.answering.stream_processing.utils import map_document_id_order +from danswer.prompts.constants import TRIPLE_BACKTICK +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + + +def in_code_block(llm_text: str) -> bool: + count = llm_text.count(TRIPLE_BACKTICK) + return count % 2 != 0 + + +def extract_citations_from_stream( + tokens: Iterator[str], + context_docs: list[LlmDoc], + doc_id_to_rank_map: dict[str, int], + stop_stream: str | None = STOP_STREAM_PAT, +) -> Iterator[DanswerAnswerPiece | CitationInfo]: + llm_out = "" + max_citation_num = len(context_docs) + curr_segment = "" + prepend_bracket = False + cited_inds = set() + hold = "" + for raw_token in tokens: + if stop_stream: + next_hold = hold + raw_token + + if stop_stream in next_hold: + break + + if next_hold == stop_stream[: len(next_hold)]: + hold = next_hold + continue + + token = next_hold + hold = "" + else: + token = raw_token + + # Special case of [1][ where ][ is a single token + # This is where the model attempts to do consecutive citations like [1][2] + if prepend_bracket: + curr_segment += "[" + curr_segment + prepend_bracket = False + + curr_segment += token + llm_out += token + + possible_citation_pattern = r"(\[\d*$)" # [1, [, etc + possible_citation_found = re.search(possible_citation_pattern, curr_segment) + + citation_pattern = r"\[(\d+)\]" # [1], [2] etc + citation_found = re.search(citation_pattern, curr_segment) + + if citation_found and not in_code_block(llm_out): + numerical_value = int(citation_found.group(1)) + if 1 <= numerical_value <= max_citation_num: + context_llm_doc = context_docs[ + numerical_value - 1 + ] # remove 1 index offset + + link = context_llm_doc.link + target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id] + + # Use the citation number for the document's rank in + # the search (or selected docs) results + curr_segment = re.sub( + rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment + ) + + if target_citation_num not in cited_inds: + cited_inds.add(target_citation_num) + yield CitationInfo( + citation_num=target_citation_num, + document_id=context_llm_doc.document_id, + ) + + if link: + curr_segment = re.sub(r"\[", "[[", curr_segment, count=1) + curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1) + + # In case there's another open bracket like [1][, don't want to match this + possible_citation_found = None + + # if we see "[", but haven't seen the right side, hold back - this may be a + # citation that needs to be replaced with a link + if possible_citation_found: + continue + + # Special case with back to back citations [1][2] + if curr_segment and curr_segment[-1] == "[": + curr_segment = curr_segment[:-1] + prepend_bracket = True + + yield DanswerAnswerPiece(answer_piece=curr_segment) + curr_segment = "" + + if curr_segment: + if prepend_bracket: + yield DanswerAnswerPiece(answer_piece="[" + curr_segment) + else: + yield DanswerAnswerPiece(answer_piece=curr_segment) + + +def build_citation_processor( + context_docs: list[LlmDoc], +) -> StreamProcessor: + def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn: + yield from extract_citations_from_stream( + tokens=tokens, + context_docs=context_docs, + doc_id_to_rank_map=map_document_id_order(context_docs), + ) + + return stream_processor diff --git a/backend/danswer/llm/answering/stream_processing/quotes_processing.py b/backend/danswer/llm/answering/stream_processing/quotes_processing.py new file mode 100644 index 0000000000..daa966e694 --- /dev/null +++ b/backend/danswer/llm/answering/stream_processing/quotes_processing.py @@ -0,0 +1,282 @@ +import math +import re +from collections.abc import Callable +from collections.abc import Generator +from collections.abc import Iterator +from json import JSONDecodeError +from typing import Optional + +import regex + +from danswer.chat.models import AnswerQuestionStreamReturn +from danswer.chat.models import DanswerAnswer +from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import DanswerQuote +from danswer.chat.models import DanswerQuotes +from danswer.chat.models import LlmDoc +from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT +from danswer.indexing.models import InferenceChunk +from danswer.prompts.constants import ANSWER_PAT +from danswer.prompts.constants import QUOTE_PAT +from danswer.prompts.constants import UNCERTAINTY_PAT +from danswer.utils.logger import setup_logger +from danswer.utils.text_processing import clean_model_quote +from danswer.utils.text_processing import clean_up_code_blocks +from danswer.utils.text_processing import extract_embedded_json +from danswer.utils.text_processing import shared_precompare_cleanup + + +logger = setup_logger() + + +def _extract_answer_quotes_freeform( + answer_raw: str, +) -> tuple[Optional[str], Optional[list[str]]]: + """Splits the model output into an Answer and 0 or more Quote sections. + Splits by the Quote pattern, if not exist then assume it's all answer and no quotes + """ + # If no answer section, don't care about the quote + if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()): + return None, None + + # Sometimes model regenerates the Answer: pattern despite it being provided in the prompt + if answer_raw.lower().startswith(ANSWER_PAT.lower()): + answer_raw = answer_raw[len(ANSWER_PAT) :] + + # Accept quote sections starting with the lower case version + answer_raw = answer_raw.replace( + f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}" + ) # Just in case model unreliable + + sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw) + sections_clean = [ + str(section).strip() for section in sections if str(section).strip() + ] + if not sections_clean: + return None, None + + answer = str(sections_clean[0]) + if len(sections) == 1: + return answer, None + return answer, sections_clean[1:] + + +def _extract_answer_quotes_json( + answer_dict: dict[str, str | list[str]] +) -> tuple[Optional[str], Optional[list[str]]]: + answer_dict = {k.lower(): v for k, v in answer_dict.items()} + answer = str(answer_dict.get("answer")) + quotes = answer_dict.get("quotes") or answer_dict.get("quote") + if isinstance(quotes, str): + quotes = [quotes] + return answer, quotes + + +def _extract_answer_json(raw_model_output: str) -> dict: + try: + answer_json = extract_embedded_json(raw_model_output) + except (ValueError, JSONDecodeError): + # LLMs get confused when handling the list in the json. Sometimes it doesn't attend + # enough to the previous { token so it just ends the list of quotes and stops there + # here, we add logic to try to fix this LLM error. + answer_json = extract_embedded_json(raw_model_output + "}") + + if "answer" not in answer_json: + raise ValueError("Model did not output an answer as expected.") + + return answer_json + + +def match_quotes_to_docs( + quotes: list[str], + docs: list[LlmDoc] | list[InferenceChunk], + max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT, + fuzzy_search: bool = False, + prefix_only_length: int = 100, +) -> DanswerQuotes: + danswer_quotes: list[DanswerQuote] = [] + for quote in quotes: + max_edits = math.ceil(float(len(quote)) * max_error_percent) + + for doc in docs: + if not doc.source_links: + continue + + quote_clean = shared_precompare_cleanup( + clean_model_quote(quote, trim_length=prefix_only_length) + ) + chunk_clean = shared_precompare_cleanup(doc.content) + + # Finding the offset of the quote in the plain text + if fuzzy_search: + re_search_str = ( + r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}" + ) + found = regex.search(re_search_str, chunk_clean) + if not found: + continue + offset = found.span()[0] + else: + if quote_clean not in chunk_clean: + continue + offset = chunk_clean.index(quote_clean) + + # Extracting the link from the offset + curr_link = None + for link_offset, link in doc.source_links.items(): + # Should always find one because offset is at least 0 and there + # must be a 0 link_offset + if int(link_offset) <= offset: + curr_link = link + else: + break + + danswer_quotes.append( + DanswerQuote( + quote=quote, + document_id=doc.document_id, + link=curr_link, + source_type=doc.source_type, + semantic_identifier=doc.semantic_identifier, + blurb=doc.blurb, + ) + ) + break + + return DanswerQuotes(quotes=danswer_quotes) + + +def separate_answer_quotes( + answer_raw: str, is_json_prompt: bool = False +) -> tuple[Optional[str], Optional[list[str]]]: + """Takes in a raw model output and pulls out the answer and the quotes sections.""" + if is_json_prompt: + model_raw_json = _extract_answer_json(answer_raw) + return _extract_answer_quotes_json(model_raw_json) + + return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw)) + + +def process_answer( + answer_raw: str, + docs: list[LlmDoc], + is_json_prompt: bool = True, +) -> tuple[DanswerAnswer, DanswerQuotes]: + """Used (1) in the non-streaming case to process the model output + into an Answer and Quotes AND (2) after the complete streaming response + has been received to process the model output into an Answer and Quotes.""" + answer, quote_strings = separate_answer_quotes(answer_raw, is_json_prompt) + if answer == UNCERTAINTY_PAT or not answer: + if answer == UNCERTAINTY_PAT: + logger.debug("Answer matched UNCERTAINTY_PAT") + else: + logger.debug("No answer extracted from raw output") + return DanswerAnswer(answer=None), DanswerQuotes(quotes=[]) + + logger.info(f"Answer: {answer}") + if not quote_strings: + logger.debug("No quotes extracted from raw output") + return DanswerAnswer(answer=answer), DanswerQuotes(quotes=[]) + logger.info(f"All quotes (including unmatched): {quote_strings}") + quotes = match_quotes_to_docs(quote_strings, docs) + logger.debug(f"Final quotes: {quotes}") + + return DanswerAnswer(answer=answer), quotes + + +def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool: + next_token = next_token.replace('\\"', "") + # If the previous character is an escape token, don't consider the first character of next_token + # This does not work if it's an escaped escape sign before the " but this is rare, not worth handling + if answer_so_far and answer_so_far[-1] == "\\": + next_token = next_token[1:] + if '"' in next_token: + return True + return False + + +def _extract_quotes_from_completed_token_stream( + model_output: str, context_docs: list[LlmDoc], is_json_prompt: bool = True +) -> DanswerQuotes: + answer, quotes = process_answer(model_output, context_docs, is_json_prompt) + if answer: + logger.info(answer) + elif model_output: + logger.warning("Answer extraction from model output failed.") + + return quotes + + +def process_model_tokens( + tokens: Iterator[str], + context_docs: list[LlmDoc], + is_json_prompt: bool = True, +) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]: + """Used in the streaming case to process the model output + into an Answer and Quotes + + Yields Answer tokens back out in a dict for streaming to frontend + When Answer section ends, yields dict with answer_finished key + Collects all the tokens at the end to form the complete model output""" + quote_pat = f"\n{QUOTE_PAT}" + # Sometimes worse model outputs new line instead of : + quote_loose = f"\n{quote_pat[:-1]}\n" + # Sometime model outputs two newlines before quote section + quote_pat_full = f"\n{quote_pat}" + model_output: str = "" + found_answer_start = False if is_json_prompt else True + found_answer_end = False + hold_quote = "" + for token in tokens: + model_previous = model_output + model_output += token + + if not found_answer_start and '{"answer":"' in re.sub(r"\s", "", model_output): + # Note, if the token that completes the pattern has additional text, for example if the token is "? + # Then the chars after " will not be streamed, but this is ok as it prevents streaming the ? in the + # event that the model outputs the UNCERTAINTY_PAT + found_answer_start = True + + # Prevent heavy cases of hallucinations where model is not even providing a json until later + if is_json_prompt and len(model_output) > 40: + logger.warning("LLM did not produce json as prompted") + found_answer_end = True + + continue + + if found_answer_start and not found_answer_end: + if is_json_prompt and _stream_json_answer_end(model_previous, token): + found_answer_end = True + yield DanswerAnswerPiece(answer_piece=None) + continue + elif not is_json_prompt: + if quote_pat in hold_quote + token or quote_loose in hold_quote + token: + found_answer_end = True + yield DanswerAnswerPiece(answer_piece=None) + continue + if hold_quote + token in quote_pat_full: + hold_quote += token + continue + yield DanswerAnswerPiece(answer_piece=hold_quote + token) + hold_quote = "" + + logger.debug(f"Raw Model QnA Output: {model_output}") + + yield _extract_quotes_from_completed_token_stream( + model_output=model_output, + context_docs=context_docs, + is_json_prompt=is_json_prompt, + ) + + +def build_quotes_processor( + context_docs: list[LlmDoc], is_json_prompt: bool +) -> Callable[[Iterator[str]], AnswerQuestionStreamReturn]: + def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn: + yield from process_model_tokens( + tokens=tokens, + context_docs=context_docs, + is_json_prompt=is_json_prompt, + ) + + return stream_processor diff --git a/backend/danswer/llm/answering/stream_processing/utils.py b/backend/danswer/llm/answering/stream_processing/utils.py new file mode 100644 index 0000000000..1ddcdf605e --- /dev/null +++ b/backend/danswer/llm/answering/stream_processing/utils.py @@ -0,0 +1,17 @@ +from collections.abc import Sequence + +from danswer.chat.models import LlmDoc +from danswer.indexing.models import InferenceChunk + + +def map_document_id_order( + chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True +) -> dict[str, int]: + order_mapping = {} + current = 1 if one_indexed else 0 + for chunk in chunks: + if chunk.document_id not in order_mapping: + order_mapping[chunk.document_id] = current + current += 1 + + return order_mapping diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index f36f285461..c07b708bb5 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -33,6 +33,7 @@ from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM from danswer.utils.logger import setup_logger @@ -114,7 +115,9 @@ def tokenizer_trim_chunks( return new_chunks -def translate_danswer_msg_to_langchain(msg: ChatMessage) -> BaseMessage: +def translate_danswer_msg_to_langchain( + msg: ChatMessage | PreviousMessage, +) -> BaseMessage: if msg.message_type == MessageType.SYSTEM: raise ValueError("System messages are not currently part of history") if msg.message_type == MessageType.ASSISTANT: @@ -126,7 +129,7 @@ def translate_danswer_msg_to_langchain(msg: ChatMessage) -> BaseMessage: def translate_history_to_basemessages( - history: list[ChatMessage], + history: list[ChatMessage] | list[PreviousMessage], ) -> tuple[list[BaseMessage], list[int]]: history_basemessages = [ translate_danswer_msg_to_langchain(msg) diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index db5ef6f0f9..e863f4ac09 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -1,54 +1,37 @@ -import itertools from collections.abc import Callable from collections.abc import Iterator -from langchain.schema.messages import BaseMessage -from langchain.schema.messages import HumanMessage from sqlalchemy.orm import Session -from danswer.chat.chat_utils import build_chat_system_message -from danswer.chat.chat_utils import compute_max_document_tokens -from danswer.chat.chat_utils import extract_citations_from_stream -from danswer.chat.chat_utils import get_chunks_for_qa from danswer.chat.chat_utils import llm_doc_from_inference_chunk -from danswer.chat.chat_utils import map_document_id_order from danswer.chat.chat_utils import reorganize_citations from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece -from danswer.chat.models import DanswerContext from danswer.chat.models import DanswerContexts from danswer.chat.models import DanswerQuotes -from danswer.chat.models import LLMMetricsContainer from danswer.chat.models import LLMRelevanceFilterResponse from danswer.chat.models import QADocsResponse from danswer.chat.models import StreamingError from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from danswer.configs.chat_configs import QA_TIMEOUT from danswer.configs.constants import MessageType -from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.db.chat import create_chat_session from danswer.db.chat import create_new_chat_message from danswer.db.chat import get_or_create_root_message -from danswer.db.chat import get_persona_by_id from danswer.db.chat import get_prompt_by_id from danswer.db.chat import translate_db_message_to_chat_message_detail from danswer.db.engine import get_session_context_manager -from danswer.db.models import Prompt from danswer.db.models import User -from danswer.indexing.models import InferenceChunk -from danswer.llm.factory import get_default_llm +from danswer.llm.answering.answer import Answer +from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import CitationConfig +from danswer.llm.answering.models import DocumentPruningConfig +from danswer.llm.answering.models import QuotesConfig from danswer.llm.utils import get_default_llm_token_encode -from danswer.llm.utils import get_default_llm_tokenizer -from danswer.one_shot_answer.factory import get_question_answer_model from danswer.one_shot_answer.models import DirectQARequest from danswer.one_shot_answer.models import OneShotQAResponse from danswer.one_shot_answer.models import QueryRephrase -from danswer.one_shot_answer.models import ThreadMessage -from danswer.one_shot_answer.qa_block import no_gen_ai_response from danswer.one_shot_answer.qa_utils import combine_message_thread -from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT -from danswer.prompts.prompt_utils import build_complete_context_str -from danswer.prompts.prompt_utils import build_task_prompt_reminders from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer from danswer.search.models import SavedSearchDoc @@ -77,106 +60,6 @@ ] -def quote_based_qa( - prompt: Prompt, - query_message: ThreadMessage, - history_str: str, - context_chunks: list[InferenceChunk], - llm_override: str | None, - timeout: int, - use_chain_of_thought: bool, - return_contexts: bool, - llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, -) -> AnswerObjectIterator: - qa_model = get_question_answer_model( - prompt=prompt, - timeout=timeout, - chain_of_thought=use_chain_of_thought, - llm_version=llm_override, - ) - - full_prompt_str = ( - qa_model.build_prompt( - query=query_message.message, - history_str=history_str, - context_chunks=context_chunks, - ) - if qa_model is not None - else "Gen AI Disabled" - ) - - response_packets = ( - qa_model.answer_question_stream( - prompt=full_prompt_str, - llm_context_docs=context_chunks, - metrics_callback=llm_metrics_callback, - ) - if qa_model is not None - else no_gen_ai_response() - ) - - if qa_model is not None and return_contexts: - contexts = DanswerContexts( - contexts=[ - DanswerContext( - content=context_chunk.content, - document_id=context_chunk.document_id, - semantic_identifier=context_chunk.semantic_identifier, - blurb=context_chunk.semantic_identifier, - ) - for context_chunk in context_chunks - ] - ) - - response_packets = itertools.chain(response_packets, [contexts]) - - yield from response_packets - - -def citation_based_qa( - prompt: Prompt, - query_message: ThreadMessage, - history_str: str, - context_chunks: list[InferenceChunk], - llm_override: str | None, - timeout: int, -) -> AnswerObjectIterator: - llm_tokenizer = get_default_llm_tokenizer() - - system_prompt_or_none, _ = build_chat_system_message( - prompt=prompt, - context_exists=True, - llm_tokenizer_encode_func=llm_tokenizer.encode, - ) - - task_prompt_with_reminder = build_task_prompt_reminders(prompt) - - context_docs_str = build_complete_context_str(context_chunks) - user_message = HumanMessage( - content=CITATIONS_PROMPT.format( - task_prompt=task_prompt_with_reminder, - user_query=query_message.message, - history_block=history_str, - context_docs_str=context_docs_str, - ) - ) - - llm = get_default_llm( - timeout=timeout, - gen_ai_model_version_override=llm_override, - ) - - llm_prompt: list[BaseMessage] = [user_message] - if system_prompt_or_none is not None: - llm_prompt = [system_prompt_or_none] + llm_prompt - - llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in context_chunks] - doc_id_to_rank_map = map_document_id_order(llm_docs) - - tokens = llm.stream(llm_prompt) - yield from extract_citations_from_stream(tokens, llm_docs, doc_id_to_rank_map) - - def stream_answer_objects( query_req: DirectQARequest, user: User | None, @@ -188,14 +71,12 @@ def stream_answer_objects( db_session: Session, # Needed to translate persona num_chunks to tokens to the LLM default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT, - default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE, timeout: int = QA_TIMEOUT, bypass_acl: bool = False, use_citations: bool = False, retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, - llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, ) -> AnswerObjectIterator: """Streams in order: 1. [always] Retrieved documents, stops flow if nothing is found @@ -273,43 +154,11 @@ def stream_answer_objects( ) yield llm_relevance_filtering_response - # Prep chunks to pass to LLM - num_llm_chunks = ( - chat_session.persona.num_chunks - if chat_session.persona.num_chunks is not None - else default_num_chunks - ) - - chunk_token_limit = int(num_llm_chunks * default_chunk_size) - if max_document_tokens: - chunk_token_limit = min(chunk_token_limit, max_document_tokens) - else: - max_document_tokens = compute_max_document_tokens( - persona=chat_session.persona, actual_user_input=query_msg.message - ) - chunk_token_limit = min(chunk_token_limit, max_document_tokens) - - llm_chunks_indices = get_chunks_for_qa( - chunks=top_chunks, - llm_chunk_selection=search_pipeline.chunk_relevance_list, - token_limit=chunk_token_limit, - ) - llm_chunks = [top_chunks[i] for i in llm_chunks_indices] - - logger.debug( - f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in llm_chunks]}" - ) - prompt = None - llm_override = None if query_req.prompt_id is not None: prompt = get_prompt_by_id( prompt_id=query_req.prompt_id, user_id=user_id, db_session=db_session ) - persona = get_persona_by_id( - persona_id=query_req.persona_id, user_id=user_id, db_session=db_session - ) - llm_override = persona.llm_model_version_override if prompt is None: if not chat_session.persona.prompts: raise RuntimeError( @@ -329,52 +178,39 @@ def stream_answer_objects( commit=True, ) - if use_citations: - qa_stream = citation_based_qa( - prompt=prompt, - query_message=query_msg, - history_str=history_str, - context_chunks=llm_chunks, - llm_override=llm_override, - timeout=timeout, - ) - else: - qa_stream = quote_based_qa( - prompt=prompt, - query_message=query_msg, - history_str=history_str, - context_chunks=llm_chunks, - llm_override=llm_override, - timeout=timeout, - use_chain_of_thought=False, - return_contexts=False, - llm_metrics_callback=llm_metrics_callback, - ) - - # Capture outputs and errors - llm_output = "" - error: str | None = None - for packet in qa_stream: - logger.debug(packet) - - if isinstance(packet, DanswerAnswerPiece): - token = packet.answer_piece - if token: - llm_output += token - elif isinstance(packet, StreamingError): - error = packet.error - - yield packet + answer_config = AnswerStyleConfig( + citation_config=CitationConfig() if use_citations else None, + quotes_config=QuotesConfig() if not use_citations else None, + document_pruning_config=DocumentPruningConfig( + max_chunks=int( + chat_session.persona.num_chunks + if chat_session.persona.num_chunks is not None + else default_num_chunks + ), + max_tokens=max_document_tokens, + ), + ) + answer = Answer( + question=query_msg.message, + docs=[llm_doc_from_inference_chunk(chunk) for chunk in top_chunks], + answer_style_config=answer_config, + prompt=prompt, + persona=chat_session.persona, + doc_relevance_list=search_pipeline.chunk_relevance_list, + single_message_history=history_str, + timeout=timeout, + ) + yield from answer.processed_streamed_output # Saving Gen AI answer and responding with message info gen_ai_response_message = create_new_chat_message( chat_session_id=chat_session.id, parent_message=new_user_message, prompt_id=query_req.prompt_id, - message=llm_output, - token_count=len(llm_tokenizer(llm_output)), + message=answer.llm_answer, + token_count=len(llm_tokenizer(answer.llm_answer)), message_type=MessageType.ASSISTANT, - error=error, + error=None, reference_docs=None, # Don't need to save reference docs for one shot flow db_session=db_session, commit=True, @@ -419,7 +255,6 @@ def get_search_answer( retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, - llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, ) -> OneShotQAResponse: """Collects the streamed one shot answer responses into a single object""" qa_response = OneShotQAResponse() @@ -435,7 +270,6 @@ def get_search_answer( timeout=answer_generation_timeout, retrieval_metrics_callback=retrieval_metrics_callback, rerank_metrics_callback=rerank_metrics_callback, - llm_metrics_callback=llm_metrics_callback, ) answer = "" diff --git a/backend/danswer/one_shot_answer/factory.py b/backend/danswer/one_shot_answer/factory.py deleted file mode 100644 index 122ed6ac06..0000000000 --- a/backend/danswer/one_shot_answer/factory.py +++ /dev/null @@ -1,48 +0,0 @@ -from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE -from danswer.configs.chat_configs import QA_TIMEOUT -from danswer.db.models import Prompt -from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_default_llm -from danswer.one_shot_answer.interfaces import QAModel -from danswer.one_shot_answer.qa_block import QABlock -from danswer.one_shot_answer.qa_block import QAHandler -from danswer.one_shot_answer.qa_block import SingleMessageQAHandler -from danswer.one_shot_answer.qa_block import WeakLLMQAHandler -from danswer.utils.logger import setup_logger - -logger = setup_logger() - - -def get_question_answer_model( - prompt: Prompt | None, - api_key: str | None = None, - timeout: int = QA_TIMEOUT, - chain_of_thought: bool = False, - llm_version: str | None = None, - qa_model_version: str | None = QA_PROMPT_OVERRIDE, -) -> QAModel | None: - if chain_of_thought: - raise NotImplementedError("COT has been disabled") - - system_prompt = prompt.system_prompt if prompt is not None else None - task_prompt = prompt.task_prompt if prompt is not None else None - - try: - llm = get_default_llm( - api_key=api_key, - timeout=timeout, - gen_ai_model_version_override=llm_version, - ) - except GenAIDisabledException: - return None - - if qa_model_version == "weak": - qa_handler: QAHandler = WeakLLMQAHandler( - system_prompt=system_prompt, task_prompt=task_prompt - ) - else: - qa_handler = SingleMessageQAHandler( - system_prompt=system_prompt, task_prompt=task_prompt - ) - - return QABlock(llm=llm, qa_handler=qa_handler) diff --git a/backend/danswer/one_shot_answer/interfaces.py b/backend/danswer/one_shot_answer/interfaces.py deleted file mode 100644 index ca916d699d..0000000000 --- a/backend/danswer/one_shot_answer/interfaces.py +++ /dev/null @@ -1,26 +0,0 @@ -import abc -from collections.abc import Callable - -from danswer.chat.models import AnswerQuestionStreamReturn -from danswer.chat.models import LLMMetricsContainer -from danswer.indexing.models import InferenceChunk - - -class QAModel: - @abc.abstractmethod - def build_prompt( - self, - query: str, - history_str: str, - context_chunks: list[InferenceChunk], - ) -> str: - raise NotImplementedError - - @abc.abstractmethod - def answer_question_stream( - self, - prompt: str, - llm_context_docs: list[InferenceChunk], - metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, - ) -> AnswerQuestionStreamReturn: - raise NotImplementedError diff --git a/backend/danswer/one_shot_answer/qa_block.py b/backend/danswer/one_shot_answer/qa_block.py deleted file mode 100644 index 68cb6e4a82..0000000000 --- a/backend/danswer/one_shot_answer/qa_block.py +++ /dev/null @@ -1,313 +0,0 @@ -import abc -import re -from collections.abc import Callable -from collections.abc import Iterator -from typing import cast - -from danswer.chat.models import AnswerQuestionStreamReturn -from danswer.chat.models import DanswerAnswer -from danswer.chat.models import DanswerAnswerPiece -from danswer.chat.models import DanswerQuotes -from danswer.chat.models import LlmDoc -from danswer.chat.models import LLMMetricsContainer -from danswer.chat.models import StreamingError -from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION -from danswer.configs.constants import DISABLED_GEN_AI_MSG -from danswer.indexing.models import InferenceChunk -from danswer.llm.interfaces import LLM -from danswer.llm.utils import check_number_of_tokens -from danswer.llm.utils import get_default_llm_token_encode -from danswer.one_shot_answer.interfaces import QAModel -from danswer.one_shot_answer.qa_utils import process_answer -from danswer.one_shot_answer.qa_utils import process_model_tokens -from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK -from danswer.prompts.direct_qa_prompts import COT_PROMPT -from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK -from danswer.prompts.direct_qa_prompts import JSON_PROMPT -from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT -from danswer.prompts.direct_qa_prompts import ONE_SHOT_SYSTEM_PROMPT -from danswer.prompts.direct_qa_prompts import ONE_SHOT_TASK_PROMPT -from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT -from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT -from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT -from danswer.prompts.direct_qa_prompts import WEAK_MODEL_SYSTEM_PROMPT -from danswer.prompts.direct_qa_prompts import WEAK_MODEL_TASK_PROMPT -from danswer.prompts.prompt_utils import build_complete_context_str -from danswer.utils.logger import setup_logger -from danswer.utils.text_processing import clean_up_code_blocks -from danswer.utils.text_processing import escape_newlines - -logger = setup_logger() - - -class QAHandler(abc.ABC): - @property - @abc.abstractmethod - def is_json_output(self) -> bool: - """Does the model output a valid json with answer and quotes keys? Most flows with a - capable model should output a json. This hints to the model that the output is used - with a downstream system rather than freeform creative output. Most models should be - finetuned to recognize this.""" - raise NotImplementedError - - @abc.abstractmethod - def build_prompt( - self, - query: str, - history_str: str, - context_chunks: list[InferenceChunk], - ) -> str: - raise NotImplementedError - - def process_llm_token_stream( - self, tokens: Iterator[str], context_chunks: list[InferenceChunk] - ) -> AnswerQuestionStreamReturn: - yield from process_model_tokens( - tokens=tokens, - context_docs=context_chunks, - is_json_prompt=self.is_json_output, - ) - - -class WeakLLMQAHandler(QAHandler): - """Since Danswer supports a variety of LLMs, this less demanding prompt is provided - as an option to use with weaker LLMs such as small version, low float precision, quantized, - or distilled models. It only uses one context document and has very weak requirements of - output format. - """ - - def __init__( - self, - system_prompt: str | None, - task_prompt: str | None, - ) -> None: - if not system_prompt and not task_prompt: - self.system_prompt = WEAK_MODEL_SYSTEM_PROMPT - self.task_prompt = WEAK_MODEL_TASK_PROMPT - else: - self.system_prompt = system_prompt or "" - self.task_prompt = task_prompt or "" - - self.task_prompt = self.task_prompt.rstrip() - if self.task_prompt and self.task_prompt[0] != "\n": - self.task_prompt = "\n" + self.task_prompt - - @property - def is_json_output(self) -> bool: - return False - - def build_prompt( - self, - query: str, - history_str: str, - context_chunks: list[InferenceChunk], - ) -> str: - context_block = "" - if context_chunks: - context_block = CONTEXT_BLOCK.format( - context_docs_str=context_chunks[0].content - ) - - prompt_str = WEAK_LLM_PROMPT.format( - system_prompt=self.system_prompt, - context_block=context_block, - task_prompt=self.task_prompt, - user_query=query, - ) - return prompt_str - - -class SingleMessageQAHandler(QAHandler): - def __init__( - self, - system_prompt: str | None, - task_prompt: str | None, - use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), - ) -> None: - self.use_language_hint = use_language_hint - if not system_prompt and not task_prompt: - self.system_prompt = ONE_SHOT_SYSTEM_PROMPT - self.task_prompt = ONE_SHOT_TASK_PROMPT - else: - self.system_prompt = system_prompt or "" - self.task_prompt = task_prompt or "" - - self.task_prompt = self.task_prompt.rstrip() - if self.task_prompt and self.task_prompt[0] != "\n": - self.task_prompt = "\n" + self.task_prompt - - @property - def is_json_output(self) -> bool: - return True - - def build_prompt( - self, query: str, history_str: str, context_chunks: list[InferenceChunk] - ) -> str: - context_block = "" - if context_chunks: - context_docs_str = build_complete_context_str( - cast(list[LlmDoc | InferenceChunk], context_chunks) - ) - context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str) - - history_block = "" - if history_str: - history_block = HISTORY_BLOCK.format(history_str=history_str) - - full_prompt = JSON_PROMPT.format( - system_prompt=self.system_prompt, - context_block=context_block, - history_block=history_block, - task_prompt=self.task_prompt, - user_query=query, - language_hint_or_none=LANGUAGE_HINT.strip() - if self.use_language_hint - else "", - ).strip() - return full_prompt - - -# This one isn't used, currently only streaming prompts are used -class SingleMessageScratchpadHandler(QAHandler): - def __init__( - self, - system_prompt: str | None, - task_prompt: str | None, - use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), - ) -> None: - self.use_language_hint = use_language_hint - if not system_prompt and not task_prompt: - self.system_prompt = ONE_SHOT_SYSTEM_PROMPT - self.task_prompt = ONE_SHOT_TASK_PROMPT - else: - self.system_prompt = system_prompt or "" - self.task_prompt = task_prompt or "" - - self.task_prompt = self.task_prompt.rstrip() - if self.task_prompt and self.task_prompt[0] != "\n": - self.task_prompt = "\n" + self.task_prompt - - @property - def is_json_output(self) -> bool: - return True - - def build_prompt( - self, query: str, history_str: str, context_chunks: list[InferenceChunk] - ) -> str: - context_docs_str = build_complete_context_str( - cast(list[LlmDoc | InferenceChunk], context_chunks) - ) - - # Outdated - prompt = COT_PROMPT.format( - context_docs_str=context_docs_str, - user_query=query, - language_hint_or_none=LANGUAGE_HINT.strip() - if self.use_language_hint - else "", - ).strip() - - return prompt - - def process_llm_output( - self, model_output: str, context_chunks: list[InferenceChunk] - ) -> tuple[DanswerAnswer, DanswerQuotes]: - logger.debug(model_output) - - model_clean = clean_up_code_blocks(model_output) - - match = re.search(r'{\s*"answer":', model_clean) - if not match: - return DanswerAnswer(answer=None), DanswerQuotes(quotes=[]) - - final_json = escape_newlines(model_clean[match.start() :]) - - return process_answer( - final_json, context_chunks, is_json_prompt=self.is_json_output - ) - - def process_llm_token_stream( - self, tokens: Iterator[str], context_chunks: list[InferenceChunk] - ) -> AnswerQuestionStreamReturn: - # Can be supported but the parsing is more involved, not handling until needed - raise ValueError( - "This Scratchpad approach is not suitable for real time uses like streaming" - ) - - -def build_dummy_prompt( - system_prompt: str, task_prompt: str, retrieval_disabled: bool -) -> str: - if retrieval_disabled: - return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format( - user_query="", - system_prompt=system_prompt, - task_prompt=task_prompt, - ).strip() - - return PARAMATERIZED_PROMPT.format( - context_docs_str="", - user_query="", - system_prompt=system_prompt, - task_prompt=task_prompt, - ).strip() - - -def no_gen_ai_response() -> Iterator[DanswerAnswerPiece]: - yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG) - - -class QABlock(QAModel): - def __init__(self, llm: LLM, qa_handler: QAHandler) -> None: - self._llm = llm - self._qa_handler = qa_handler - - def build_prompt( - self, - query: str, - history_str: str, - context_chunks: list[InferenceChunk], - ) -> str: - prompt = self._qa_handler.build_prompt( - query=query, history_str=history_str, context_chunks=context_chunks - ) - return prompt - - def answer_question_stream( - self, - prompt: str, - llm_context_docs: list[InferenceChunk], - metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, - ) -> AnswerQuestionStreamReturn: - tokens_stream = self._llm.stream(prompt) - - captured_tokens = [] - - try: - for answer_piece in self._qa_handler.process_llm_token_stream( - iter(tokens_stream), llm_context_docs - ): - if ( - isinstance(answer_piece, DanswerAnswerPiece) - and answer_piece.answer_piece - ): - captured_tokens.append(answer_piece.answer_piece) - yield answer_piece - - except Exception as e: - yield StreamingError(error=str(e)) - - if metrics_callback is not None: - prompt_tokens = check_number_of_tokens( - text=str(prompt), encode_fn=get_default_llm_token_encode() - ) - - response_tokens = check_number_of_tokens( - text="".join(captured_tokens), encode_fn=get_default_llm_token_encode() - ) - - metrics_callback( - LLMMetricsContainer( - prompt_tokens=prompt_tokens, response_tokens=response_tokens - ) - ) diff --git a/backend/danswer/one_shot_answer/qa_utils.py b/backend/danswer/one_shot_answer/qa_utils.py index 032d243459..e912a915e2 100644 --- a/backend/danswer/one_shot_answer/qa_utils.py +++ b/backend/danswer/one_shot_answer/qa_utils.py @@ -1,275 +1,14 @@ -import math -import re from collections.abc import Callable from collections.abc import Generator -from collections.abc import Iterator -from json.decoder import JSONDecodeError -from typing import Optional -from typing import Tuple -import regex - -from danswer.chat.models import DanswerAnswer -from danswer.chat.models import DanswerAnswerPiece -from danswer.chat.models import DanswerQuote -from danswer.chat.models import DanswerQuotes -from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT from danswer.configs.constants import MessageType -from danswer.indexing.models import InferenceChunk from danswer.llm.utils import get_default_llm_token_encode from danswer.one_shot_answer.models import ThreadMessage -from danswer.prompts.constants import ANSWER_PAT -from danswer.prompts.constants import QUOTE_PAT -from danswer.prompts.constants import UNCERTAINTY_PAT from danswer.utils.logger import setup_logger -from danswer.utils.text_processing import clean_model_quote -from danswer.utils.text_processing import clean_up_code_blocks -from danswer.utils.text_processing import extract_embedded_json -from danswer.utils.text_processing import shared_precompare_cleanup logger = setup_logger() -def _extract_answer_quotes_freeform( - answer_raw: str, -) -> Tuple[Optional[str], Optional[list[str]]]: - """Splits the model output into an Answer and 0 or more Quote sections. - Splits by the Quote pattern, if not exist then assume it's all answer and no quotes - """ - # If no answer section, don't care about the quote - if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()): - return None, None - - # Sometimes model regenerates the Answer: pattern despite it being provided in the prompt - if answer_raw.lower().startswith(ANSWER_PAT.lower()): - answer_raw = answer_raw[len(ANSWER_PAT) :] - - # Accept quote sections starting with the lower case version - answer_raw = answer_raw.replace( - f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}" - ) # Just in case model unreliable - - sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw) - sections_clean = [ - str(section).strip() for section in sections if str(section).strip() - ] - if not sections_clean: - return None, None - - answer = str(sections_clean[0]) - if len(sections) == 1: - return answer, None - return answer, sections_clean[1:] - - -def _extract_answer_quotes_json( - answer_dict: dict[str, str | list[str]] -) -> Tuple[Optional[str], Optional[list[str]]]: - answer_dict = {k.lower(): v for k, v in answer_dict.items()} - answer = str(answer_dict.get("answer")) - quotes = answer_dict.get("quotes") or answer_dict.get("quote") - if isinstance(quotes, str): - quotes = [quotes] - return answer, quotes - - -def _extract_answer_json(raw_model_output: str) -> dict: - try: - answer_json = extract_embedded_json(raw_model_output) - except (ValueError, JSONDecodeError): - # LLMs get confused when handling the list in the json. Sometimes it doesn't attend - # enough to the previous { token so it just ends the list of quotes and stops there - # here, we add logic to try to fix this LLM error. - answer_json = extract_embedded_json(raw_model_output + "}") - - if "answer" not in answer_json: - raise ValueError("Model did not output an answer as expected.") - - return answer_json - - -def separate_answer_quotes( - answer_raw: str, is_json_prompt: bool = False -) -> Tuple[Optional[str], Optional[list[str]]]: - """Takes in a raw model output and pulls out the answer and the quotes sections.""" - if is_json_prompt: - model_raw_json = _extract_answer_json(answer_raw) - return _extract_answer_quotes_json(model_raw_json) - - return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw)) - - -def match_quotes_to_docs( - quotes: list[str], - chunks: list[InferenceChunk], - max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT, - fuzzy_search: bool = False, - prefix_only_length: int = 100, -) -> DanswerQuotes: - danswer_quotes: list[DanswerQuote] = [] - for quote in quotes: - max_edits = math.ceil(float(len(quote)) * max_error_percent) - - for chunk in chunks: - if not chunk.source_links: - continue - - quote_clean = shared_precompare_cleanup( - clean_model_quote(quote, trim_length=prefix_only_length) - ) - chunk_clean = shared_precompare_cleanup(chunk.content) - - # Finding the offset of the quote in the plain text - if fuzzy_search: - re_search_str = ( - r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}" - ) - found = regex.search(re_search_str, chunk_clean) - if not found: - continue - offset = found.span()[0] - else: - if quote_clean not in chunk_clean: - continue - offset = chunk_clean.index(quote_clean) - - # Extracting the link from the offset - curr_link = None - for link_offset, link in chunk.source_links.items(): - # Should always find one because offset is at least 0 and there - # must be a 0 link_offset - if int(link_offset) <= offset: - curr_link = link - else: - break - - danswer_quotes.append( - DanswerQuote( - quote=quote, - document_id=chunk.document_id, - link=curr_link, - source_type=chunk.source_type, - semantic_identifier=chunk.semantic_identifier, - blurb=chunk.blurb, - ) - ) - break - - return DanswerQuotes(quotes=danswer_quotes) - - -def process_answer( - answer_raw: str, - chunks: list[InferenceChunk], - is_json_prompt: bool = True, -) -> tuple[DanswerAnswer, DanswerQuotes]: - """Used (1) in the non-streaming case to process the model output - into an Answer and Quotes AND (2) after the complete streaming response - has been received to process the model output into an Answer and Quotes.""" - answer, quote_strings = separate_answer_quotes(answer_raw, is_json_prompt) - if answer == UNCERTAINTY_PAT or not answer: - if answer == UNCERTAINTY_PAT: - logger.debug("Answer matched UNCERTAINTY_PAT") - else: - logger.debug("No answer extracted from raw output") - return DanswerAnswer(answer=None), DanswerQuotes(quotes=[]) - - logger.info(f"Answer: {answer}") - if not quote_strings: - logger.debug("No quotes extracted from raw output") - return DanswerAnswer(answer=answer), DanswerQuotes(quotes=[]) - logger.info(f"All quotes (including unmatched): {quote_strings}") - quotes = match_quotes_to_docs(quote_strings, chunks) - logger.debug(f"Final quotes: {quotes}") - - return DanswerAnswer(answer=answer), quotes - - -def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool: - next_token = next_token.replace('\\"', "") - # If the previous character is an escape token, don't consider the first character of next_token - # This does not work if it's an escaped escape sign before the " but this is rare, not worth handling - if answer_so_far and answer_so_far[-1] == "\\": - next_token = next_token[1:] - if '"' in next_token: - return True - return False - - -def _extract_quotes_from_completed_token_stream( - model_output: str, context_chunks: list[InferenceChunk], is_json_prompt: bool = True -) -> DanswerQuotes: - answer, quotes = process_answer(model_output, context_chunks, is_json_prompt) - if answer: - logger.info(answer) - elif model_output: - logger.warning("Answer extraction from model output failed.") - - return quotes - - -def process_model_tokens( - tokens: Iterator[str], - context_docs: list[InferenceChunk], - is_json_prompt: bool = True, -) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]: - """Used in the streaming case to process the model output - into an Answer and Quotes - - Yields Answer tokens back out in a dict for streaming to frontend - When Answer section ends, yields dict with answer_finished key - Collects all the tokens at the end to form the complete model output""" - quote_pat = f"\n{QUOTE_PAT}" - # Sometimes worse model outputs new line instead of : - quote_loose = f"\n{quote_pat[:-1]}\n" - # Sometime model outputs two newlines before quote section - quote_pat_full = f"\n{quote_pat}" - model_output: str = "" - found_answer_start = False if is_json_prompt else True - found_answer_end = False - hold_quote = "" - for token in tokens: - model_previous = model_output - model_output += token - - if not found_answer_start and '{"answer":"' in re.sub(r"\s", "", model_output): - # Note, if the token that completes the pattern has additional text, for example if the token is "? - # Then the chars after " will not be streamed, but this is ok as it prevents streaming the ? in the - # event that the model outputs the UNCERTAINTY_PAT - found_answer_start = True - - # Prevent heavy cases of hallucinations where model is not even providing a json until later - if is_json_prompt and len(model_output) > 40: - logger.warning("LLM did not produce json as prompted") - found_answer_end = True - - continue - - if found_answer_start and not found_answer_end: - if is_json_prompt and _stream_json_answer_end(model_previous, token): - found_answer_end = True - yield DanswerAnswerPiece(answer_piece=None) - continue - elif not is_json_prompt: - if quote_pat in hold_quote + token or quote_loose in hold_quote + token: - found_answer_end = True - yield DanswerAnswerPiece(answer_piece=None) - continue - if hold_quote + token in quote_pat_full: - hold_quote += token - continue - yield DanswerAnswerPiece(answer_piece=hold_quote + token) - hold_quote = "" - - logger.debug(f"Raw Model QnA Output: {model_output}") - - yield _extract_quotes_from_completed_token_stream( - model_output=model_output, - context_chunks=context_docs, - is_json_prompt=is_json_prompt, - ) - - def simulate_streaming_response(model_out: str) -> Generator[str, None, None]: """Mock streaming by generating the passed in model output, character by character""" for token in model_out: diff --git a/backend/danswer/search/pipeline.py b/backend/danswer/search/pipeline.py index 972f510db9..5c590939b5 100644 --- a/backend/danswer/search/pipeline.py +++ b/backend/danswer/search/pipeline.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from collections.abc import Generator from typing import cast from sqlalchemy.orm import Session @@ -51,6 +52,11 @@ def __init__( self._reranked_docs: list[InferenceChunk] | None = None self._relevant_chunk_indicies: list[int] | None = None + # generator state + self._postprocessing_generator: Generator[ + list[InferenceChunk] | list[str], None, None + ] | None = None + """Pre-processing""" def _run_preprocessing(self) -> None: @@ -113,36 +119,38 @@ def retrieved_docs(self) -> list[InferenceChunk]: """Post-Processing""" - def _run_postprocessing(self) -> None: - postprocessing_generator = search_postprocessing( - search_query=self.search_query, - retrieved_chunks=self.retrieved_docs, - rerank_metrics_callback=self.rerank_metrics_callback, - ) - self._reranked_docs = cast(list[InferenceChunk], next(postprocessing_generator)) - - relevant_chunk_ids = cast(list[str], next(postprocessing_generator)) - self._relevant_chunk_indicies = [ - ind - for ind, chunk in enumerate(self._reranked_docs) - if chunk.unique_id in relevant_chunk_ids - ] - @property def reranked_docs(self) -> list[InferenceChunk]: if self._reranked_docs is not None: return self._reranked_docs - self._run_postprocessing() - return cast(list[InferenceChunk], self._reranked_docs) + self._postprocessing_generator = search_postprocessing( + search_query=self.search_query, + retrieved_chunks=self.retrieved_docs, + rerank_metrics_callback=self.rerank_metrics_callback, + ) + self._reranked_docs = cast( + list[InferenceChunk], next(self._postprocessing_generator) + ) + return self._reranked_docs @property def relevant_chunk_indicies(self) -> list[int]: if self._relevant_chunk_indicies is not None: return self._relevant_chunk_indicies - self._run_postprocessing() - return cast(list[int], self._relevant_chunk_indicies) + # run first step of postprocessing generator if not already done + reranked_docs = self.reranked_docs + + relevant_chunk_ids = next( + cast(Generator[list[str], None, None], self._postprocessing_generator) + ) + self._relevant_chunk_indicies = [ + ind + for ind, chunk in enumerate(reranked_docs) + if chunk.unique_id in relevant_chunk_ids + ] + return self._relevant_chunk_indicies @property def chunk_relevance_list(self) -> list[bool]: diff --git a/backend/danswer/search/retrieval/search_runner.py b/backend/danswer/search/retrieval/search_runner.py index 3dff76d96e..41aa3a3c7e 100644 --- a/backend/danswer/search/retrieval/search_runner.py +++ b/backend/danswer/search/retrieval/search_runner.py @@ -223,11 +223,13 @@ def combine_inference_chunks(inf_chunks: list[InferenceChunk]) -> LlmDoc: return LlmDoc( document_id=first_chunk.document_id, content="\n".join(chunk_texts), + blurb=first_chunk.blurb, semantic_identifier=first_chunk.semantic_identifier, source_type=first_chunk.source_type, metadata=first_chunk.metadata, updated_at=first_chunk.updated_at, link=first_chunk.source_links[0] if first_chunk.source_links else None, + source_links=first_chunk.source_links, ) diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index 8762f40b51..d75ff69480 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -14,8 +14,8 @@ from danswer.db.engine import get_session from danswer.db.models import User from danswer.db.persona import create_update_persona +from danswer.llm.answering.prompts.utils import build_dummy_prompt from danswer.llm.utils import get_default_llm_version -from danswer.one_shot_answer.qa_block import build_dummy_prompt from danswer.server.features.persona.models import CreatePersonaRequest from danswer.server.features.persona.models import PersonaSnapshot from danswer.server.features.persona.models import PromptTemplateResponse diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index a8076659c6..4fb98c5a15 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -6,7 +6,6 @@ from sqlalchemy.orm import Session from danswer.auth.users import current_user -from danswer.chat.chat_utils import compute_max_document_tokens from danswer.chat.chat_utils import create_chat_chain from danswer.chat.process_message import stream_chat_message from danswer.db.chat import create_chat_session @@ -25,6 +24,7 @@ from danswer.db.models import User from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index +from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens from danswer.secondary_llm_flows.chat_session_naming import ( get_renamed_conversation_name, ) diff --git a/backend/tests/regression/answer_quality/eval_direct_qa.py b/backend/tests/regression/answer_quality/eval_direct_qa.py index bd2f70010e..d32f275472 100644 --- a/backend/tests/regression/answer_quality/eval_direct_qa.py +++ b/backend/tests/regression/answer_quality/eval_direct_qa.py @@ -77,7 +77,6 @@ def get_answer_for_question( str | None, RetrievalMetricsContainer | None, RerankMetricsContainer | None, - LLMMetricsContainer | None, ]: filters = IndexFilters( source_type=None, @@ -103,7 +102,6 @@ def get_answer_for_question( retrieval_metrics = MetricsHander[RetrievalMetricsContainer]() rerank_metrics = MetricsHander[RerankMetricsContainer]() - llm_metrics = MetricsHander[LLMMetricsContainer]() answer = get_search_answer( query_req=new_message_request, @@ -116,14 +114,12 @@ def get_answer_for_question( bypass_acl=True, retrieval_metrics_callback=retrieval_metrics.record_metric, rerank_metrics_callback=rerank_metrics.record_metric, - llm_metrics_callback=llm_metrics.record_metric, ) return ( answer.answer, retrieval_metrics.metrics, rerank_metrics.metrics, - llm_metrics.metrics, ) @@ -221,7 +217,6 @@ def _print_llm_metrics(metrics_container: LLMMetricsContainer) -> None: answer, retrieval_metrics, rerank_metrics, - llm_metrics, ) = get_answer_for_question(sample["question"], db_session) end_time = datetime.now() @@ -237,12 +232,6 @@ def _print_llm_metrics(metrics_container: LLMMetricsContainer) -> None: else "\tFailed, either crashed or refused to answer." ) if not args.discard_metrics: - print("\nLLM Tokens Usage:") - if llm_metrics is None: - print("No LLM Metrics Available") - else: - _print_llm_metrics(llm_metrics) - print("\nRetrieval Metrics:") if retrieval_metrics is None: print("No Retrieval Metrics Available") diff --git a/backend/tests/regression/search_quality/eval_search.py b/backend/tests/regression/search_quality/eval_search.py index d40ae13480..5bf9406b41 100644 --- a/backend/tests/regression/search_quality/eval_search.py +++ b/backend/tests/regression/search_quality/eval_search.py @@ -7,9 +7,9 @@ from sqlalchemy.orm import Session -from danswer.chat.chat_utils import get_chunks_for_qa from danswer.db.engine import get_sqlalchemy_engine from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.doc_pruning import reorder_docs from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer from danswer.search.models import SearchRequest @@ -95,16 +95,8 @@ def get_search_results( top_chunks = search_pipeline.reranked_docs llm_chunk_selection = search_pipeline.chunk_relevance_list - llm_chunks_indices = get_chunks_for_qa( - chunks=top_chunks, - llm_chunk_selection=llm_chunk_selection, - token_limit=None, - ) - - llm_chunks = [top_chunks[i] for i in llm_chunks_indices] - return ( - llm_chunks, + reorder_docs(top_chunks, llm_chunk_selection), retrieval_metrics.metrics, rerank_metrics.metrics, ) diff --git a/backend/tests/unit/danswer/direct_qa/test_qa_utils.py b/backend/tests/unit/danswer/direct_qa/test_qa_utils.py index b30d08b169..b7b30b63d2 100644 --- a/backend/tests/unit/danswer/direct_qa/test_qa_utils.py +++ b/backend/tests/unit/danswer/direct_qa/test_qa_utils.py @@ -3,8 +3,12 @@ from danswer.configs.constants import DocumentSource from danswer.indexing.models import InferenceChunk -from danswer.one_shot_answer.qa_utils import match_quotes_to_docs -from danswer.one_shot_answer.qa_utils import separate_answer_quotes +from danswer.llm.answering.stream_processing.quotes_processing import ( + match_quotes_to_docs, +) +from danswer.llm.answering.stream_processing.quotes_processing import ( + separate_answer_quotes, +) class TestQAPostprocessing(unittest.TestCase): From efc7d6e0989551954777534b77602580ae892996 Mon Sep 17 00:00:00 2001 From: Weves Date: Tue, 26 Mar 2024 11:05:04 -0700 Subject: [PATCH 20/58] Add support for Github Flavored Markdown --- web/package-lock.json | 1183 ++++++++++++----- web/package.json | 4 +- web/src/app/chat/message/Messages.tsx | 3 + .../search/results/AnswerSection.tsx | 11 +- 4 files changed, 845 insertions(+), 356 deletions(-) diff --git a/web/package-lock.json b/web/package-lock.json index ef78893802..41b3655089 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -20,6 +20,7 @@ "autoprefixer": "^10.4.14", "formik": "^2.2.9", "js-cookie": "^3.0.5", + "mdast-util-find-and-replace": "^3.0.1", "next": "^14.1.0", "postcss": "^8.4.31", "react": "^18.2.0", @@ -27,7 +28,8 @@ "react-dropzone": "^14.2.3", "react-icons": "^4.8.0", "react-loader-spinner": "^5.4.5", - "react-markdown": "^8.0.7", + "react-markdown": "^9.0.1", + "remark-gfm": "^4.0.0", "semver": "^7.5.4", "sharp": "^0.32.6", "swr": "^2.1.5", @@ -1256,12 +1258,25 @@ "@types/ms": "*" } }, + "node_modules/@types/estree": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.5.tgz", + "integrity": "sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw==" + }, + "node_modules/@types/estree-jsx": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/@types/estree-jsx/-/estree-jsx-1.0.5.tgz", + "integrity": "sha512-52CcUVNFyfb1A2ALocQw/Dd1BQFNmSdkuC3BkZ6iqhdMfQz7JWOFRuJFloOzjk+6WijU56m9oKXFAXc7o3Towg==", + "dependencies": { + "@types/estree": "*" + } + }, "node_modules/@types/hast": { - "version": "2.3.8", - "resolved": "https://registry.npmjs.org/@types/hast/-/hast-2.3.8.tgz", - "integrity": "sha512-aMIqAlFd2wTIDZuvLbhUT+TGvMxrNC8ECUIVtH6xxy0sQLs3iu6NO8Kp/VT5je7i5ufnebXzdV1dNDMnvaH6IQ==", + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/@types/hast/-/hast-3.0.4.tgz", + "integrity": "sha512-WPs+bbQw5aCj+x6laNGWLH3wviHtoCv/P3+otBhbOhJgG8qtpdAMlTCxLtsTWA7LH1Oh/bFCHsBn0TPS5m30EQ==", "dependencies": { - "@types/unist": "^2" + "@types/unist": "*" } }, "node_modules/@types/hoist-non-react-statics": { @@ -1285,11 +1300,11 @@ "dev": true }, "node_modules/@types/mdast": { - "version": "3.0.15", - "resolved": "https://registry.npmjs.org/@types/mdast/-/mdast-3.0.15.tgz", - "integrity": "sha512-LnwD+mUEfxWMa1QpDraczIn6k0Ee3SMicuYSSzS6ZYl2gKS09EClnJYGd8Du6rfc5r/GZEk5o1mRb8TaTj03sQ==", + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/@types/mdast/-/mdast-4.0.3.tgz", + "integrity": "sha512-LsjtqsyF+d2/yFOYaN22dHZI1Cpwkrj+g06G8+qtUKlhovPW89YhqSnfKtMbkgmEtYpH2gydRNULd6y8mciAFg==", "dependencies": { - "@types/unist": "^2" + "@types/unist": "*" } }, "node_modules/@types/ms": { @@ -1331,9 +1346,9 @@ "integrity": "sha512-WZLiwShhwLRmeV6zH+GkbOFT6Z6VklCItrDioxUnv+u4Ll+8vKeFySoFyK/0ctcRpOmwAicELfmys1sDc/Rw+A==" }, "node_modules/@types/unist": { - "version": "2.0.10", - "resolved": "https://registry.npmjs.org/@types/unist/-/unist-2.0.10.tgz", - "integrity": "sha512-IfYcSBWE3hLpBg8+X2SEa8LVkJdJEkT2Ese2aaLs3ptGdVtABxndrMaxuFlQ1qdFf9Q5rDvDpxI3WwgvKFAsQA==" + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/@types/unist/-/unist-3.0.2.tgz", + "integrity": "sha512-dqId9J8K/vGi5Zr7oo212BGii5m3q5Hxlkwy3WpYuKPklmBEvsbMYYyLxAQpSffdLl/gdW0XUpKWFvYmyoWCoQ==" }, "node_modules/@typescript-eslint/parser": { "version": "6.13.1", @@ -1440,8 +1455,7 @@ "node_modules/@ungap/structured-clone": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/@ungap/structured-clone/-/structured-clone-1.2.0.tgz", - "integrity": "sha512-zuVdFrMJiuCDQUMCzQaD6KL28MjnqqN8XnAqiEq9PNm/hCPTSGfrXCOfwj1ow4LFb/tNymJPwsNbVePc1xFqrQ==", - "dev": true + "integrity": "sha512-zuVdFrMJiuCDQUMCzQaD6KL28MjnqqN8XnAqiEq9PNm/hCPTSGfrXCOfwj1ow4LFb/tNymJPwsNbVePc1xFqrQ==" }, "node_modules/acorn": { "version": "8.11.2", @@ -1985,6 +1999,15 @@ } ] }, + "node_modules/ccount": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/ccount/-/ccount-2.0.1.tgz", + "integrity": "sha512-eyrF0jiFpY+3drT6383f1qhkbGsLSifNAjA61IUjZjmLCWjItY6LB9ft9YhoDgwfmclB2zhu51Lc7+95b8NRAg==", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/chalk": { "version": "4.1.2", "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", @@ -2010,6 +2033,33 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/character-entities-html4": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/character-entities-html4/-/character-entities-html4-2.1.0.tgz", + "integrity": "sha512-1v7fgQRj6hnSwFpq1Eu0ynr/CDEw0rXo2B61qXrLNdHZmPKgb7fqS1a2JwF0rISo9q77jDI8VMEHoApn8qDoZA==", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/character-entities-legacy": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/character-entities-legacy/-/character-entities-legacy-3.0.0.tgz", + "integrity": "sha512-RpPp0asT/6ufRm//AJVwpViZbGM/MkjQFxJccQRHmISF/22NBtsHqAWmL+/pmkPWoIUJdWyeVleTl1wydHATVQ==", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/character-reference-invalid": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/character-reference-invalid/-/character-reference-invalid-2.0.1.tgz", + "integrity": "sha512-iBZ4F4wRbyORVsu0jPV7gXkOsGYjGHPmAyv+HiHG8gi5PtC9KI2j1+v8/tlibRvjoWX027ypmG/n0HtO5t7unw==", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/chokidar": { "version": "3.5.3", "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.5.3.tgz", @@ -2425,19 +2475,23 @@ "node": ">=8" } }, + "node_modules/devlop": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/devlop/-/devlop-1.1.0.tgz", + "integrity": "sha512-RWmIqhcFf1lRYBvNmr7qTNuyCt/7/ns2jbpp1+PalgE/rDQcBT0fioSMUpJ93irlUhC5hrg4cYqe6U+0ImW0rA==", + "dependencies": { + "dequal": "^2.0.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/didyoumean": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz", "integrity": "sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw==" }, - "node_modules/diff": { - "version": "5.1.0", - "resolved": "https://registry.npmjs.org/diff/-/diff-5.1.0.tgz", - "integrity": "sha512-D+mk+qE8VC/PAUrlAU34N+VfXev0ghe5ywmpqrawphmVZc1bEfn56uo9qpyGp1p4xpzOHkSW4ztBd6L7Xx4ACw==", - "engines": { - "node": ">=0.3.1" - } - }, "node_modules/dir-glob": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/dir-glob/-/dir-glob-3.0.1.tgz", @@ -3050,6 +3104,15 @@ "node": ">=4.0" } }, + "node_modules/estree-util-is-identifier-name": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/estree-util-is-identifier-name/-/estree-util-is-identifier-name-3.0.0.tgz", + "integrity": "sha512-hFtqIDZTIUZ9BXLb8y4pYGyk6+wekIivNVTcmvk8NoOh+VeRn5y6cEHzbURrWbfp1fIqdVipilzj+lfaadNZmg==", + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, "node_modules/esutils": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", @@ -3572,10 +3635,39 @@ "node": ">= 0.4" } }, + "node_modules/hast-util-to-jsx-runtime": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/hast-util-to-jsx-runtime/-/hast-util-to-jsx-runtime-2.3.0.tgz", + "integrity": "sha512-H/y0+IWPdsLLS738P8tDnrQ8Z+dj12zQQ6WC11TIM21C8WFVoIxcqWXf2H3hiTVZjF1AWqoimGwrTWecWrnmRQ==", + "dependencies": { + "@types/estree": "^1.0.0", + "@types/hast": "^3.0.0", + "@types/unist": "^3.0.0", + "comma-separated-tokens": "^2.0.0", + "devlop": "^1.0.0", + "estree-util-is-identifier-name": "^3.0.0", + "hast-util-whitespace": "^3.0.0", + "mdast-util-mdx-expression": "^2.0.0", + "mdast-util-mdx-jsx": "^3.0.0", + "mdast-util-mdxjs-esm": "^2.0.0", + "property-information": "^6.0.0", + "space-separated-tokens": "^2.0.0", + "style-to-object": "^1.0.0", + "unist-util-position": "^5.0.0", + "vfile-message": "^4.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, "node_modules/hast-util-whitespace": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/hast-util-whitespace/-/hast-util-whitespace-2.0.1.tgz", - "integrity": "sha512-nAxA0v8+vXSBDt3AnRUNjyRIQ0rD+ntpbAp4LnPkumc5M9yUbSMa4XDU9Q6etY4f1Wp4bNgvc1yjiZtsTTrSng==", + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/hast-util-whitespace/-/hast-util-whitespace-3.0.0.tgz", + "integrity": "sha512-88JUN06ipLwsnv+dVn+OIYOvAuvBMy/Qoi6O7mQHxdPXpjy+Cd6xRkWwux7DKO+4sYILtLBRIKgsdpS2gQc7qw==", + "dependencies": { + "@types/hast": "^3.0.0" + }, "funding": { "type": "opencollective", "url": "https://opencollective.com/unified" @@ -3589,6 +3681,15 @@ "react-is": "^16.7.0" } }, + "node_modules/html-url-attributes": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/html-url-attributes/-/html-url-attributes-3.0.0.tgz", + "integrity": "sha512-/sXbVCWayk6GDVg3ctOX6nxaVj7So40FcFAnWlWGNAB1LpYKcV5Cd10APjPjW80O7zYW2MsjBV4zZ7IZO5fVow==", + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, "node_modules/ieee754": { "version": "1.2.1", "resolved": "https://registry.npmjs.org/ieee754/-/ieee754-1.2.1.tgz", @@ -3662,9 +3763,9 @@ "integrity": "sha512-JV/yugV2uzW5iMRSiZAyDtQd+nxtUnjeLt0acNdw98kKLrvuRVyB80tsREOE7yvGVgalhZ6RNXCmEHkUKBKxew==" }, "node_modules/inline-style-parser": { - "version": "0.1.1", - "resolved": "https://registry.npmjs.org/inline-style-parser/-/inline-style-parser-0.1.1.tgz", - "integrity": "sha512-7NXolsK4CAS5+xvdj5OMMbI962hU/wvwoxk+LWR9Ek9bVtyuuYScDN6eS0rUm6TxApFpw7CX1o4uJzcd4AyD3Q==" + "version": "0.2.2", + "resolved": "https://registry.npmjs.org/inline-style-parser/-/inline-style-parser-0.2.2.tgz", + "integrity": "sha512-EcKzdTHVe8wFVOGEYXiW9WmJXPjqi1T+234YpJr98RiFYKHV3cdy1+3mkTE+KHTHxFFLH51SfaGOoUdW+v7ViQ==" }, "node_modules/internal-slot": { "version": "1.0.6", @@ -3688,6 +3789,28 @@ "node": ">=12" } }, + "node_modules/is-alphabetical": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/is-alphabetical/-/is-alphabetical-2.0.1.tgz", + "integrity": "sha512-FWyyY60MeTNyeSRpkM2Iry0G9hpr7/9kD40mD/cGQEuilcZYS4okz8SN2Q6rLCJ8gbCt6fN+rC+6tMGS99LaxQ==", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/is-alphanumerical": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/is-alphanumerical/-/is-alphanumerical-2.0.1.tgz", + "integrity": "sha512-hmbYhX/9MUMF5uh7tOXyK/n0ZvWpad5caBA17GsC6vyuCqaWliRG5K1qS9inmUhEMaOBIW7/whAnSwveW/LtZw==", + "dependencies": { + "is-alphabetical": "^2.0.0", + "is-decimal": "^2.0.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/is-array-buffer": { "version": "3.0.2", "resolved": "https://registry.npmjs.org/is-array-buffer/-/is-array-buffer-3.0.2.tgz", @@ -3761,28 +3884,6 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/is-buffer": { - "version": "2.0.5", - "resolved": "https://registry.npmjs.org/is-buffer/-/is-buffer-2.0.5.tgz", - "integrity": "sha512-i2R6zNFDwgEHJyQUtJEk0XFi1i0dPFn/oqjK3/vPCcDeJvW5NQ83V8QbicfF1SupOaB0h8ntgBC2YiE7dfyctQ==", - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/feross" - }, - { - "type": "patreon", - "url": "https://www.patreon.com/feross" - }, - { - "type": "consulting", - "url": "https://feross.org/support" - } - ], - "engines": { - "node": ">=4" - } - }, "node_modules/is-callable": { "version": "1.2.7", "resolved": "https://registry.npmjs.org/is-callable/-/is-callable-1.2.7.tgz", @@ -3821,6 +3922,15 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/is-decimal": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/is-decimal/-/is-decimal-2.0.1.tgz", + "integrity": "sha512-AAB9hiomQs5DXWcRB1rqsxGUstbRroFOPPVAomNk/3XHR5JyEZChOyTWe2oayKnsSsr/kcGqF+z6yuH6HHpN0A==", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/is-extglob": { "version": "2.1.1", "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", @@ -3876,6 +3986,15 @@ "node": ">=0.10.0" } }, + "node_modules/is-hexadecimal": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/is-hexadecimal/-/is-hexadecimal-2.0.1.tgz", + "integrity": "sha512-DgZQp241c8oO6cA1SbTEWiXeoxV42vlcJxgH+B3hi1AiqqKruZR3ZGF8In3fj4+/y/7rHvlOZLZtgJ/4ttYGZg==", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/is-map": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/is-map/-/is-map-2.0.2.tgz", @@ -4197,14 +4316,6 @@ "json-buffer": "3.0.1" } }, - "node_modules/kleur": { - "version": "4.1.5", - "resolved": "https://registry.npmjs.org/kleur/-/kleur-4.1.5.tgz", - "integrity": "sha512-o+NO+8WrRiQEE4/7nwRJhN1HWpVmJm511pBHUxPLtp0BUISzlBplORYSmTclCnJvQq2tKu/sgl3xVpkc7ZWuQQ==", - "engines": { - "node": ">=6" - } - }, "node_modules/language-subtag-registry": { "version": "0.3.22", "resolved": "https://registry.npmjs.org/language-subtag-registry/-/language-subtag-registry-0.3.22.tgz", @@ -4292,6 +4403,15 @@ "integrity": "sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==", "dev": true }, + "node_modules/longest-streak": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/longest-streak/-/longest-streak-3.1.0.tgz", + "integrity": "sha512-9Ri+o0JYgehTaVBBDoMqIl8GXtbWg711O3srftcHhZ0dqnETqLaoIK0x17fUw9rFSlK/0NlsKe0Ahhyl5pXE2g==", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/loose-envify": { "version": "1.4.0", "resolved": "https://registry.npmjs.org/loose-envify/-/loose-envify-1.4.0.tgz", @@ -4314,37 +4434,224 @@ "node": ">=10" } }, - "node_modules/mdast-util-definitions": { - "version": "5.1.2", - "resolved": "https://registry.npmjs.org/mdast-util-definitions/-/mdast-util-definitions-5.1.2.tgz", - "integrity": "sha512-8SVPMuHqlPME/z3gqVwWY4zVXn8lqKv/pAhC57FuJ40ImXyBpmO5ukh98zB2v7Blql2FiHjHv9LVztSIqjY+MA==", + "node_modules/markdown-table": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/markdown-table/-/markdown-table-3.0.3.tgz", + "integrity": "sha512-Z1NL3Tb1M9wH4XESsCDEksWoKTdlUafKc4pt0GRwjUyXaCFZ+dc3g2erqB6zm3szA2IUSi7VnPI+o/9jnxh9hw==", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/mdast-util-find-and-replace": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/mdast-util-find-and-replace/-/mdast-util-find-and-replace-3.0.1.tgz", + "integrity": "sha512-SG21kZHGC3XRTSUhtofZkBzZTJNM5ecCi0SK2IMKmSXR8vO3peL+kb1O0z7Zl83jKtutG4k5Wv/W7V3/YHvzPA==", "dependencies": { - "@types/mdast": "^3.0.0", - "@types/unist": "^2.0.0", - "unist-util-visit": "^4.0.0" + "@types/mdast": "^4.0.0", + "escape-string-regexp": "^5.0.0", + "unist-util-is": "^6.0.0", + "unist-util-visit-parents": "^6.0.0" }, "funding": { "type": "opencollective", "url": "https://opencollective.com/unified" } }, + "node_modules/mdast-util-find-and-replace/node_modules/escape-string-regexp": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-5.0.0.tgz", + "integrity": "sha512-/veY75JbMK4j1yjvuUxuVsiS/hr/4iHs9FTT6cgTexxdE0Ly/glccBAkloH/DofkjRbZU3bnoj38mOmhkZ0lHw==", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/mdast-util-from-markdown": { - "version": "1.3.1", - "resolved": "https://registry.npmjs.org/mdast-util-from-markdown/-/mdast-util-from-markdown-1.3.1.tgz", - "integrity": "sha512-4xTO/M8c82qBcnQc1tgpNtubGUW/Y1tBQ1B0i5CtSoelOLKFYlElIr3bvgREYYO5iRqbMY1YuqZng0GVOI8Qww==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/mdast-util-from-markdown/-/mdast-util-from-markdown-2.0.0.tgz", + "integrity": "sha512-n7MTOr/z+8NAX/wmhhDji8O3bRvPTV/U0oTCaZJkjhPSKTPhS3xufVhKGF8s1pJ7Ox4QgoIU7KHseh09S+9rTA==", "dependencies": { - "@types/mdast": "^3.0.0", - "@types/unist": "^2.0.0", + "@types/mdast": "^4.0.0", + "@types/unist": "^3.0.0", "decode-named-character-reference": "^1.0.0", - "mdast-util-to-string": "^3.1.0", - "micromark": "^3.0.0", - "micromark-util-decode-numeric-character-reference": "^1.0.0", - "micromark-util-decode-string": "^1.0.0", - "micromark-util-normalize-identifier": "^1.0.0", - "micromark-util-symbol": "^1.0.0", - "micromark-util-types": "^1.0.0", - "unist-util-stringify-position": "^3.0.0", - "uvu": "^0.5.0" + "devlop": "^1.0.0", + "mdast-util-to-string": "^4.0.0", + "micromark": "^4.0.0", + "micromark-util-decode-numeric-character-reference": "^2.0.0", + "micromark-util-decode-string": "^2.0.0", + "micromark-util-normalize-identifier": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0", + "unist-util-stringify-position": "^4.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-gfm": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/mdast-util-gfm/-/mdast-util-gfm-3.0.0.tgz", + "integrity": "sha512-dgQEX5Amaq+DuUqf26jJqSK9qgixgd6rYDHAv4aTBuA92cTknZlKpPfa86Z/s8Dj8xsAQpFfBmPUHWJBWqS4Bw==", + "dependencies": { + "mdast-util-from-markdown": "^2.0.0", + "mdast-util-gfm-autolink-literal": "^2.0.0", + "mdast-util-gfm-footnote": "^2.0.0", + "mdast-util-gfm-strikethrough": "^2.0.0", + "mdast-util-gfm-table": "^2.0.0", + "mdast-util-gfm-task-list-item": "^2.0.0", + "mdast-util-to-markdown": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-gfm-autolink-literal": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/mdast-util-gfm-autolink-literal/-/mdast-util-gfm-autolink-literal-2.0.0.tgz", + "integrity": "sha512-FyzMsduZZHSc3i0Px3PQcBT4WJY/X/RCtEJKuybiC6sjPqLv7h1yqAkmILZtuxMSsUyaLUWNp71+vQH2zqp5cg==", + "dependencies": { + "@types/mdast": "^4.0.0", + "ccount": "^2.0.0", + "devlop": "^1.0.0", + "mdast-util-find-and-replace": "^3.0.0", + "micromark-util-character": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-gfm-footnote": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/mdast-util-gfm-footnote/-/mdast-util-gfm-footnote-2.0.0.tgz", + "integrity": "sha512-5jOT2boTSVkMnQ7LTrd6n/18kqwjmuYqo7JUPe+tRCY6O7dAuTFMtTPauYYrMPpox9hlN0uOx/FL8XvEfG9/mQ==", + "dependencies": { + "@types/mdast": "^4.0.0", + "devlop": "^1.1.0", + "mdast-util-from-markdown": "^2.0.0", + "mdast-util-to-markdown": "^2.0.0", + "micromark-util-normalize-identifier": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-gfm-strikethrough": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/mdast-util-gfm-strikethrough/-/mdast-util-gfm-strikethrough-2.0.0.tgz", + "integrity": "sha512-mKKb915TF+OC5ptj5bJ7WFRPdYtuHv0yTRxK2tJvi+BDqbkiG7h7u/9SI89nRAYcmap2xHQL9D+QG/6wSrTtXg==", + "dependencies": { + "@types/mdast": "^4.0.0", + "mdast-util-from-markdown": "^2.0.0", + "mdast-util-to-markdown": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-gfm-table": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/mdast-util-gfm-table/-/mdast-util-gfm-table-2.0.0.tgz", + "integrity": "sha512-78UEvebzz/rJIxLvE7ZtDd/vIQ0RHv+3Mh5DR96p7cS7HsBhYIICDBCu8csTNWNO6tBWfqXPWekRuj2FNOGOZg==", + "dependencies": { + "@types/mdast": "^4.0.0", + "devlop": "^1.0.0", + "markdown-table": "^3.0.0", + "mdast-util-from-markdown": "^2.0.0", + "mdast-util-to-markdown": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-gfm-task-list-item": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/mdast-util-gfm-task-list-item/-/mdast-util-gfm-task-list-item-2.0.0.tgz", + "integrity": "sha512-IrtvNvjxC1o06taBAVJznEnkiHxLFTzgonUdy8hzFVeDun0uTjxxrRGVaNFqkU1wJR3RBPEfsxmU6jDWPofrTQ==", + "dependencies": { + "@types/mdast": "^4.0.0", + "devlop": "^1.0.0", + "mdast-util-from-markdown": "^2.0.0", + "mdast-util-to-markdown": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-mdx-expression": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/mdast-util-mdx-expression/-/mdast-util-mdx-expression-2.0.0.tgz", + "integrity": "sha512-fGCu8eWdKUKNu5mohVGkhBXCXGnOTLuFqOvGMvdikr+J1w7lDJgxThOKpwRWzzbyXAU2hhSwsmssOY4yTokluw==", + "dependencies": { + "@types/estree-jsx": "^1.0.0", + "@types/hast": "^3.0.0", + "@types/mdast": "^4.0.0", + "devlop": "^1.0.0", + "mdast-util-from-markdown": "^2.0.0", + "mdast-util-to-markdown": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-mdx-jsx": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/mdast-util-mdx-jsx/-/mdast-util-mdx-jsx-3.1.2.tgz", + "integrity": "sha512-eKMQDeywY2wlHc97k5eD8VC+9ASMjN8ItEZQNGwJ6E0XWKiW/Z0V5/H8pvoXUf+y+Mj0VIgeRRbujBmFn4FTyA==", + "dependencies": { + "@types/estree-jsx": "^1.0.0", + "@types/hast": "^3.0.0", + "@types/mdast": "^4.0.0", + "@types/unist": "^3.0.0", + "ccount": "^2.0.0", + "devlop": "^1.1.0", + "mdast-util-from-markdown": "^2.0.0", + "mdast-util-to-markdown": "^2.0.0", + "parse-entities": "^4.0.0", + "stringify-entities": "^4.0.0", + "unist-util-remove-position": "^5.0.0", + "unist-util-stringify-position": "^4.0.0", + "vfile-message": "^4.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-mdxjs-esm": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/mdast-util-mdxjs-esm/-/mdast-util-mdxjs-esm-2.0.1.tgz", + "integrity": "sha512-EcmOpxsZ96CvlP03NghtH1EsLtr0n9Tm4lPUJUBccV9RwUOneqSycg19n5HGzCf+10LozMRSObtVr3ee1WoHtg==", + "dependencies": { + "@types/estree-jsx": "^1.0.0", + "@types/hast": "^3.0.0", + "@types/mdast": "^4.0.0", + "devlop": "^1.0.0", + "mdast-util-from-markdown": "^2.0.0", + "mdast-util-to-markdown": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-phrasing": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/mdast-util-phrasing/-/mdast-util-phrasing-4.1.0.tgz", + "integrity": "sha512-TqICwyvJJpBwvGAMZjj4J2n0X8QWp21b9l0o7eXyVJ25YNWYbJDVIyD1bZXE6WtV6RmKJVYmQAKWa0zWOABz2w==", + "dependencies": { + "@types/mdast": "^4.0.0", + "unist-util-is": "^6.0.0" }, "funding": { "type": "opencollective", @@ -4352,18 +4659,38 @@ } }, "node_modules/mdast-util-to-hast": { - "version": "12.3.0", - "resolved": "https://registry.npmjs.org/mdast-util-to-hast/-/mdast-util-to-hast-12.3.0.tgz", - "integrity": "sha512-pits93r8PhnIoU4Vy9bjW39M2jJ6/tdHyja9rrot9uujkN7UTU9SDnE6WNJz/IGyQk3XHX6yNNtrBH6cQzm8Hw==", - "dependencies": { - "@types/hast": "^2.0.0", - "@types/mdast": "^3.0.0", - "mdast-util-definitions": "^5.0.0", - "micromark-util-sanitize-uri": "^1.1.0", + "version": "13.1.0", + "resolved": "https://registry.npmjs.org/mdast-util-to-hast/-/mdast-util-to-hast-13.1.0.tgz", + "integrity": "sha512-/e2l/6+OdGp/FB+ctrJ9Avz71AN/GRH3oi/3KAx/kMnoUsD6q0woXlDT8lLEeViVKE7oZxE7RXzvO3T8kF2/sA==", + "dependencies": { + "@types/hast": "^3.0.0", + "@types/mdast": "^4.0.0", + "@ungap/structured-clone": "^1.0.0", + "devlop": "^1.0.0", + "micromark-util-sanitize-uri": "^2.0.0", "trim-lines": "^3.0.0", - "unist-util-generated": "^2.0.0", - "unist-util-position": "^4.0.0", - "unist-util-visit": "^4.0.0" + "unist-util-position": "^5.0.0", + "unist-util-visit": "^5.0.0", + "vfile": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-to-markdown": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/mdast-util-to-markdown/-/mdast-util-to-markdown-2.1.0.tgz", + "integrity": "sha512-SR2VnIEdVNCJbP6y7kVTJgPLifdr8WEU440fQec7qHoHOUz/oJ2jmNRqdDQ3rbiStOXb2mCDGTuwsK5OPUgYlQ==", + "dependencies": { + "@types/mdast": "^4.0.0", + "@types/unist": "^3.0.0", + "longest-streak": "^3.0.0", + "mdast-util-phrasing": "^4.0.0", + "mdast-util-to-string": "^4.0.0", + "micromark-util-decode-string": "^2.0.0", + "unist-util-visit": "^5.0.0", + "zwitch": "^2.0.0" }, "funding": { "type": "opencollective", @@ -4371,11 +4698,11 @@ } }, "node_modules/mdast-util-to-string": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/mdast-util-to-string/-/mdast-util-to-string-3.2.0.tgz", - "integrity": "sha512-V4Zn/ncyN1QNSqSBxTrMOLpjr+IKdHl2v3KVLoWmDPscP4r9GcCi71gjgvUV1SFSKh92AjAG4peFuBl2/YgCJg==", + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/mdast-util-to-string/-/mdast-util-to-string-4.0.0.tgz", + "integrity": "sha512-0H44vDimn51F0YwvxSJSm0eCDOJTRlmN0R1yBh4HLj9wiV1Dn0QoXGbvFAWj2hSItVTlCmBF1hqKlIyUBVFLPg==", "dependencies": { - "@types/mdast": "^3.0.0" + "@types/mdast": "^4.0.0" }, "funding": { "type": "opencollective", @@ -4391,9 +4718,9 @@ } }, "node_modules/micromark": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/micromark/-/micromark-3.2.0.tgz", - "integrity": "sha512-uD66tJj54JLYq0De10AhWycZWGQNUvDI55xPgk2sQM5kn1JYlhbCMTtEeT27+vAhW2FBQxLlOmS3pmA7/2z4aA==", + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/micromark/-/micromark-4.0.0.tgz", + "integrity": "sha512-o/sd0nMof8kYff+TqcDx3VSrgBTcZpSvYcAHIfHhv5VAuNmisCxjhx6YmxS8PFEpb9z5WKWKPdzf0jM23ro3RQ==", "funding": [ { "type": "GitHub Sponsors", @@ -4408,26 +4735,26 @@ "@types/debug": "^4.0.0", "debug": "^4.0.0", "decode-named-character-reference": "^1.0.0", - "micromark-core-commonmark": "^1.0.1", - "micromark-factory-space": "^1.0.0", - "micromark-util-character": "^1.0.0", - "micromark-util-chunked": "^1.0.0", - "micromark-util-combine-extensions": "^1.0.0", - "micromark-util-decode-numeric-character-reference": "^1.0.0", - "micromark-util-encode": "^1.0.0", - "micromark-util-normalize-identifier": "^1.0.0", - "micromark-util-resolve-all": "^1.0.0", - "micromark-util-sanitize-uri": "^1.0.0", - "micromark-util-subtokenize": "^1.0.0", - "micromark-util-symbol": "^1.0.0", - "micromark-util-types": "^1.0.1", - "uvu": "^0.5.0" + "devlop": "^1.0.0", + "micromark-core-commonmark": "^2.0.0", + "micromark-factory-space": "^2.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-chunked": "^2.0.0", + "micromark-util-combine-extensions": "^2.0.0", + "micromark-util-decode-numeric-character-reference": "^2.0.0", + "micromark-util-encode": "^2.0.0", + "micromark-util-normalize-identifier": "^2.0.0", + "micromark-util-resolve-all": "^2.0.0", + "micromark-util-sanitize-uri": "^2.0.0", + "micromark-util-subtokenize": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" } }, "node_modules/micromark-core-commonmark": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/micromark-core-commonmark/-/micromark-core-commonmark-1.1.0.tgz", - "integrity": "sha512-BgHO1aRbolh2hcrzL2d1La37V0Aoz73ymF8rAcKnohLy93titmv62E0gP8Hrx9PKcKrqCZ1BbLGbP3bEhoXYlw==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-core-commonmark/-/micromark-core-commonmark-2.0.0.tgz", + "integrity": "sha512-jThOz/pVmAYUtkroV3D5c1osFXAMv9e0ypGDOIZuCeAe91/sD6BoE2Sjzt30yuXtwOYUmySOhMas/PVyh02itA==", "funding": [ { "type": "GitHub Sponsors", @@ -4440,27 +4767,141 @@ ], "dependencies": { "decode-named-character-reference": "^1.0.0", - "micromark-factory-destination": "^1.0.0", - "micromark-factory-label": "^1.0.0", - "micromark-factory-space": "^1.0.0", - "micromark-factory-title": "^1.0.0", - "micromark-factory-whitespace": "^1.0.0", - "micromark-util-character": "^1.0.0", - "micromark-util-chunked": "^1.0.0", - "micromark-util-classify-character": "^1.0.0", - "micromark-util-html-tag-name": "^1.0.0", - "micromark-util-normalize-identifier": "^1.0.0", - "micromark-util-resolve-all": "^1.0.0", - "micromark-util-subtokenize": "^1.0.0", - "micromark-util-symbol": "^1.0.0", - "micromark-util-types": "^1.0.1", - "uvu": "^0.5.0" + "devlop": "^1.0.0", + "micromark-factory-destination": "^2.0.0", + "micromark-factory-label": "^2.0.0", + "micromark-factory-space": "^2.0.0", + "micromark-factory-title": "^2.0.0", + "micromark-factory-whitespace": "^2.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-chunked": "^2.0.0", + "micromark-util-classify-character": "^2.0.0", + "micromark-util-html-tag-name": "^2.0.0", + "micromark-util-normalize-identifier": "^2.0.0", + "micromark-util-resolve-all": "^2.0.0", + "micromark-util-subtokenize": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-extension-gfm": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/micromark-extension-gfm/-/micromark-extension-gfm-3.0.0.tgz", + "integrity": "sha512-vsKArQsicm7t0z2GugkCKtZehqUm31oeGBV/KVSorWSy8ZlNAv7ytjFhvaryUiCUJYqs+NoE6AFhpQvBTM6Q4w==", + "dependencies": { + "micromark-extension-gfm-autolink-literal": "^2.0.0", + "micromark-extension-gfm-footnote": "^2.0.0", + "micromark-extension-gfm-strikethrough": "^2.0.0", + "micromark-extension-gfm-table": "^2.0.0", + "micromark-extension-gfm-tagfilter": "^2.0.0", + "micromark-extension-gfm-task-list-item": "^2.0.0", + "micromark-util-combine-extensions": "^2.0.0", + "micromark-util-types": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/micromark-extension-gfm-autolink-literal": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-extension-gfm-autolink-literal/-/micromark-extension-gfm-autolink-literal-2.0.0.tgz", + "integrity": "sha512-rTHfnpt/Q7dEAK1Y5ii0W8bhfJlVJFnJMHIPisfPK3gpVNuOP0VnRl96+YJ3RYWV/P4gFeQoGKNlT3RhuvpqAg==", + "dependencies": { + "micromark-util-character": "^2.0.0", + "micromark-util-sanitize-uri": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/micromark-extension-gfm-footnote": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-extension-gfm-footnote/-/micromark-extension-gfm-footnote-2.0.0.tgz", + "integrity": "sha512-6Rzu0CYRKDv3BfLAUnZsSlzx3ak6HAoI85KTiijuKIz5UxZxbUI+pD6oHgw+6UtQuiRwnGRhzMmPRv4smcz0fg==", + "dependencies": { + "devlop": "^1.0.0", + "micromark-core-commonmark": "^2.0.0", + "micromark-factory-space": "^2.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-normalize-identifier": "^2.0.0", + "micromark-util-sanitize-uri": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/micromark-extension-gfm-strikethrough": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-extension-gfm-strikethrough/-/micromark-extension-gfm-strikethrough-2.0.0.tgz", + "integrity": "sha512-c3BR1ClMp5fxxmwP6AoOY2fXO9U8uFMKs4ADD66ahLTNcwzSCyRVU4k7LPV5Nxo/VJiR4TdzxRQY2v3qIUceCw==", + "dependencies": { + "devlop": "^1.0.0", + "micromark-util-chunked": "^2.0.0", + "micromark-util-classify-character": "^2.0.0", + "micromark-util-resolve-all": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/micromark-extension-gfm-table": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-extension-gfm-table/-/micromark-extension-gfm-table-2.0.0.tgz", + "integrity": "sha512-PoHlhypg1ItIucOaHmKE8fbin3vTLpDOUg8KAr8gRCF1MOZI9Nquq2i/44wFvviM4WuxJzc3demT8Y3dkfvYrw==", + "dependencies": { + "devlop": "^1.0.0", + "micromark-factory-space": "^2.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/micromark-extension-gfm-tagfilter": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-extension-gfm-tagfilter/-/micromark-extension-gfm-tagfilter-2.0.0.tgz", + "integrity": "sha512-xHlTOmuCSotIA8TW1mDIM6X2O1SiX5P9IuDtqGonFhEK0qgRI4yeC6vMxEV2dgyr2TiD+2PQ10o+cOhdVAcwfg==", + "dependencies": { + "micromark-util-types": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/micromark-extension-gfm-task-list-item": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-extension-gfm-task-list-item/-/micromark-extension-gfm-task-list-item-2.0.1.tgz", + "integrity": "sha512-cY5PzGcnULaN5O7T+cOzfMoHjBW7j+T9D2sucA5d/KbsBTPcYdebm9zUd9zzdgJGCwahV+/W78Z3nbulBYVbTw==", + "dependencies": { + "devlop": "^1.0.0", + "micromark-factory-space": "^2.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" } }, "node_modules/micromark-factory-destination": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/micromark-factory-destination/-/micromark-factory-destination-1.1.0.tgz", - "integrity": "sha512-XaNDROBgx9SgSChd69pjiGKbV+nfHGDPVYFs5dOoDd7ZnMAE+Cuu91BCpsY8RT2NP9vo/B8pds2VQNCLiu0zhg==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-factory-destination/-/micromark-factory-destination-2.0.0.tgz", + "integrity": "sha512-j9DGrQLm/Uhl2tCzcbLhy5kXsgkHUrjJHg4fFAeoMRwJmJerT9aw4FEhIbZStWN8A3qMwOp1uzHr4UL8AInxtA==", "funding": [ { "type": "GitHub Sponsors", @@ -4472,15 +4913,15 @@ } ], "dependencies": { - "micromark-util-character": "^1.0.0", - "micromark-util-symbol": "^1.0.0", - "micromark-util-types": "^1.0.0" + "micromark-util-character": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" } }, "node_modules/micromark-factory-label": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/micromark-factory-label/-/micromark-factory-label-1.1.0.tgz", - "integrity": "sha512-OLtyez4vZo/1NjxGhcpDSbHQ+m0IIGnT8BoPamh+7jVlzLJBH98zzuCoUeMxvM6WsNeh8wx8cKvqLiPHEACn0w==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-factory-label/-/micromark-factory-label-2.0.0.tgz", + "integrity": "sha512-RR3i96ohZGde//4WSe/dJsxOX6vxIg9TimLAS3i4EhBAFx8Sm5SmqVfR8E87DPSR31nEAjZfbt91OMZWcNgdZw==", "funding": [ { "type": "GitHub Sponsors", @@ -4492,16 +4933,16 @@ } ], "dependencies": { - "micromark-util-character": "^1.0.0", - "micromark-util-symbol": "^1.0.0", - "micromark-util-types": "^1.0.0", - "uvu": "^0.5.0" + "devlop": "^1.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" } }, "node_modules/micromark-factory-space": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/micromark-factory-space/-/micromark-factory-space-1.1.0.tgz", - "integrity": "sha512-cRzEj7c0OL4Mw2v6nwzttyOZe8XY/Z8G0rzmWQZTBi/jjwyw/U4uqKtUORXQrR5bAZZnbTI/feRV/R7hc4jQYQ==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-factory-space/-/micromark-factory-space-2.0.0.tgz", + "integrity": "sha512-TKr+LIDX2pkBJXFLzpyPyljzYK3MtmllMUMODTQJIUfDGncESaqB90db9IAUcz4AZAJFdd8U9zOp9ty1458rxg==", "funding": [ { "type": "GitHub Sponsors", @@ -4513,14 +4954,14 @@ } ], "dependencies": { - "micromark-util-character": "^1.0.0", - "micromark-util-types": "^1.0.0" + "micromark-util-character": "^2.0.0", + "micromark-util-types": "^2.0.0" } }, "node_modules/micromark-factory-title": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/micromark-factory-title/-/micromark-factory-title-1.1.0.tgz", - "integrity": "sha512-J7n9R3vMmgjDOCY8NPw55jiyaQnH5kBdV2/UXCtZIpnHH3P6nHUKaH7XXEYuWwx/xUJcawa8plLBEjMPU24HzQ==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-factory-title/-/micromark-factory-title-2.0.0.tgz", + "integrity": "sha512-jY8CSxmpWLOxS+t8W+FG3Xigc0RDQA9bKMY/EwILvsesiRniiVMejYTE4wumNc2f4UbAa4WsHqe3J1QS1sli+A==", "funding": [ { "type": "GitHub Sponsors", @@ -4532,16 +4973,16 @@ } ], "dependencies": { - "micromark-factory-space": "^1.0.0", - "micromark-util-character": "^1.0.0", - "micromark-util-symbol": "^1.0.0", - "micromark-util-types": "^1.0.0" + "micromark-factory-space": "^2.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" } }, "node_modules/micromark-factory-whitespace": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/micromark-factory-whitespace/-/micromark-factory-whitespace-1.1.0.tgz", - "integrity": "sha512-v2WlmiymVSp5oMg+1Q0N1Lxmt6pMhIHD457whWM7/GUlEks1hI9xj5w3zbc4uuMKXGisksZk8DzP2UyGbGqNsQ==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-factory-whitespace/-/micromark-factory-whitespace-2.0.0.tgz", + "integrity": "sha512-28kbwaBjc5yAI1XadbdPYHX/eDnqaUFVikLwrO7FDnKG7lpgxnvk/XGRhX/PN0mOZ+dBSZ+LgunHS+6tYQAzhA==", "funding": [ { "type": "GitHub Sponsors", @@ -4553,16 +4994,16 @@ } ], "dependencies": { - "micromark-factory-space": "^1.0.0", - "micromark-util-character": "^1.0.0", - "micromark-util-symbol": "^1.0.0", - "micromark-util-types": "^1.0.0" + "micromark-factory-space": "^2.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" } }, "node_modules/micromark-util-character": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-1.2.0.tgz", - "integrity": "sha512-lXraTwcX3yH/vMDaFWCQJP1uIszLVebzUa3ZHdrgxr7KEU/9mL4mVgCpGbyhvNLNlauROiNUq7WN5u7ndbY6xg==", + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.0.tgz", + "integrity": "sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==", "funding": [ { "type": "GitHub Sponsors", @@ -4574,14 +5015,14 @@ } ], "dependencies": { - "micromark-util-symbol": "^1.0.0", - "micromark-util-types": "^1.0.0" + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" } }, "node_modules/micromark-util-chunked": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/micromark-util-chunked/-/micromark-util-chunked-1.1.0.tgz", - "integrity": "sha512-Ye01HXpkZPNcV6FiyoW2fGZDUw4Yc7vT0E9Sad83+bEDiCJ1uXu0S3mr8WLpsz3HaG3x2q0HM6CTuPdcZcluFQ==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-util-chunked/-/micromark-util-chunked-2.0.0.tgz", + "integrity": "sha512-anK8SWmNphkXdaKgz5hJvGa7l00qmcaUQoMYsBwDlSKFKjc6gjGXPDw3FNL3Nbwq5L8gE+RCbGqTw49FK5Qyvg==", "funding": [ { "type": "GitHub Sponsors", @@ -4593,13 +5034,13 @@ } ], "dependencies": { - "micromark-util-symbol": "^1.0.0" + "micromark-util-symbol": "^2.0.0" } }, "node_modules/micromark-util-classify-character": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/micromark-util-classify-character/-/micromark-util-classify-character-1.1.0.tgz", - "integrity": "sha512-SL0wLxtKSnklKSUplok1WQFoGhUdWYKggKUiqhX+Swala+BtptGCu5iPRc+xvzJ4PXE/hwM3FNXsfEVgoZsWbw==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-util-classify-character/-/micromark-util-classify-character-2.0.0.tgz", + "integrity": "sha512-S0ze2R9GH+fu41FA7pbSqNWObo/kzwf8rN/+IGlW/4tC6oACOs8B++bh+i9bVyNnwCcuksbFwsBme5OCKXCwIw==", "funding": [ { "type": "GitHub Sponsors", @@ -4611,15 +5052,15 @@ } ], "dependencies": { - "micromark-util-character": "^1.0.0", - "micromark-util-symbol": "^1.0.0", - "micromark-util-types": "^1.0.0" + "micromark-util-character": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" } }, "node_modules/micromark-util-combine-extensions": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/micromark-util-combine-extensions/-/micromark-util-combine-extensions-1.1.0.tgz", - "integrity": "sha512-Q20sp4mfNf9yEqDL50WwuWZHUrCO4fEyeDCnMGmG5Pr0Cz15Uo7KBs6jq+dq0EgX4DPwwrh9m0X+zPV1ypFvUA==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-util-combine-extensions/-/micromark-util-combine-extensions-2.0.0.tgz", + "integrity": "sha512-vZZio48k7ON0fVS3CUgFatWHoKbbLTK/rT7pzpJ4Bjp5JjkZeasRfrS9wsBdDJK2cJLHMckXZdzPSSr1B8a4oQ==", "funding": [ { "type": "GitHub Sponsors", @@ -4631,14 +5072,14 @@ } ], "dependencies": { - "micromark-util-chunked": "^1.0.0", - "micromark-util-types": "^1.0.0" + "micromark-util-chunked": "^2.0.0", + "micromark-util-types": "^2.0.0" } }, "node_modules/micromark-util-decode-numeric-character-reference": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/micromark-util-decode-numeric-character-reference/-/micromark-util-decode-numeric-character-reference-1.1.0.tgz", - "integrity": "sha512-m9V0ExGv0jB1OT21mrWcuf4QhP46pH1KkfWy9ZEezqHKAxkj4mPCy3nIH1rkbdMlChLHX531eOrymlwyZIf2iw==", + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-decode-numeric-character-reference/-/micromark-util-decode-numeric-character-reference-2.0.1.tgz", + "integrity": "sha512-bmkNc7z8Wn6kgjZmVHOX3SowGmVdhYS7yBpMnuMnPzDq/6xwVA604DuOXMZTO1lvq01g+Adfa0pE2UKGlxL1XQ==", "funding": [ { "type": "GitHub Sponsors", @@ -4650,13 +5091,13 @@ } ], "dependencies": { - "micromark-util-symbol": "^1.0.0" + "micromark-util-symbol": "^2.0.0" } }, "node_modules/micromark-util-decode-string": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/micromark-util-decode-string/-/micromark-util-decode-string-1.1.0.tgz", - "integrity": "sha512-YphLGCK8gM1tG1bd54azwyrQRjCFcmgj2S2GoJDNnh4vYtnL38JS8M4gpxzOPNyHdNEpheyWXCTnnTDY3N+NVQ==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-util-decode-string/-/micromark-util-decode-string-2.0.0.tgz", + "integrity": "sha512-r4Sc6leeUTn3P6gk20aFMj2ntPwn6qpDZqWvYmAG6NgvFTIlj4WtrAudLi65qYoaGdXYViXYw2pkmn7QnIFasA==", "funding": [ { "type": "GitHub Sponsors", @@ -4669,15 +5110,15 @@ ], "dependencies": { "decode-named-character-reference": "^1.0.0", - "micromark-util-character": "^1.0.0", - "micromark-util-decode-numeric-character-reference": "^1.0.0", - "micromark-util-symbol": "^1.0.0" + "micromark-util-character": "^2.0.0", + "micromark-util-decode-numeric-character-reference": "^2.0.0", + "micromark-util-symbol": "^2.0.0" } }, "node_modules/micromark-util-encode": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/micromark-util-encode/-/micromark-util-encode-1.1.0.tgz", - "integrity": "sha512-EuEzTWSTAj9PA5GOAs992GzNh2dGQO52UvAbtSOMvXTxv3Criqb6IOzJUBCmEqrrXSblJIJBbFFv6zPxpreiJw==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-util-encode/-/micromark-util-encode-2.0.0.tgz", + "integrity": "sha512-pS+ROfCXAGLWCOc8egcBvT0kf27GoWMqtdarNfDcjb6YLuV5cM3ioG45Ys2qOVqeqSbjaKg72vU+Wby3eddPsA==", "funding": [ { "type": "GitHub Sponsors", @@ -4690,9 +5131,9 @@ ] }, "node_modules/micromark-util-html-tag-name": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/micromark-util-html-tag-name/-/micromark-util-html-tag-name-1.2.0.tgz", - "integrity": "sha512-VTQzcuQgFUD7yYztuQFKXT49KghjtETQ+Wv/zUjGSGBioZnkA4P1XXZPT1FHeJA6RwRXSF47yvJ1tsJdoxwO+Q==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-util-html-tag-name/-/micromark-util-html-tag-name-2.0.0.tgz", + "integrity": "sha512-xNn4Pqkj2puRhKdKTm8t1YHC/BAjx6CEwRFXntTaRf/x16aqka6ouVoutm+QdkISTlT7e2zU7U4ZdlDLJd2Mcw==", "funding": [ { "type": "GitHub Sponsors", @@ -4705,9 +5146,9 @@ ] }, "node_modules/micromark-util-normalize-identifier": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/micromark-util-normalize-identifier/-/micromark-util-normalize-identifier-1.1.0.tgz", - "integrity": "sha512-N+w5vhqrBihhjdpM8+5Xsxy71QWqGn7HYNUvch71iV2PM7+E3uWGox1Qp90loa1ephtCxG2ftRV/Conitc6P2Q==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-util-normalize-identifier/-/micromark-util-normalize-identifier-2.0.0.tgz", + "integrity": "sha512-2xhYT0sfo85FMrUPtHcPo2rrp1lwbDEEzpx7jiH2xXJLqBuy4H0GgXk5ToU8IEwoROtXuL8ND0ttVa4rNqYK3w==", "funding": [ { "type": "GitHub Sponsors", @@ -4719,13 +5160,13 @@ } ], "dependencies": { - "micromark-util-symbol": "^1.0.0" + "micromark-util-symbol": "^2.0.0" } }, "node_modules/micromark-util-resolve-all": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/micromark-util-resolve-all/-/micromark-util-resolve-all-1.1.0.tgz", - "integrity": "sha512-b/G6BTMSg+bX+xVCshPTPyAu2tmA0E4X98NSR7eIbeC6ycCqCeE7wjfDIgzEbkzdEVJXRtOG4FbEm/uGbCRouA==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-util-resolve-all/-/micromark-util-resolve-all-2.0.0.tgz", + "integrity": "sha512-6KU6qO7DZ7GJkaCgwBNtplXCvGkJToU86ybBAUdavvgsCiG8lSSvYxr9MhwmQ+udpzywHsl4RpGJsYWG1pDOcA==", "funding": [ { "type": "GitHub Sponsors", @@ -4737,13 +5178,13 @@ } ], "dependencies": { - "micromark-util-types": "^1.0.0" + "micromark-util-types": "^2.0.0" } }, "node_modules/micromark-util-sanitize-uri": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/micromark-util-sanitize-uri/-/micromark-util-sanitize-uri-1.2.0.tgz", - "integrity": "sha512-QO4GXv0XZfWey4pYFndLUKEAktKkG5kZTdUNaTAkzbuJxn2tNBOr+QtxR2XpWaMhbImT2dPzyLrPXLlPhph34A==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-util-sanitize-uri/-/micromark-util-sanitize-uri-2.0.0.tgz", + "integrity": "sha512-WhYv5UEcZrbAtlsnPuChHUAsu/iBPOVaEVsntLBIdpibO0ddy8OzavZz3iL2xVvBZOpolujSliP65Kq0/7KIYw==", "funding": [ { "type": "GitHub Sponsors", @@ -4755,15 +5196,15 @@ } ], "dependencies": { - "micromark-util-character": "^1.0.0", - "micromark-util-encode": "^1.0.0", - "micromark-util-symbol": "^1.0.0" + "micromark-util-character": "^2.0.0", + "micromark-util-encode": "^2.0.0", + "micromark-util-symbol": "^2.0.0" } }, "node_modules/micromark-util-subtokenize": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/micromark-util-subtokenize/-/micromark-util-subtokenize-1.1.0.tgz", - "integrity": "sha512-kUQHyzRoxvZO2PuLzMt2P/dwVsTiivCK8icYTeR+3WgbuPqfHgPPy7nFKbeqRivBvn/3N3GBiNC+JRTMSxEC7A==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-util-subtokenize/-/micromark-util-subtokenize-2.0.0.tgz", + "integrity": "sha512-vc93L1t+gpR3p8jxeVdaYlbV2jTYteDje19rNSS/H5dlhxUYll5Fy6vJ2cDwP8RnsXi818yGty1ayP55y3W6fg==", "funding": [ { "type": "GitHub Sponsors", @@ -4775,16 +5216,16 @@ } ], "dependencies": { - "micromark-util-chunked": "^1.0.0", - "micromark-util-symbol": "^1.0.0", - "micromark-util-types": "^1.0.0", - "uvu": "^0.5.0" + "devlop": "^1.0.0", + "micromark-util-chunked": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" } }, "node_modules/micromark-util-symbol": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/micromark-util-symbol/-/micromark-util-symbol-1.1.0.tgz", - "integrity": "sha512-uEjpEYY6KMs1g7QfJ2eX1SQEV+ZT4rUD3UcF6l57acZvLNK7PBZL+ty82Z1qhK1/yXIY4bdx04FKMgR0g4IAag==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-util-symbol/-/micromark-util-symbol-2.0.0.tgz", + "integrity": "sha512-8JZt9ElZ5kyTnO94muPxIGS8oyElRJaiJO8EzV6ZSyGQ1Is8xwl4Q45qU5UOg+bGH4AikWziz0iN4sFLWs8PGw==", "funding": [ { "type": "GitHub Sponsors", @@ -4797,9 +5238,9 @@ ] }, "node_modules/micromark-util-types": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/micromark-util-types/-/micromark-util-types-1.1.0.tgz", - "integrity": "sha512-ukRBgie8TIAcacscVHSiddHjO4k/q3pnedmzMQ4iwDcK0FtFCohKOlFbaOL/mPgfnPsL3C1ZyxJa4sbWrBl3jg==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/micromark-util-types/-/micromark-util-types-2.0.0.tgz", + "integrity": "sha512-oNh6S2WMHWRZrmutsRmDDfkzKtxF+bc2VxLC9dvtrDIRFln627VsFP6fLMgTryGDljgLPjkrzQSDcPrjPyDJ5w==", "funding": [ { "type": "GitHub Sponsors", @@ -4867,14 +5308,6 @@ "resolved": "https://registry.npmjs.org/mkdirp-classic/-/mkdirp-classic-0.5.3.tgz", "integrity": "sha512-gKLcREMhtuZRwRAfqP3RFW+TK4JqApVBtOIftVgjuABpAtpxhPGaDcfvbhNvD0B8iD1oUr/txX35NjcaY6Ns/A==" }, - "node_modules/mri": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/mri/-/mri-1.2.0.tgz", - "integrity": "sha512-tzzskb3bG8LvYGFF/mDTpq3jpI6Q9wc3LEmBaghu+DdCssd1FakN7Bc0hVNmEyGq1bq3RgfkCb3cmQLpNPOroA==", - "engines": { - "node": ">=4" - } - }, "node_modules/ms": { "version": "2.1.2", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", @@ -5219,6 +5652,30 @@ "node": ">=6" } }, + "node_modules/parse-entities": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/parse-entities/-/parse-entities-4.0.1.tgz", + "integrity": "sha512-SWzvYcSJh4d/SGLIOQfZ/CoNv6BTlI6YEQ7Nj82oDVnRpwe/Z/F1EMx42x3JAOwGBlCjeCH0BRJQbQ/opHL17w==", + "dependencies": { + "@types/unist": "^2.0.0", + "character-entities": "^2.0.0", + "character-entities-legacy": "^3.0.0", + "character-reference-invalid": "^2.0.0", + "decode-named-character-reference": "^1.0.0", + "is-alphanumerical": "^2.0.0", + "is-decimal": "^2.0.0", + "is-hexadecimal": "^2.0.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/parse-entities/node_modules/@types/unist": { + "version": "2.0.10", + "resolved": "https://registry.npmjs.org/@types/unist/-/unist-2.0.10.tgz", + "integrity": "sha512-IfYcSBWE3hLpBg8+X2SEa8LVkJdJEkT2Ese2aaLs3ptGdVtABxndrMaxuFlQ1qdFf9Q5rDvDpxI3WwgvKFAsQA==" + }, "node_modules/path-exists": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-4.0.0.tgz", @@ -5558,9 +6015,9 @@ "integrity": "sha512-SVtmxhRE/CGkn3eZY1T6pC8Nln6Fr/lu1mKSgRud0eC73whjGfoAogbn78LkD8aFL0zz3bAFerKSnOl7NlErBA==" }, "node_modules/property-information": { - "version": "6.4.0", - "resolved": "https://registry.npmjs.org/property-information/-/property-information-6.4.0.tgz", - "integrity": "sha512-9t5qARVofg2xQqKtytzt+lZ4d1Qvj8t5B8fEwXK6qOfgRLgH/b13QlgEyDh033NOS31nXeFbYv7CLUDG1CeifQ==", + "version": "6.4.1", + "resolved": "https://registry.npmjs.org/property-information/-/property-information-6.4.1.tgz", + "integrity": "sha512-OHYtXfu5aI2sS2LWFSN5rgJjrQ4pCy8i1jubJLe2QvMF8JJ++HXTUIVWFLfXJoaOfvYYjk2SN8J2wFUWIGXT4w==", "funding": { "type": "github", "url": "https://github.com/sponsors/wooorm" @@ -5725,40 +6182,30 @@ "integrity": "sha512-xWGDIW6x921xtzPkhiULtthJHoJvBbF3q26fzloPCK0hsvxtPVelvftw3zjbHWSkR2km9Z+4uxbDDK/6Zw9B8w==" }, "node_modules/react-markdown": { - "version": "8.0.7", - "resolved": "https://registry.npmjs.org/react-markdown/-/react-markdown-8.0.7.tgz", - "integrity": "sha512-bvWbzG4MtOU62XqBx3Xx+zB2raaFFsq4mYiAzfjXJMEz2sixgeAfraA3tvzULF02ZdOMUOKTBFFaZJDDrq+BJQ==", - "dependencies": { - "@types/hast": "^2.0.0", - "@types/prop-types": "^15.0.0", - "@types/unist": "^2.0.0", - "comma-separated-tokens": "^2.0.0", - "hast-util-whitespace": "^2.0.0", - "prop-types": "^15.0.0", - "property-information": "^6.0.0", - "react-is": "^18.0.0", - "remark-parse": "^10.0.0", - "remark-rehype": "^10.0.0", - "space-separated-tokens": "^2.0.0", - "style-to-object": "^0.4.0", - "unified": "^10.0.0", - "unist-util-visit": "^4.0.0", - "vfile": "^5.0.0" + "version": "9.0.1", + "resolved": "https://registry.npmjs.org/react-markdown/-/react-markdown-9.0.1.tgz", + "integrity": "sha512-186Gw/vF1uRkydbsOIkcGXw7aHq0sZOCRFFjGrr7b9+nVZg4UfA4enXCaxm4fUzecU38sWfrNDitGhshuU7rdg==", + "dependencies": { + "@types/hast": "^3.0.0", + "devlop": "^1.0.0", + "hast-util-to-jsx-runtime": "^2.0.0", + "html-url-attributes": "^3.0.0", + "mdast-util-to-hast": "^13.0.0", + "remark-parse": "^11.0.0", + "remark-rehype": "^11.0.0", + "unified": "^11.0.0", + "unist-util-visit": "^5.0.0", + "vfile": "^6.0.0" }, "funding": { "type": "opencollective", "url": "https://opencollective.com/unified" }, "peerDependencies": { - "@types/react": ">=16", - "react": ">=16" + "@types/react": ">=18", + "react": ">=18" } }, - "node_modules/react-markdown/node_modules/react-is": { - "version": "18.2.0", - "resolved": "https://registry.npmjs.org/react-is/-/react-is-18.2.0.tgz", - "integrity": "sha512-xWGDIW6x921xtzPkhiULtthJHoJvBbF3q26fzloPCK0hsvxtPVelvftw3zjbHWSkR2km9Z+4uxbDDK/6Zw9B8w==" - }, "node_modules/react-smooth": { "version": "2.0.5", "resolved": "https://registry.npmjs.org/react-smooth/-/react-smooth-2.0.5.tgz", @@ -5916,14 +6363,32 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/remark-gfm": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/remark-gfm/-/remark-gfm-4.0.0.tgz", + "integrity": "sha512-U92vJgBPkbw4Zfu/IiW2oTZLSL3Zpv+uI7My2eq8JxKgqraFdU8YUGicEJCEgSbeaG+QDFqIcwwfMTOEelPxuA==", + "dependencies": { + "@types/mdast": "^4.0.0", + "mdast-util-gfm": "^3.0.0", + "micromark-extension-gfm": "^3.0.0", + "remark-parse": "^11.0.0", + "remark-stringify": "^11.0.0", + "unified": "^11.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, "node_modules/remark-parse": { - "version": "10.0.2", - "resolved": "https://registry.npmjs.org/remark-parse/-/remark-parse-10.0.2.tgz", - "integrity": "sha512-3ydxgHa/ZQzG8LvC7jTXccARYDcRld3VfcgIIFs7bI6vbRSxJJmzgLEIIoYKyrfhaY+ujuWaf/PJiMZXoiCXgw==", + "version": "11.0.0", + "resolved": "https://registry.npmjs.org/remark-parse/-/remark-parse-11.0.0.tgz", + "integrity": "sha512-FCxlKLNGknS5ba/1lmpYijMUzX2esxW5xQqjWxw2eHFfS2MSdaHVINFmhjo+qN1WhZhNimq0dZATN9pH0IDrpA==", "dependencies": { - "@types/mdast": "^3.0.0", - "mdast-util-from-markdown": "^1.0.0", - "unified": "^10.0.0" + "@types/mdast": "^4.0.0", + "mdast-util-from-markdown": "^2.0.0", + "micromark-util-types": "^2.0.0", + "unified": "^11.0.0" }, "funding": { "type": "opencollective", @@ -5931,14 +6396,29 @@ } }, "node_modules/remark-rehype": { - "version": "10.1.0", - "resolved": "https://registry.npmjs.org/remark-rehype/-/remark-rehype-10.1.0.tgz", - "integrity": "sha512-EFmR5zppdBp0WQeDVZ/b66CWJipB2q2VLNFMabzDSGR66Z2fQii83G5gTBbgGEnEEA0QRussvrFHxk1HWGJskw==", + "version": "11.1.0", + "resolved": "https://registry.npmjs.org/remark-rehype/-/remark-rehype-11.1.0.tgz", + "integrity": "sha512-z3tJrAs2kIs1AqIIy6pzHmAHlF1hWQ+OdY4/hv+Wxe35EhyLKcajL33iUEn3ScxtFox9nUvRufR/Zre8Q08H/g==", + "dependencies": { + "@types/hast": "^3.0.0", + "@types/mdast": "^4.0.0", + "mdast-util-to-hast": "^13.0.0", + "unified": "^11.0.0", + "vfile": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/remark-stringify": { + "version": "11.0.0", + "resolved": "https://registry.npmjs.org/remark-stringify/-/remark-stringify-11.0.0.tgz", + "integrity": "sha512-1OSmLd3awB/t8qdoEOMazZkNsfVTeY4fTsgzcQFdXNq8ToTN4ZGwrMnlda4K6smTFKD+GRV6O48i6Z4iKgPPpw==", "dependencies": { - "@types/hast": "^2.0.0", - "@types/mdast": "^3.0.0", - "mdast-util-to-hast": "^12.1.0", - "unified": "^10.0.0" + "@types/mdast": "^4.0.0", + "mdast-util-to-markdown": "^2.0.0", + "unified": "^11.0.0" }, "funding": { "type": "opencollective", @@ -6025,17 +6505,6 @@ "queue-microtask": "^1.2.2" } }, - "node_modules/sade": { - "version": "1.8.1", - "resolved": "https://registry.npmjs.org/sade/-/sade-1.8.1.tgz", - "integrity": "sha512-xal3CZX1Xlo/k4ApwCFrHVACi9fBqJ7V+mwhBsuf/1IOKbBy098Fex+Wa/5QMubw09pSZ/u8EY8PWgevJsXp1A==", - "dependencies": { - "mri": "^1.1.0" - }, - "engines": { - "node": ">=6" - } - }, "node_modules/safe-array-concat": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/safe-array-concat/-/safe-array-concat-1.0.1.tgz", @@ -6444,6 +6913,19 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/stringify-entities": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/stringify-entities/-/stringify-entities-4.0.3.tgz", + "integrity": "sha512-BP9nNHMhhfcMbiuQKCqMjhDP5yBCAxsPu4pHFFzJ6Alo9dZgY4VLDPutXqIjpRiMoKdp7Av85Gr73Q5uH9k7+g==", + "dependencies": { + "character-entities-html4": "^2.0.0", + "character-entities-legacy": "^3.0.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/strip-ansi": { "version": "6.0.1", "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", @@ -6491,11 +6973,11 @@ } }, "node_modules/style-to-object": { - "version": "0.4.4", - "resolved": "https://registry.npmjs.org/style-to-object/-/style-to-object-0.4.4.tgz", - "integrity": "sha512-HYNoHZa2GorYNyqiCaBgsxvcJIn7OHq6inEga+E6Ke3m5JkoqpQbnFssk4jwe+K7AhGa2fcha4wSOf1Kn01dMg==", + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/style-to-object/-/style-to-object-1.0.5.tgz", + "integrity": "sha512-rDRwHtoDD3UMMrmZ6BzOW0naTjMsVZLIjsGleSKS/0Oz+cgCfAPRspaqJuE8rDzpKha/nEvnM0IF4seEAZUTKQ==", "dependencies": { - "inline-style-parser": "0.1.1" + "inline-style-parser": "0.2.2" } }, "node_modules/styled-components": { @@ -6995,50 +7477,54 @@ } }, "node_modules/unified": { - "version": "10.1.2", - "resolved": "https://registry.npmjs.org/unified/-/unified-10.1.2.tgz", - "integrity": "sha512-pUSWAi/RAnVy1Pif2kAoeWNBa3JVrx0MId2LASj8G+7AiHWoKZNTomq6LG326T68U7/e263X6fTdcXIy7XnF7Q==", + "version": "11.0.4", + "resolved": "https://registry.npmjs.org/unified/-/unified-11.0.4.tgz", + "integrity": "sha512-apMPnyLjAX+ty4OrNap7yumyVAMlKx5IWU2wlzzUdYJO9A8f1p9m/gywF/GM2ZDFcjQPrx59Mc90KwmxsoklxQ==", "dependencies": { - "@types/unist": "^2.0.0", + "@types/unist": "^3.0.0", "bail": "^2.0.0", + "devlop": "^1.0.0", "extend": "^3.0.0", - "is-buffer": "^2.0.0", "is-plain-obj": "^4.0.0", "trough": "^2.0.0", - "vfile": "^5.0.0" + "vfile": "^6.0.0" }, "funding": { "type": "opencollective", "url": "https://opencollective.com/unified" } }, - "node_modules/unist-util-generated": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/unist-util-generated/-/unist-util-generated-2.0.1.tgz", - "integrity": "sha512-qF72kLmPxAw0oN2fwpWIqbXAVyEqUzDHMsbtPvOudIlUzXYFIeQIuxXQCRCFh22B7cixvU0MG7m3MW8FTq/S+A==", + "node_modules/unist-util-is": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/unist-util-is/-/unist-util-is-6.0.0.tgz", + "integrity": "sha512-2qCTHimwdxLfz+YzdGfkqNlH0tLi9xjTnHddPmJwtIG9MGsdbutfTc4P+haPD7l7Cjxf/WZj+we5qfVPvvxfYw==", + "dependencies": { + "@types/unist": "^3.0.0" + }, "funding": { "type": "opencollective", "url": "https://opencollective.com/unified" } }, - "node_modules/unist-util-is": { - "version": "5.2.1", - "resolved": "https://registry.npmjs.org/unist-util-is/-/unist-util-is-5.2.1.tgz", - "integrity": "sha512-u9njyyfEh43npf1M+yGKDGVPbY/JWEemg5nH05ncKPfi+kBbKBJoTdsogMu33uhytuLlv9y0O7GH7fEdwLdLQw==", + "node_modules/unist-util-position": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/unist-util-position/-/unist-util-position-5.0.0.tgz", + "integrity": "sha512-fucsC7HjXvkB5R3kTCO7kUjRdrS0BJt3M/FPxmHMBOm8JQi2BsHAHFsy27E0EolP8rp0NzXsJ+jNPyDWvOJZPA==", "dependencies": { - "@types/unist": "^2.0.0" + "@types/unist": "^3.0.0" }, "funding": { "type": "opencollective", "url": "https://opencollective.com/unified" } }, - "node_modules/unist-util-position": { - "version": "4.0.4", - "resolved": "https://registry.npmjs.org/unist-util-position/-/unist-util-position-4.0.4.tgz", - "integrity": "sha512-kUBE91efOWfIVBo8xzh/uZQ7p9ffYRtUbMRZBNFYwf0RK8koUMx6dGUfwylLOKmaT2cs4wSW96QoYUSXAyEtpg==", + "node_modules/unist-util-remove-position": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/unist-util-remove-position/-/unist-util-remove-position-5.0.0.tgz", + "integrity": "sha512-Hp5Kh3wLxv0PHj9m2yZhhLt58KzPtEYKQQ4yxfYFEO7EvHwzyDYnduhHnY1mDxoqr7VUwVuHXk9RXKIiYS1N8Q==", "dependencies": { - "@types/unist": "^2.0.0" + "@types/unist": "^3.0.0", + "unist-util-visit": "^5.0.0" }, "funding": { "type": "opencollective", @@ -7046,11 +7532,11 @@ } }, "node_modules/unist-util-stringify-position": { - "version": "3.0.3", - "resolved": "https://registry.npmjs.org/unist-util-stringify-position/-/unist-util-stringify-position-3.0.3.tgz", - "integrity": "sha512-k5GzIBZ/QatR8N5X2y+drfpWG8IDBzdnVj6OInRNWm1oXrzydiaAT2OQiA8DPRRZyAKb9b6I2a6PxYklZD0gKg==", + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/unist-util-stringify-position/-/unist-util-stringify-position-4.0.0.tgz", + "integrity": "sha512-0ASV06AAoKCDkS2+xw5RXJywruurpbC4JZSm7nr7MOt1ojAzvyyaO+UxZf18j8FCF6kmzCZKcAgN/yu2gm2XgQ==", "dependencies": { - "@types/unist": "^2.0.0" + "@types/unist": "^3.0.0" }, "funding": { "type": "opencollective", @@ -7058,13 +7544,13 @@ } }, "node_modules/unist-util-visit": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/unist-util-visit/-/unist-util-visit-4.1.2.tgz", - "integrity": "sha512-MSd8OUGISqHdVvfY9TPhyK2VdUrPgxkUtWSuMHF6XAAFuL4LokseigBnZtPnJMu+FbynTkFNnFlyjxpVKujMRg==", + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/unist-util-visit/-/unist-util-visit-5.0.0.tgz", + "integrity": "sha512-MR04uvD+07cwl/yhVuVWAtw+3GOR/knlL55Nd/wAdblk27GCVt3lqpTivy/tkJcZoNPzTwS1Y+KMojlLDhoTzg==", "dependencies": { - "@types/unist": "^2.0.0", - "unist-util-is": "^5.0.0", - "unist-util-visit-parents": "^5.1.1" + "@types/unist": "^3.0.0", + "unist-util-is": "^6.0.0", + "unist-util-visit-parents": "^6.0.0" }, "funding": { "type": "opencollective", @@ -7072,12 +7558,12 @@ } }, "node_modules/unist-util-visit-parents": { - "version": "5.1.3", - "resolved": "https://registry.npmjs.org/unist-util-visit-parents/-/unist-util-visit-parents-5.1.3.tgz", - "integrity": "sha512-x6+y8g7wWMyQhL1iZfhIPhDAs7Xwbn9nRosDXl7qoPTSCy0yNxnKc+hWokFifWQIDGi154rdUqKvbCa4+1kLhg==", + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/unist-util-visit-parents/-/unist-util-visit-parents-6.0.1.tgz", + "integrity": "sha512-L/PqWzfTP9lzzEa6CKs0k2nARxTdZduw3zyh8d2NVBnsyvHjSX4TWse388YrrQKbvI8w20fGjGlhgT96WwKykw==", "dependencies": { - "@types/unist": "^2.0.0", - "unist-util-is": "^5.0.0" + "@types/unist": "^3.0.0", + "unist-util-is": "^6.0.0" }, "funding": { "type": "opencollective", @@ -7135,32 +7621,14 @@ "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==" }, - "node_modules/uvu": { - "version": "0.5.6", - "resolved": "https://registry.npmjs.org/uvu/-/uvu-0.5.6.tgz", - "integrity": "sha512-+g8ENReyr8YsOc6fv/NVJs2vFdHBnBNdfE49rshrTzDWOlUx4Gq7KOS2GD8eqhy2j+Ejq29+SbKH8yjkAqXqoA==", - "dependencies": { - "dequal": "^2.0.0", - "diff": "^5.0.0", - "kleur": "^4.0.3", - "sade": "^1.7.3" - }, - "bin": { - "uvu": "bin.js" - }, - "engines": { - "node": ">=8" - } - }, "node_modules/vfile": { - "version": "5.3.7", - "resolved": "https://registry.npmjs.org/vfile/-/vfile-5.3.7.tgz", - "integrity": "sha512-r7qlzkgErKjobAmyNIkkSpizsFPYiUPuJb5pNW1RB4JcYVZhs4lIbVqk8XPk033CV/1z8ss5pkax8SuhGpcG8g==", + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/vfile/-/vfile-6.0.1.tgz", + "integrity": "sha512-1bYqc7pt6NIADBJ98UiG0Bn/CHIVOoZ/IyEkqIruLg0mE1BKzkOXY2D6CSqQIcKqgadppE5lrxgWXJmXd7zZJw==", "dependencies": { - "@types/unist": "^2.0.0", - "is-buffer": "^2.0.0", - "unist-util-stringify-position": "^3.0.0", - "vfile-message": "^3.0.0" + "@types/unist": "^3.0.0", + "unist-util-stringify-position": "^4.0.0", + "vfile-message": "^4.0.0" }, "funding": { "type": "opencollective", @@ -7168,12 +7636,12 @@ } }, "node_modules/vfile-message": { - "version": "3.1.4", - "resolved": "https://registry.npmjs.org/vfile-message/-/vfile-message-3.1.4.tgz", - "integrity": "sha512-fa0Z6P8HUrQN4BZaX05SIVXic+7kE3b05PWAtPuYP9QLHsLKYR7/AlLW3NtOrpXRLeawpDLMsVkmk5DG0NXgWw==", + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/vfile-message/-/vfile-message-4.0.2.tgz", + "integrity": "sha512-jRDZ1IMLttGj41KcZvlrYAaI3CfqpLpfpf+Mfig13viT6NKvRzWZ+lXz0Y5D60w6uJIBAOGq9mSHf0gktF0duw==", "dependencies": { - "@types/unist": "^2.0.0", - "unist-util-stringify-position": "^3.0.0" + "@types/unist": "^3.0.0", + "unist-util-stringify-position": "^4.0.0" }, "funding": { "type": "opencollective", @@ -7437,6 +7905,15 @@ "funding": { "url": "https://github.com/sponsors/sindresorhus" } + }, + "node_modules/zwitch": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/zwitch/-/zwitch-2.0.4.tgz", + "integrity": "sha512-bXE4cR/kVZhKZX/RjPEflHaKVhUVl85noU3v6b8apfQEc1x4A+zBxjZ4lN8LqGd6WZ3dl98pY4o717VFmoPp+A==", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } } } } diff --git a/web/package.json b/web/package.json index 7089d2bf3d..4d3fc3794b 100644 --- a/web/package.json +++ b/web/package.json @@ -21,6 +21,7 @@ "autoprefixer": "^10.4.14", "formik": "^2.2.9", "js-cookie": "^3.0.5", + "mdast-util-find-and-replace": "^3.0.1", "next": "^14.1.0", "postcss": "^8.4.31", "react": "^18.2.0", @@ -28,7 +29,8 @@ "react-dropzone": "^14.2.3", "react-icons": "^4.8.0", "react-loader-spinner": "^5.4.5", - "react-markdown": "^8.0.7", + "react-markdown": "^9.0.1", + "remark-gfm": "^4.0.0", "semver": "^7.5.4", "sharp": "^0.32.6", "swr": "^2.1.5", diff --git a/web/src/app/chat/message/Messages.tsx b/web/src/app/chat/message/Messages.tsx index 514451f389..a4ab222d93 100644 --- a/web/src/app/chat/message/Messages.tsx +++ b/web/src/app/chat/message/Messages.tsx @@ -14,6 +14,7 @@ import { SearchSummary, ShowHideDocsButton } from "./SearchSummary"; import { SourceIcon } from "@/components/SourceIcon"; import { ThreeDots } from "react-loader-spinner"; import { SkippedSearch } from "./SkippedSearch"; +import remarkGfm from "remark-gfm"; export const Hoverable: React.FC<{ children: JSX.Element; @@ -132,6 +133,7 @@ export const AIMessage = ({ /> ), }} + remarkPlugins={[remarkGfm]} > {content} @@ -255,6 +257,7 @@ export const HumanMessage = ({ /> ), }} + remarkPlugins={[remarkGfm]} > {content} diff --git a/web/src/components/search/results/AnswerSection.tsx b/web/src/components/search/results/AnswerSection.tsx index db9d6ae05c..08ce5c6bfb 100644 --- a/web/src/components/search/results/AnswerSection.tsx +++ b/web/src/components/search/results/AnswerSection.tsx @@ -1,6 +1,7 @@ import { Quote } from "@/lib/search/interfaces"; import { ResponseSection, StatusOptions } from "./ResponseSection"; import ReactMarkdown from "react-markdown"; +import remarkGfm from "remark-gfm"; const TEMP_STRING = "__$%^TEMP$%^__"; @@ -40,7 +41,10 @@ export const AnswerSection = (props: AnswerSectionProps) => { header = <>AI answer; if (props.answer) { body = ( - + {replaceNewlines(props.answer)} ); @@ -62,7 +66,10 @@ export const AnswerSection = (props: AnswerSectionProps) => { status = "success"; header = <>AI answer; body = ( - + {replaceNewlines(props.answer)} ); From fbff5b5784ac8ea0579e8c5aa4fcd3a03dd8a28c Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Tue, 26 Mar 2024 22:48:40 -0700 Subject: [PATCH 21/58] Save Retrieved Docs for One Shot Flows (#1259) --- backend/danswer/one_shot_answer/answer_question.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index e863f4ac09..320027c54b 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -16,6 +16,7 @@ from danswer.configs.chat_configs import QA_TIMEOUT from danswer.configs.constants import MessageType from danswer.db.chat import create_chat_session +from danswer.db.chat import create_db_search_doc from danswer.db.chat import create_new_chat_message from danswer.db.chat import get_or_create_root_message from danswer.db.chat import get_prompt_by_id @@ -202,6 +203,11 @@ def stream_answer_objects( ) yield from answer.processed_streamed_output + reference_db_search_docs = [ + create_db_search_doc(server_search_doc=top_doc, db_session=db_session) + for top_doc in top_docs + ] + # Saving Gen AI answer and responding with message info gen_ai_response_message = create_new_chat_message( chat_session_id=chat_session.id, @@ -211,7 +217,7 @@ def stream_answer_objects( token_count=len(llm_tokenizer(answer.llm_answer)), message_type=MessageType.ASSISTANT, error=None, - reference_docs=None, # Don't need to save reference docs for one shot flow + reference_docs=reference_db_search_docs, db_session=db_session, commit=True, ) From 5a967322fd69798c9ffe99655f913ab860b82329 Mon Sep 17 00:00:00 2001 From: Weves Date: Tue, 26 Mar 2024 23:12:04 -0700 Subject: [PATCH 22/58] Add ability to specify custom embedding models --- backend/danswer/background/update.py | 12 +- backend/danswer/db/index_attempt.py | 9 +- backend/danswer/indexing/models.py | 14 +++ backend/danswer/server/manage/models.py | 9 +- .../danswer/server/manage/secondary_index.py | 19 +-- .../models/embedding/CustomModelForm.tsx | 116 ++++++++++++++++++ .../embedding/ModelSelectionConfirmation.tsx | 28 ++++- .../admin/models/embedding/ModelSelector.tsx | 38 +++--- .../embedding/ReindexingProgressTable.tsx | 14 ++- .../admin/models/embedding/embeddingModels.ts | 9 ++ web/src/app/admin/models/embedding/page.tsx | 99 ++++++++++----- 11 files changed, 289 insertions(+), 78 deletions(-) create mode 100644 web/src/app/admin/models/embedding/CustomModelForm.tsx diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index b77ddee859..a7a20fca30 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -29,7 +29,9 @@ from danswer.db.engine import get_db_current_time from danswer.db.engine import get_sqlalchemy_engine from danswer.db.index_attempt import cancel_indexing_attempts_past_model -from danswer.db.index_attempt import count_unique_cc_pairs_with_index_attempts +from danswer.db.index_attempt import ( + count_unique_cc_pairs_with_successful_index_attempts, +) from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import get_inprogress_index_attempts @@ -365,9 +367,9 @@ def kickoff_indexing_jobs( def check_index_swap(db_session: Session) -> None: - """Get count of cc-pairs and count of index_attempts for the new model grouped by - connector + credential, if it's the same, then assume new index is done building. - This does not take into consideration if the attempt failed or not""" + """Get count of cc-pairs and count of successful index_attempts for the + new model grouped by connector + credential, if it's the same, then assume + new index is done building. If so, swap the indices and expire the old one.""" # Default CC-pair created for Ingestion API unused here all_cc_pairs = get_connector_credential_pairs(db_session) cc_pair_count = len(all_cc_pairs) - 1 @@ -376,7 +378,7 @@ def check_index_swap(db_session: Session) -> None: if not embedding_model: return - unique_cc_indexings = count_unique_cc_pairs_with_index_attempts( + unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts( embedding_model_id=embedding_model.id, db_session=db_session ) diff --git a/backend/danswer/db/index_attempt.py b/backend/danswer/db/index_attempt.py index ce913098eb..4580140a5f 100644 --- a/backend/danswer/db/index_attempt.py +++ b/backend/danswer/db/index_attempt.py @@ -291,7 +291,7 @@ def cancel_indexing_attempts_past_model( db_session.commit() -def count_unique_cc_pairs_with_index_attempts( +def count_unique_cc_pairs_with_successful_index_attempts( embedding_model_id: int | None, db_session: Session, ) -> int: @@ -299,12 +299,7 @@ def count_unique_cc_pairs_with_index_attempts( db_session.query(IndexAttempt.connector_id, IndexAttempt.credential_id) .filter( IndexAttempt.embedding_model_id == embedding_model_id, - # Should not be able to hang since indexing jobs expire after a limit - # It will then be marked failed, and the next cycle it will be in a completed state - or_( - IndexAttempt.status == IndexingStatus.SUCCESS, - IndexAttempt.status == IndexingStatus.FAILED, - ), + IndexAttempt.status == IndexingStatus.SUCCESS, ) .distinct() .count() diff --git a/backend/danswer/indexing/models.py b/backend/danswer/indexing/models.py index c875c88bdd..68f9e3886a 100644 --- a/backend/danswer/indexing/models.py +++ b/backend/danswer/indexing/models.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from dataclasses import fields from datetime import datetime +from typing import TYPE_CHECKING from pydantic import BaseModel @@ -9,6 +10,9 @@ from danswer.connectors.models import Document from danswer.utils.logger import setup_logger +if TYPE_CHECKING: + from danswer.db.models import EmbeddingModel + logger = setup_logger() @@ -130,3 +134,13 @@ class EmbeddingModelDetail(BaseModel): normalize: bool query_prefix: str | None passage_prefix: str | None + + @classmethod + def from_model(cls, embedding_model: "EmbeddingModel") -> "EmbeddingModelDetail": + return cls( + model_name=embedding_model.model_name, + model_dim=embedding_model.model_dim, + normalize=embedding_model.normalize, + query_prefix=embedding_model.query_prefix, + passage_prefix=embedding_model.passage_prefix, + ) diff --git a/backend/danswer/server/manage/models.py b/backend/danswer/server/manage/models.py index a2ea4c7ab6..8857ffc55b 100644 --- a/backend/danswer/server/manage/models.py +++ b/backend/danswer/server/manage/models.py @@ -11,6 +11,7 @@ from danswer.db.models import ChannelConfig from danswer.db.models import SlackBotConfig as SlackBotConfigModel from danswer.db.models import SlackBotResponseType +from danswer.indexing.models import EmbeddingModelDetail from danswer.server.features.persona.models import PersonaSnapshot @@ -125,10 +126,6 @@ def from_model( ) -class ModelVersionResponse(BaseModel): - model_name: str | None # None only applicable to secondary index - - class FullModelVersionResponse(BaseModel): - current_model_name: str - secondary_model_name: str | None + current_model: EmbeddingModelDetail + secondary_model: EmbeddingModelDetail | None diff --git a/backend/danswer/server/manage/secondary_index.py b/backend/danswer/server/manage/secondary_index.py index c4c51c0e30..6f5adf752f 100644 --- a/backend/danswer/server/manage/secondary_index.py +++ b/backend/danswer/server/manage/secondary_index.py @@ -20,7 +20,6 @@ from danswer.document_index.factory import get_default_document_index from danswer.indexing.models import EmbeddingModelDetail from danswer.server.manage.models import FullModelVersionResponse -from danswer.server.manage.models import ModelVersionResponse from danswer.server.models import IdReturn from danswer.utils.logger import setup_logger @@ -115,21 +114,21 @@ def cancel_new_embedding( def get_current_embedding_model( _: User | None = Depends(current_user), db_session: Session = Depends(get_session), -) -> ModelVersionResponse: +) -> EmbeddingModelDetail: current_model = get_current_db_embedding_model(db_session) - return ModelVersionResponse(model_name=current_model.model_name) + return EmbeddingModelDetail.from_model(current_model) @router.get("/get-secondary-embedding-model") def get_secondary_embedding_model( _: User | None = Depends(current_user), db_session: Session = Depends(get_session), -) -> ModelVersionResponse: +) -> EmbeddingModelDetail | None: next_model = get_secondary_db_embedding_model(db_session) + if not next_model: + return None - return ModelVersionResponse( - model_name=next_model.model_name if next_model else None - ) + return EmbeddingModelDetail.from_model(next_model) @router.get("/get-embedding-models") @@ -140,6 +139,8 @@ def get_embedding_models( current_model = get_current_db_embedding_model(db_session) next_model = get_secondary_db_embedding_model(db_session) return FullModelVersionResponse( - current_model_name=current_model.model_name, - secondary_model_name=next_model.model_name if next_model else None, + current_model=EmbeddingModelDetail.from_model(current_model), + secondary_model=EmbeddingModelDetail.from_model(next_model) + if next_model + else None, ) diff --git a/web/src/app/admin/models/embedding/CustomModelForm.tsx b/web/src/app/admin/models/embedding/CustomModelForm.tsx new file mode 100644 index 0000000000..23676bc61b --- /dev/null +++ b/web/src/app/admin/models/embedding/CustomModelForm.tsx @@ -0,0 +1,116 @@ +import { + BooleanFormField, + TextFormField, +} from "@/components/admin/connectors/Field"; +import { Button, Divider, Text } from "@tremor/react"; +import { Form, Formik } from "formik"; + +import * as Yup from "yup"; +import { EmbeddingModelDescriptor } from "./embeddingModels"; + +export function CustomModelForm({ + onSubmit, +}: { + onSubmit: (model: EmbeddingModelDescriptor) => void; +}) { + return ( +
+ { + onSubmit({ ...values, model_dim: parseInt(values.model_dim) }); + }} + > + {({ isSubmitting, setFieldValue }) => ( +
+ + + { + const value = e.target.value; + // Allow only integer values + if (value === "" || /^[0-9]+$/.test(value)) { + setFieldValue("model_dim", value); + } + }} + /> + + + The prefix specified by the model creators which should be + prepended to queries before passing them to the model. + Many models do not have this, in which case this should be + left empty. + + } + placeholder="E.g. 'query: '" + autoCompleteDisabled={true} + /> + + + The prefix specified by the model creators which should be + prepended to passages before passing them to the model. + Many models do not have this, in which case this should be + left empty. + + } + placeholder="E.g. 'passage: '" + autoCompleteDisabled={true} + /> + + + +
+ +
+ + )} +
+
+ ); +} diff --git a/web/src/app/admin/models/embedding/ModelSelectionConfirmation.tsx b/web/src/app/admin/models/embedding/ModelSelectionConfirmation.tsx index 949c5d46da..7572ac2ce8 100644 --- a/web/src/app/admin/models/embedding/ModelSelectionConfirmation.tsx +++ b/web/src/app/admin/models/embedding/ModelSelectionConfirmation.tsx @@ -1,18 +1,21 @@ import { Modal } from "@/components/Modal"; -import { Button, Text } from "@tremor/react"; +import { Button, Text, Callout } from "@tremor/react"; +import { EmbeddingModelDescriptor } from "./embeddingModels"; export function ModelSelectionConfirmaion({ selectedModel, + isCustom, onConfirm, }: { - selectedModel: string; + selectedModel: EmbeddingModelDescriptor; + isCustom: boolean; onConfirm: () => void; }) { return (
- You have selected: {selectedModel}. Are you sure you want to - update to this new embedding model? + You have selected: {selectedModel.model_name}. Are you sure you + want to update to this new embedding model? We will re-index all your documents in the background so you will be @@ -25,6 +28,18 @@ export function ModelSelectionConfirmaion({ normal. If you are self-hosting, we recommend that you allocate at least 16GB of RAM to Danswer during this process. + + {isCustom && ( + + We've detected that this is a custom-specified embedding model. + Since we have to download the model files before verifying the + configuration's correctness, we won't be able to let you + know if the configuration is valid until after we start + re-indexing your documents. If there is an issue, it will show up on + this page as an indexing error on this page after clicking Confirm. + + )} +
@@ -61,17 +69,19 @@ export function ModelSelector({ setSelectedModel, }: { modelOptions: FullEmbeddingModelDescriptor[]; - setSelectedModel: (modelName: string) => void; + setSelectedModel: (model: EmbeddingModelDescriptor) => void; }) { return ( -
- {modelOptions.map((modelOption) => ( - - ))} +
+
+ {modelOptions.map((modelOption) => ( + + ))} +
); } diff --git a/web/src/app/admin/models/embedding/ReindexingProgressTable.tsx b/web/src/app/admin/models/embedding/ReindexingProgressTable.tsx index 3b366c1922..b1f91d24bb 100644 --- a/web/src/app/admin/models/embedding/ReindexingProgressTable.tsx +++ b/web/src/app/admin/models/embedding/ReindexingProgressTable.tsx @@ -1,14 +1,14 @@ import { PageSelector } from "@/components/PageSelector"; -import { CCPairStatus, IndexAttemptStatus } from "@/components/Status"; -import { ConnectorIndexingStatus, ValidStatuses } from "@/lib/types"; +import { IndexAttemptStatus } from "@/components/Status"; +import { ConnectorIndexingStatus } from "@/lib/types"; import { - Button, Table, TableBody, TableCell, TableHead, TableHeaderCell, TableRow, + Text, } from "@tremor/react"; import Link from "next/link"; import { useState } from "react"; @@ -30,6 +30,7 @@ export function ReindexingProgressTable({ Connector Name Status Docs Re-Indexed + Error Message @@ -58,6 +59,13 @@ export function ReindexingProgressTable({ {reindexingProgress?.latest_index_attempt ?.total_docs_indexed || "-"} + +
+ + {reindexingProgress.error_msg || "-"} + +
+
); })} diff --git a/web/src/app/admin/models/embedding/embeddingModels.ts b/web/src/app/admin/models/embedding/embeddingModels.ts index 64ccfff958..7c5d09180f 100644 --- a/web/src/app/admin/models/embedding/embeddingModels.ts +++ b/web/src/app/admin/models/embedding/embeddingModels.ts @@ -76,3 +76,12 @@ export function checkModelNameIsValid(modelName: string | undefined | null) { } return true; } + +export function fillOutEmeddingModelDescriptor( + embeddingModel: EmbeddingModelDescriptor | FullEmbeddingModelDescriptor +): FullEmbeddingModelDescriptor { + return { + ...embeddingModel, + description: "", + }; +} diff --git a/web/src/app/admin/models/embedding/page.tsx b/web/src/app/admin/models/embedding/page.tsx index 5f4cd1c93d..0612fe2c62 100644 --- a/web/src/app/admin/models/embedding/page.tsx +++ b/web/src/app/admin/models/embedding/page.tsx @@ -6,7 +6,7 @@ import { KeyIcon, TrashIcon } from "@/components/icons/icons"; import { ApiKeyForm } from "@/components/openai/ApiKeyForm"; import { GEN_AI_API_KEY_URL } from "@/components/openai/constants"; import { errorHandlingFetcher, fetcher } from "@/lib/fetcher"; -import { Button, Divider, Text, Title } from "@tremor/react"; +import { Button, Card, Divider, Text, Title } from "@tremor/react"; import { FiCpu, FiPackage } from "react-icons/fi"; import useSWR, { mutate } from "swr"; import { ModelOption, ModelSelector } from "./ModelSelector"; @@ -16,17 +16,18 @@ import { ReindexingProgressTable } from "./ReindexingProgressTable"; import { Modal } from "@/components/Modal"; import { AVAILABLE_MODELS, - EmbeddingModelResponse, + EmbeddingModelDescriptor, INVALID_OLD_MODEL, + fillOutEmeddingModelDescriptor, } from "./embeddingModels"; import { ErrorCallout } from "@/components/ErrorCallout"; import { Connector, ConnectorIndexingStatus } from "@/lib/types"; import Link from "next/link"; +import { CustomModelForm } from "./CustomModelForm"; function Main() { - const [tentativeNewEmbeddingModel, setTentativeNewEmbeddingModel] = useState< - string | null - >(null); + const [tentativeNewEmbeddingModel, setTentativeNewEmbeddingModel] = + useState(null); const [isCancelling, setIsCancelling] = useState(false); const [showAddConnectorPopup, setShowAddConnectorPopup] = useState(false); @@ -35,16 +36,16 @@ function Main() { data: currentEmeddingModel, isLoading: isLoadingCurrentModel, error: currentEmeddingModelError, - } = useSWR( + } = useSWR( "/api/secondary-index/get-current-embedding-model", errorHandlingFetcher, { refreshInterval: 5000 } // 5 seconds ); const { - data: futureEmeddingModel, + data: futureEmbeddingModel, isLoading: isLoadingFutureModel, error: futureEmeddingModelError, - } = useSWR( + } = useSWR( "/api/secondary-index/get-secondary-embedding-model", errorHandlingFetcher, { refreshInterval: 5000 } // 5 seconds @@ -63,24 +64,20 @@ function Main() { { refreshInterval: 5000 } // 5 seconds ); - const onSelect = async (modelName: string) => { + const onSelect = async (model: EmbeddingModelDescriptor) => { if (currentEmeddingModel?.model_name === INVALID_OLD_MODEL) { - await onConfirm(modelName); + await onConfirm(model); } else { - setTentativeNewEmbeddingModel(modelName); + setTentativeNewEmbeddingModel(model); } }; - const onConfirm = async (modelName: string) => { - const modelDescriptor = AVAILABLE_MODELS.find( - (model) => model.model_name === modelName - ); - + const onConfirm = async (model: EmbeddingModelDescriptor) => { const response = await fetch( "/api/secondary-index/set-new-embedding-model", { method: "POST", - body: JSON.stringify(modelDescriptor), + body: JSON.stringify(model), headers: { "Content-Type": "application/json", }, @@ -120,26 +117,33 @@ function Main() { if ( currentEmeddingModelError || !currentEmeddingModel || - futureEmeddingModelError || - !futureEmeddingModel + futureEmeddingModelError ) { return ; } const currentModelName = currentEmeddingModel.model_name; - const currentModel = AVAILABLE_MODELS.find( - (model) => model.model_name === currentModelName - ); + const currentModel = + AVAILABLE_MODELS.find((model) => model.model_name === currentModelName) || + fillOutEmeddingModelDescriptor(currentEmeddingModel); - const newModelSelection = AVAILABLE_MODELS.find( - (model) => model.model_name === futureEmeddingModel.model_name - ); + const newModelSelection = futureEmbeddingModel + ? AVAILABLE_MODELS.find( + (model) => model.model_name === futureEmbeddingModel.model_name + ) || fillOutEmeddingModelDescriptor(futureEmbeddingModel) + : null; return (
{tentativeNewEmbeddingModel && ( + model.model_name === tentativeNewEmbeddingModel.model_name + ) === undefined + } onConfirm={() => onConfirm(tentativeNewEmbeddingModel)} onCancel={() => setTentativeNewEmbeddingModel(null)} /> @@ -243,12 +247,49 @@ function Main() { )} + + Below are a curated selection of quality models that we recommend + you choose from. + + modelOption.model_name !== currentModelName )} setSelectedModel={onSelect} /> + + + Alternatively, (if you know what you're doing) you can + specify a{" "} + + SentenceTransformers + + -compatible model of your choice below. The rough list of + supported models can be found{" "} + + here + + . +
+ NOTE: not all models listed will work with Danswer, since + some have unique interfaces or special requirements. If in doubt, + reach out to the Danswer team. + + +
+ + + +
) : ( connectors && @@ -272,10 +313,10 @@ function Main() { The table below shows the re-indexing progress of all existing - connectors. Once all connectors have been re-indexed, the new - model will be used for all search queries. Until then, we will - use the old model so that no downtime is necessary during this - transition. + connectors. Once all connectors have been re-indexed + successfully, the new model will be used for all search + queries. Until then, we will use the old model so that no + downtime is necessary during this transition. {isLoadingOngoingReIndexingStatus ? ( From 9757fbee901fb15e5f647719b4f3b8a57640a439 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Wed, 27 Mar 2024 11:12:01 -0700 Subject: [PATCH 23/58] Axero Connector (#1253) --------- Co-authored-by: Weves --- backend/danswer/configs/constants.py | 1 + backend/danswer/connectors/axero/__init__.py | 0 backend/danswer/connectors/axero/connector.py | 186 ++++++++++++++ .../miscellaneous_utils.py | 14 + backend/danswer/connectors/factory.py | 2 + web/public/Axero.jpeg | Bin 0 -> 7977 bytes web/src/app/admin/connectors/axero/page.tsx | 240 ++++++++++++++++++ web/src/components/icons/icons.tsx | 12 + web/src/lib/sources.ts | 6 + web/src/lib/types.ts | 7 +- 10 files changed, 467 insertions(+), 1 deletion(-) create mode 100644 backend/danswer/connectors/axero/__init__.py create mode 100644 backend/danswer/connectors/axero/connector.py create mode 100644 web/public/Axero.jpeg create mode 100644 web/src/app/admin/connectors/axero/page.tsx diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 356fc2831f..65b9f7945b 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -87,6 +87,7 @@ class DocumentSource(str, Enum): ZENDESK = "zendesk" LOOPIO = "loopio" SHAREPOINT = "sharepoint" + AXERO = "axero" class DocumentIndexType(str, Enum): diff --git a/backend/danswer/connectors/axero/__init__.py b/backend/danswer/connectors/axero/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/danswer/connectors/axero/connector.py b/backend/danswer/connectors/axero/connector.py new file mode 100644 index 0000000000..e19fb39eb9 --- /dev/null +++ b/backend/danswer/connectors/axero/connector.py @@ -0,0 +1,186 @@ +import time +from datetime import datetime +from datetime import timezone +from typing import Any + +import requests + +from danswer.configs.app_configs import INDEX_BATCH_SIZE +from danswer.configs.constants import DocumentSource +from danswer.connectors.cross_connector_utils.html_utils import parse_html_page_basic +from danswer.connectors.cross_connector_utils.miscellaneous_utils import ( + process_in_batches, +) +from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc +from danswer.connectors.interfaces import GenerateDocumentsOutput +from danswer.connectors.interfaces import PollConnector +from danswer.connectors.interfaces import SecondsSinceUnixEpoch +from danswer.connectors.models import ConnectorMissingCredentialError +from danswer.connectors.models import Document +from danswer.connectors.models import Section +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + + +ENTITY_NAME_MAP = {1: "Forum", 3: "Article", 4: "Blog", 9: "Wiki"} + + +def _get_auth_header(api_key: str) -> dict[str, str]: + return {"Rest-Api-Key": api_key} + + +# https://my.axerosolutions.com/spaces/5/communifire-documentation/wiki/view/595/rest-api-get-content-list +def _get_entities( + entity_type: int, + api_key: str, + axero_base_url: str, + start: datetime, + end: datetime, +) -> list[dict]: + endpoint = axero_base_url + "api/content/list" + page_num = 1 + pages_fetched = 0 + pages_to_return = [] + break_out = False + while True: + params = { + "EntityType": str(entity_type), + "SortColumn": "DateUpdated", + "SortOrder": "1", # descending + "StartPage": str(page_num), + } + res = requests.get(endpoint, headers=_get_auth_header(api_key), params=params) + res.raise_for_status() + + # Axero limitations: + # No next page token, can paginate but things may have changed + # for example, a doc that hasn't been read in by Danswer is updated and is now front of the list + # due to this limitation and the fact that Axero has no rate limiting but API calls can cause + # increased latency for the team, we have to just fetch all the pages quickly to reduce the + # chance of missing a document due to an update (it will still get updated next pass) + # Assumes the volume of data isn't too big to store in memory (probably fine) + data = res.json() + total_records = data["TotalRecords"] + contents = data["ResponseData"] + pages_fetched += len(contents) + logger.debug(f"Fetched {pages_fetched} {ENTITY_NAME_MAP[entity_type]}") + + for page in contents: + update_time = time_str_to_utc(page["DateUpdated"]) + + if update_time > end: + continue + + if update_time < start: + break_out = True + break + + pages_to_return.append(page) + + if pages_fetched >= total_records: + break + + page_num += 1 + + if break_out: + break + + return pages_to_return + + +def _translate_content_to_doc(content: dict) -> Document: + page_text = "" + summary = content.get("ContentSummary") + body = content.get("ContentBody") + if summary: + page_text += f"{summary}\n" + + if body: + content_parsed = parse_html_page_basic(body) + page_text += content_parsed + + doc = Document( + id="AXERO_" + str(content["ContentID"]), + sections=[Section(link=content["ContentVersionURL"], text=page_text)], + source=DocumentSource.AXERO, + semantic_identifier=content["ContentTitle"], + doc_updated_at=time_str_to_utc(content["DateUpdated"]), + metadata={"space": content["SpaceName"]}, + ) + + return doc + + +class AxeroConnector(PollConnector): + def __init__( + self, + base_url: str, + include_article: bool = True, + include_blog: bool = True, + include_wiki: bool = True, + # Forums not supported atm + include_forum: bool = False, + batch_size: int = INDEX_BATCH_SIZE, + ) -> None: + self.include_article = include_article + self.include_blog = include_blog + self.include_wiki = include_wiki + self.include_forum = include_forum + self.batch_size = batch_size + self.axero_key = None + + if not base_url.endswith("/"): + base_url += "/" + self.base_url = base_url + + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + self.axero_key = credentials["axero_api_token"] + return None + + def poll_source( + self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch + ) -> GenerateDocumentsOutput: + if not self.axero_key: + raise ConnectorMissingCredentialError("Axero") + + start_datetime = datetime.utcfromtimestamp(start).replace(tzinfo=timezone.utc) + end_datetime = datetime.utcfromtimestamp(end).replace(tzinfo=timezone.utc) + + entity_types = [] + if self.include_article: + entity_types.append(3) + if self.include_blog: + entity_types.append(4) + if self.include_wiki: + entity_types.append(9) + if self.include_forum: + raise NotImplementedError("Forums for Axero not supported currently") + + for entity in entity_types: + articles = _get_entities( + entity_type=entity, + api_key=self.axero_key, + axero_base_url=self.base_url, + start=start_datetime, + end=end_datetime, + ) + yield from process_in_batches( + objects=articles, + process_function=_translate_content_to_doc, + batch_size=self.batch_size, + ) + + +if __name__ == "__main__": + import os + + connector = AxeroConnector(base_url=os.environ["AXERO_BASE_URL"]) + connector.load_credentials({"axero_api_token": os.environ["AXERO_API_TOKEN"]}) + current = time.time() + + one_year_ago = current - 24 * 60 * 60 * 360 + latest_docs = connector.poll_source(one_year_ago, current) + + print(next(latest_docs)) diff --git a/backend/danswer/connectors/cross_connector_utils/miscellaneous_utils.py b/backend/danswer/connectors/cross_connector_utils/miscellaneous_utils.py index 10c8315601..8faf6bfada 100644 --- a/backend/danswer/connectors/cross_connector_utils/miscellaneous_utils.py +++ b/backend/danswer/connectors/cross_connector_utils/miscellaneous_utils.py @@ -1,5 +1,8 @@ +from collections.abc import Callable +from collections.abc import Iterator from datetime import datetime from datetime import timezone +from typing import TypeVar from dateutil.parser import parse @@ -43,3 +46,14 @@ def get_experts_stores_representations( reps = [basic_expert_info_representation(owner) for owner in experts] return [owner for owner in reps if owner is not None] + + +T = TypeVar("T") +U = TypeVar("U") + + +def process_in_batches( + objects: list[T], process_function: Callable[[T], U], batch_size: int +) -> Iterator[list[U]]: + for i in range(0, len(objects), batch_size): + yield [process_function(obj) for obj in objects[i : i + batch_size]] diff --git a/backend/danswer/connectors/factory.py b/backend/danswer/connectors/factory.py index f4a9ee2908..5e6438088b 100644 --- a/backend/danswer/connectors/factory.py +++ b/backend/danswer/connectors/factory.py @@ -2,6 +2,7 @@ from typing import Type from danswer.configs.constants import DocumentSource +from danswer.connectors.axero.connector import AxeroConnector from danswer.connectors.bookstack.connector import BookstackConnector from danswer.connectors.confluence.connector import ConfluenceConnector from danswer.connectors.danswer_jira.connector import JiraConnector @@ -70,6 +71,7 @@ def identify_connector_class( DocumentSource.ZENDESK: ZendeskConnector, DocumentSource.LOOPIO: LoopioConnector, DocumentSource.SHAREPOINT: SharepointConnector, + DocumentSource.AXERO: AxeroConnector, } connector_by_source = connector_map.get(source, {}) diff --git a/web/public/Axero.jpeg b/web/public/Axero.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..f6df99217274165fefcd215709b76c2663661df8 GIT binary patch literal 7977 zcmdsbc|6oz`~PQ#v5Xiah9oAkFCk-#Y}v!u_uV8*WhZ3cm5?PxMwFc)WM3oMvSn9s zlZX;Y%JTgT-S_?6&vSpj&-48L{GFNC`CR9m>;1m2^Euab&YU@zICuxpYN)8I01yZO zK)?@hFbzZiFcMNyQW6+g!C){l@?#X_U_nq(Q657g=;)9LB=R_lh4DB&GXoOIbef6z z)G0PLw&RTKXV_WKu&}bRLO>Wf895CF1q~}blAiUyE(h%Zk^%}q0TKv107XJbkdTAt z0EXx$3D^hxQ6MDb6i`w!7z7-F{^I}uAtQ&9!blFLfMZ}kl$L}R1Z_XMc4-4&zz0_( zRw)G1BMBQuY=$M`(S?fe6xsqlY}i|7&QP2HfDa*>BEaQQf+>k$O8{j8O=E8l1Asg@ zw3pokh7t?|DoadIKpLE)0Imt}#iHPGl=6xply->#ywJ`B7H1MK5Nj7%O+hFMDdZCj zgBPIq6nm6YLI8XkF8~Wm<4h?)KLU!OaJobW0nQL?FS_IssP?bcNM0paByA+6wKVz$!`Juc1aT!OCOdiqMn@5CSpMh2W9|1o&DBsGCdHOY6S5Zx9dqKP9rf zS%&=n*v@?+Q7`WqkVh`5O&oxF-u$q!8J;`u8J;T}WKp-)EGO~#>xkfk@26h1bT^fo zw@ikoqd~h=^ah~{46ro1Si4Y7E`mVi zYJj{pFm_U|XZbp$czyjZkRaogx{ZYvV>hR@{0tM{pW=LqNp@*e0%@ze6xSqmZTcmZC4N))NNUIhIaflJmmihi8%WN05D^nyFp8D)6{D`%>Ys}S}b__GX2LPNb50-DL_O8vmJI`j$ zK#K*VEru^l7Z-}V$Y2Yj;6*(?t&S~_*K2PjK+cZu2CsStS!!drz=VdpdJB7sNgM7n zlD+V*PynbFh_#`I1L$IijR0V3^N2I76%GS4K!9;_tMU6lgWrwiJ`hHtxpKRqpyryT z?W%99i2zbf6Bbu!L~44GxB})3oMAaAmPoWb7Q$*JF8x-nbai=#R3`L;q6P3_C09<+kX`p;#QIoQpaIz6pz&3A=|8uxL=vX-}c1TBxr#eDbrY4HRGba=b(9T?jy;zjoVYW!p=y zKj$963q_$!iYZKCsdkCbl7f&&dHwNUCEZu1oM!03xoJ2z-6<1~vc)|tchfU5jQIVf zl6UUqFMbP}DHfQ{s73P>LTsui3I(7Md_567X#VAWdj~J+h8gEfVUQ!MnSWajarJU{ zQ%ex>9VJG<>Rwx*s(THc}{r~yK_d7ZWO#$FuNAKATL{3blZQUfR>k?|f`$#MATZ)Bi2z1)tetVn3?lBM@J9^rzsKjnSO zvMkf4r%U?hP06;@Ti(0*kGy?Ec8=dVVe-OSWob2Kj(HN4xXW|z-yQTX(vJ`&WP$5V zMC|KN$F>vIE@8+QWltZ%H~c@o!sSzp5ed(hKjLO0+PKM9{3JnbAxfK;*iBy4P>km$ zs>n8rMAd*^4Mr=jeCS9%a>KZ+P_o0EpXuX{e?6l$?iT|`QKMJrk;CFA5*-SIOm9c0 z+SYbFy9ON;jKPnn>RXLD-Z*pJ@XdHl#vApMuYIf6p7rJzmveW%mhTmP#!YlN^8auiDg4nexfsV4S7D*dp}#SNVto)# zw#dl3q}R6Fp*&$le`l8S*-N=k_2e6;b6q6_8}izFu(~G%LsbJb#3i%Vz3y_}i@3$v z)84X}>+s<_1AUjm#g3#bAN`%tt&O^7t7f$>;<{Sukl)GEi=K^c8hKmmwg*GZH25d; zoOU7&Ff*%I7)^zy)M02}>FjlTUUBM`e=jKMdY);4cQuxW9c@dK9dGFW;+&=jODTJE zorW7CCsjv8FUV=sjTKv_jVuL&X>{H1hceI2iGbP2la$maX$t&meXq~*V6k6L%m0}2 zGf5^B5aPPtG=F`42Ib^astUhZD%8D^63bkmu8mwcKNL3M$%7+@Fj#t19=dDIU%dG4 z%hx#UPr`NsomEoaqz4SEm&wOgZv`YV2!v#eoxWcFj2G>y;q84@-(&d{K{vt4J@U^FCMq5>pjSvz}u=Aid>1c9d%? zelj=NQ!cE^etP+IcNSH~edb)VyVAWPGPZ);9sgqeTg0rde?T)p$#-27nIu}gCGGR2 zRgpr$kn6txy-5wnJ6q)g`EZxzzCBS6w4*WWq&PC{_Y_U2YE)xXl2p`S=DthwzkI<#)BfS-p$iuw>mn{n{+GS$GDlm zMQ?r{bt^s!-O(xMp*m`+voNfcnD3CO@5Cn)6uNt-v-aC&)kNhki(=xuRk(1G%($nWq- zD4hHmV>R6g4!Z6!W;~;6#fQ82bXlI)K7uaRmtRBFEOpabN^>-PEBVT9>S}Ril>|d2 zY00NJ!EEvdC-|Q8$YJy~TN|sg@e~jibR#Q|WxrCiny#0N(_)?CTm3Y@D2H12w9LsW zAB6{NRLvbLu`jQ@thp~Jb0+%hAd=}rmQnltX4-*LM;VfLB6B=jf|grOhX+~x$jF)e z$kBZG7fI|6oVCKZT~KaIo!86Ot>OSk{9LRp3?4s$3C*3+;uSOl$|P%7yv=KUN(#v* z4ckT7KXVk9z5ej+N$KI{@L|uA@{P22Ri@M3xdCJK-AH4KpTK;c2LR!Ym!aG3HV%L5==9g!+QN2o zJ6zFceW=`hccLF~mLWFJife*$`tPy@+j~~mwE54QzjPWnYa4cZ`@_b`f9W8^HIH;^ z<@%Z?1Uuh?`Y59QAy0GzjHo}GJY%_U`@Zd(&;cO4?eL5-_+9|PJieu`@epBkXaOPq zrxyqxO-95f$W8RCf#?_Fmyd^h{X{<8`!<(w(3gMU7TtEJZUYBi5)FxN+Yk-is3aPa znGkdfR^0(@{5d0WD$b6dw)<2wa@eY*T`t5=ifv_MiWxeI$KRb1J(;$pWi*~b-;^v^ z-dnCl)k^g>Oj5gfX82YzzgrUh?0{(-z`f~q3D4|pN*pJbN)I^w|BHI-?-68_X*<=( z;imQ`9Sk$%pQ7;ipE9z%+V3ddKye(=%)r31*ZbFf_SSRX_!}izrU;RaxvwZ(-+M5%VW{J z*GA@D*859o^k(y)kIz3aV|m=56IX7EbKhT`0&guC@qjW^`It}1z zl@&)g-^jxq0A2x2A**^-xAMe*>QJ}Lgk;&U;t0Ezz@(x5>5er9r!EEy4%$oE93{sP zR$;-Uqg!5nV$)qx+J<=nrj(S5dcqJeZyt{bs$W}fJC^&=c-IhIe8YBRf8*a634=Wqi ziyiJge$3rRO|k**q{WpN6CF%{W1~PiFY5@ERhsc<`g-Va6sNhY9zVlG`KVw!mN$U) z+l-!ElFZ9*(`;Qum>;$_+0RP`)3ryd*@Wq<%opY2{P+A`p$xq5>IiHczlC`{+3qT2 zl$0c((7+!bN;U%-$T!g`*xcvF-=9oJU({A_LvYoSq8~W+l4ML*L>vIymek}l?`e>I z0@6^HI(en4guGCB1|y>sImb&(OQnj3NGyl&hDKleKuc%1GEweaN>`qAVv znt}k-w{Y711nldMcgc>=!0#?WpM+z4T=L_>^6q zx%d;Jujq}sk;bVzu{&sIVDsbEhIkcQvNo^uc~g=TP@$V=+%nxGpC1-T z;_4>Y(Oe4mKDJ;>!LlNy-ma55W-2WtWbx2_NIT~d{oGWUkebgGUL6dFY2Lozb9*nf z+moNJah%L~e(jZ~JJT#3hmUzo*9Z5Ll`|?5N-?u(^CI;uDyz?sw+H7X9P4UtU5iCF z$o7LkNg|N*4wA5CmTBrIFy(o|F_$NkaPT&m7{v=A=kEwC)g$i~)68 zUOy?P;<&H~MnTWkJycYzjU79;PS2ruc!!5Vz&m{MzbVW+O4d9EVHOPt2$(!^hzFuCKBN z1gYTPpZ*vwaTU30z;*6<{or@KlX0Q#i9@-2SDgsowm5%R%j>u*Jm_?n9Yx+nlM*e$ zz2$fE=AM688Skb^Zn=dF&=}ssLPxq+7DLwG59qXzq-p+r%JCg3XF$Wm(DNVAM~kiw0LSpGxjV# z&82N>?2AXg-+8k)LZ;>9F6N!Rn0JSe#zV*$5whTJbiD0Jt;PLL3G!t>KkkWwuTDPe zY&RJTYAIrrxchq3eoT}yC~RsY?9Ehb_oa+@oUdmR_QE>agpE*bPau12JJJgKlg8vN zAK0@R3Af*1x(aW(wS7gpbY@A6Mcx>$vu%)<;iuNayj7?960fjmPExG_8~VCvkW(bi zMWL!}0d)+Y-QC1!Fz0CPsg#}F4rqCJxzHdij{#i>Q{z+6gD3t?gVc<2Y1DZh{E64F z^WHgC^oc=FhCUzeQsKXfn2O*Aq5ta{p zBXNbx%!t~S@~)V{-WuBsY3;jfb1za&e3dufTi&6ALKQ8o5bX#XD}CP$8F+Zdx;vof z97}jRk#RNc`Q*7%=Hs7cDI2dqRvdD_q;}v`5Fe zlH?w?7Cv=;(qHsIyvTLxP%H)$Tog=jC<%<56bgkNWjF*tLP^l1Fj^5^I(i0Sz0md3 zV)7LVw#ReWIYhDUhlx!>Oza$jt67?srrpcCTuNk@?x*8zMi*`i>T~4_oV@4$QdTg2 zeuBow1Q(B;2@7MY??pXc4}=D?hGjU0e!m*erK4$^#U9lm7xuM=Qx+y2WK%9f!<}g4 zUGE_5AQEKTI(p+;Hc9wKem|X!NN1wOWtHfGkGHx9cV^u)nAy&ko#Rki^6qlD68NdM z?=~W9{6tdDX8muXUKf1-R@LilEHeZ3$IEir%L-SYzDTi!A_^AoUq@VJYYlVt=xsGs zYkJ?dd~P(@`IT$x6}ERby;=v6W$WLDc6scC(@zw`mTHFIFBOSNuoL?lJH9353s3LvuvV+bU$YQ#hACF`S1DEV@dT?Blj=Bz z9{?8&6grKsm)>WS!;rc=LuD!-^^N!hm;~&-TVCgFcgy6a{M#!2oYv%;?dNaz(|LP5 zGV}kg!dN~WV?6+ZzcM;jN4HD`itao+6kIjrgjT;1wtwc02wCW}Pew?YN9N*7+lK~? z=B|9;cdPPZs5O&@&j_A^T|*^=J)=@p__qp0hlKcCt;tWepSMC@NPFT-)_?i9tx)ty z&?)?#iW!}ALx(9|85dggiuL(~-C!K~>4F;73UFlp6}WcN8m? zlHL2QQ6Xr$C`W0W!fsNGH*eCSX{qkyZ&Q(KjSE_=kMFEegwh;0xRT&NOCD2M^=9}T p_1Vz_K*OQ?$wj-EI { + const { popup, setPopup } = usePopup(); + + const { mutate } = useSWRConfig(); + const { + data: connectorIndexingStatuses, + isLoading: isConnectorIndexingStatusesLoading, + error: isConnectorIndexingStatusesError, + } = useSWR[]>( + "/api/manage/admin/connector/indexing-status", + fetcher, + { refreshInterval: 5000 } // 5 seconds + ); + const { + data: credentialsData, + isLoading: isCredentialsLoading, + error: isCredentialsError, + isValidating: isCredentialsValidating, + refreshCredentials, + } = usePublicCredentials(); + + if ( + isConnectorIndexingStatusesLoading || + isCredentialsLoading || + isCredentialsValidating + ) { + return ; + } + + if (isConnectorIndexingStatusesError || !connectorIndexingStatuses) { + return
Failed to load connectors
; + } + + if (isCredentialsError || !credentialsData) { + return
Failed to load credentials
; + } + + const axeroConnectorIndexingStatuses: ConnectorIndexingStatus< + {}, + AxeroCredentialJson + >[] = connectorIndexingStatuses.filter( + (connectorIndexingStatus) => + connectorIndexingStatus.connector.source === "axero" + ); + const axeroCredential: Credential | undefined = + credentialsData.find( + (credential) => credential.credential_json?.axero_api_token + ); + + return ( + <> + {popup} + + Step 1: Provide your Credentials + + + {axeroCredential ? ( + <> +
+ Existing API Key: + + {axeroCredential.credential_json?.axero_api_token} + + +
+ + ) : ( + <> + + To use the Axero connector, first follow the guide{" "} + + here + {" "} + to generate an API Key. + + + + formBody={ + <> + + + } + validationSchema={Yup.object().shape({ + axero_api_token: Yup.string().required( + "Please enter your Axero API Key!" + ), + })} + initialValues={{ + axero_api_token: "", + }} + onSubmit={(isSuccess) => { + if (isSuccess) { + refreshCredentials(); + } + }} + /> + + + )} + + + Step 2: Start indexing + + {axeroCredential ? ( + <> + {axeroConnectorIndexingStatuses.length > 0 ? ( + <> + + We pull the latest Articles, Blogs, and{" "} + Wikis every 10 minutes. + +
+ + connectorIndexingStatuses={axeroConnectorIndexingStatuses} + liveCredential={axeroCredential} + getCredential={(credential) => { + return ( +
+

{credential.credential_json.axero_api_token}

+
+ ); + }} + onCredentialLink={async (connectorId) => { + if (axeroCredential) { + await linkCredential(connectorId, axeroCredential.id); + mutate("/api/manage/admin/connector/indexing-status"); + } + }} + onUpdate={() => + mutate("/api/manage/admin/connector/indexing-status") + } + /> +
+ + ) : ( + +

Create Connector

+

+ Press connect below to start the connection Axero. We pull the + latest Articles, Blogs, and Wikis every{" "} + 10 minutes. +

+ + nameBuilder={() => "AxeroConnector"} + ccPairNameBuilder={() => "Axero"} + source="axero" + inputType="poll" + formBody={ + <> + + + } + validationSchema={Yup.object().shape({})} + initialValues={{ + base_url: "", + }} + refreshFreq={10 * 60} // 10 minutes + credentialId={axeroCredential.id} + /> +
+ )} + + ) : ( + <> + + Please provide your access token in Step 1 first! Once done with + that, you can then start indexing Linear. + + + )} + + ); +}; + +export default function Page() { + return ( +
+
+ +
+ + } title="Axero" /> + +
+
+ ); +} diff --git a/web/src/components/icons/icons.tsx b/web/src/components/icons/icons.tsx index e847d1dd38..33605911fb 100644 --- a/web/src/components/icons/icons.tsx +++ b/web/src/components/icons/icons.tsx @@ -600,3 +600,15 @@ export const ZendeskIcon = ({ Logo
); + +export const AxeroIcon = ({ + size = 16, + className = defaultTailwindCSS, +}: IconProps) => ( +
+ Logo +
+); diff --git a/web/src/lib/sources.ts b/web/src/lib/sources.ts index bcd821121a..92250de5ba 100644 --- a/web/src/lib/sources.ts +++ b/web/src/lib/sources.ts @@ -1,4 +1,5 @@ import { + AxeroIcon, BookstackIcon, ConfluenceIcon, Document360Icon, @@ -154,6 +155,11 @@ const SOURCE_METADATA_MAP: SourceMap = { displayName: "Sharepoint", category: SourceCategory.AppConnection, }, + axero: { + icon: AxeroIcon, + displayName: "Axero", + category: SourceCategory.AppConnection, + }, requesttracker: { icon: RequestTrackerIcon, displayName: "Request Tracker", diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index 5c46af9d9d..f06e4515a2 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -33,7 +33,8 @@ export type ValidSources = | "google_sites" | "loopio" | "sharepoint" - | "zendesk"; + | "zendesk" + | "axero"; export type ValidInputTypes = "load_state" | "poll" | "event"; export type ValidStatuses = @@ -327,6 +328,10 @@ export interface SharepointCredentialJson { aad_directory_id: string; } +export interface AxeroCredentialJson { + axero_api_token: string; +} + // DELETION export interface DeletionAttemptSnapshot { From fd69203be8765c881c737c0eb4ba95b4d35542bd Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 28 Mar 2024 11:11:37 -0700 Subject: [PATCH 24/58] More accurate input token count for LLM (#1267) --- backend/danswer/chat/chat_utils.py | 2 +- backend/danswer/llm/utils.py | 73 +++++++++++++++++++++++++----- backend/requirements/default.txt | 2 +- 3 files changed, 64 insertions(+), 13 deletions(-) diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index ee2f582c95..d752095575 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -55,7 +55,7 @@ def create_chat_chain( id_to_msg = {msg.id: msg for msg in all_chat_messages} if not all_chat_messages: - raise ValueError("No messages in Chat Session") + raise RuntimeError("No messages in Chat Session") root_message = all_chat_messages[0] if root_message.parent_message is not None: diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index c07b708bb5..50eb3cef85 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -241,6 +241,7 @@ def test_llm(llm: LLM) -> str | None: def get_llm_max_tokens( + model_map: dict, model_name: str | None = GEN_AI_MODEL_VERSION, model_provider: str = GEN_AI_MODEL_PROVIDER, ) -> int: @@ -250,18 +251,12 @@ def get_llm_max_tokens( return GEN_AI_MAX_TOKENS model_name = model_name or get_default_llm_version()[0] - # NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually - # returns the max OUTPUT tokens. Under the hood, this uses the `litellm.model_cost` dict, - # and there is no other interface to get what we want. This should be okay though, since the - # `model_cost` dict is a named public interface: - # https://litellm.vercel.app/docs/completion/token_usage#7-model_cost - litellm_model_map = litellm.model_cost try: if model_provider == "openai": - model_obj = litellm_model_map[model_name] + model_obj = model_map[model_name] else: - model_obj = litellm_model_map[f"{model_provider}/{model_name}"] + model_obj = model_map[f"{model_provider}/{model_name}"] if "max_tokens" in model_obj: return model_obj["max_tokens"] elif "max_input_tokens" in model_obj and "max_output_tokens" in model_obj: @@ -275,17 +270,73 @@ def get_llm_max_tokens( return 4096 +def get_llm_max_input_tokens( + output_tokens: int, + model_map: dict, + model_name: str | None = GEN_AI_MODEL_VERSION, + model_provider: str = GEN_AI_MODEL_PROVIDER, +) -> int | None: + try: + if model_provider == "openai": + model_obj = model_map[model_name] + else: + model_obj = model_map[f"{model_provider}/{model_name}"] + + max_in = model_obj.get("max_input_tokens") + max_out = model_obj.get("max_output_tokens") + if max_in is None or max_out is None: + # Can't calculate precisely, use the fallback method + return None + + # Some APIs may not actually work like this, but it's a safer approach generally speaking + # since worst case we remove some extra tokens from the input space + output_token_debt = 0 + if output_tokens > max_out: + logger.warning( + "More output tokens requested than model is likely able to handle" + ) + output_token_debt = output_tokens - max_out + + remaining_max_input_tokens = max_in - output_token_debt + return remaining_max_input_tokens + + except Exception: + # We can try the less accurate approach if this fails + return None + + def get_max_input_tokens( model_name: str | None = GEN_AI_MODEL_VERSION, model_provider: str = GEN_AI_MODEL_PROVIDER, output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, ) -> int: + # NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually + # returns the max OUTPUT tokens. Under the hood, this uses the `litellm.model_cost` dict, + # and there is no other interface to get what we want. This should be okay though, since the + # `model_cost` dict is a named public interface: + # https://litellm.vercel.app/docs/completion/token_usage#7-model_cost + # model_map is litellm.model_cost + litellm_model_map = litellm.model_cost + model_name = model_name or get_default_llm_version()[0] - input_toks = ( - get_llm_max_tokens(model_name=model_name, model_provider=model_provider) - - output_tokens + + input_toks = get_llm_max_input_tokens( + output_tokens=output_tokens, + model_map=litellm_model_map, + model_name=model_name, + model_provider=model_provider, ) + if input_toks is None: + input_toks = ( + get_llm_max_tokens( + model_name=model_name, + model_provider=model_provider, + model_map=litellm_model_map, + ) + - output_tokens + ) + if input_toks <= 0: raise RuntimeError("No tokens for input for the LLM given settings") diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 008ee8480c..9ca893876b 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -24,7 +24,7 @@ httpx-oauth==0.11.2 huggingface-hub==0.20.1 jira==3.5.1 langchain==0.1.9 -litellm==1.27.10 +litellm==1.34.8 llama-index==0.9.45 Mako==1.2.4 msal==1.26.0 From d46b475410619fc2719daac23dbaefc3e2a128b3 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 28 Mar 2024 11:26:11 -0700 Subject: [PATCH 25/58] Make porting from persistent volumes optional (#1268) --- backend/danswer/main.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/backend/danswer/main.py b/backend/danswer/main.py index e770cc8abb..ad4584774a 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -169,7 +169,12 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}" ) - port_filesystem_to_postgres() + try: + port_filesystem_to_postgres() + except Exception: + logger.debug( + "Skipping port of persistent volumes. Maybe these have already been removed?" + ) with Session(engine) as db_session: db_embedding_model = get_current_db_embedding_model(db_session) From f46e65be92d9cd778f60ec8bc5c81cd753407c74 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 28 Mar 2024 12:48:01 -0700 Subject: [PATCH 26/58] Save One Shot Docs (#1269) --- backend/danswer/db/engine.py | 9 +-- backend/danswer/llm/utils.py | 63 ++++--------------- .../one_shot_answer/answer_question.py | 16 +++-- backend/danswer/search/models.py | 6 +- 4 files changed, 32 insertions(+), 62 deletions(-) diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 22f1193fe8..1be57179c7 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -4,7 +4,6 @@ from datetime import datetime from typing import ContextManager -from ddtrace import tracer from sqlalchemy import text from sqlalchemy.engine import create_engine from sqlalchemy.engine import Engine @@ -77,9 +76,11 @@ def get_session_context_manager() -> ContextManager: def get_session() -> Generator[Session, None, None]: - with tracer.trace("db.get_session"): - with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session: - yield session + # The line below was added to monitor the latency caused by Postgres connections + # during API calls. + # with tracer.trace("db.get_session"): + with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session: + yield session async def get_async_session() -> AsyncGenerator[AsyncSession, None]: diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 50eb3cef85..507bbbff62 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -257,10 +257,12 @@ def get_llm_max_tokens( model_obj = model_map[model_name] else: model_obj = model_map[f"{model_provider}/{model_name}"] + + if "max_input_tokens" in model_obj: + return model_obj["max_input_tokens"] + if "max_tokens" in model_obj: return model_obj["max_tokens"] - elif "max_input_tokens" in model_obj and "max_output_tokens" in model_obj: - return model_obj["max_input_tokens"] + model_obj["max_output_tokens"] raise RuntimeError("No max tokens found for LLM") except Exception: @@ -270,41 +272,6 @@ def get_llm_max_tokens( return 4096 -def get_llm_max_input_tokens( - output_tokens: int, - model_map: dict, - model_name: str | None = GEN_AI_MODEL_VERSION, - model_provider: str = GEN_AI_MODEL_PROVIDER, -) -> int | None: - try: - if model_provider == "openai": - model_obj = model_map[model_name] - else: - model_obj = model_map[f"{model_provider}/{model_name}"] - - max_in = model_obj.get("max_input_tokens") - max_out = model_obj.get("max_output_tokens") - if max_in is None or max_out is None: - # Can't calculate precisely, use the fallback method - return None - - # Some APIs may not actually work like this, but it's a safer approach generally speaking - # since worst case we remove some extra tokens from the input space - output_token_debt = 0 - if output_tokens > max_out: - logger.warning( - "More output tokens requested than model is likely able to handle" - ) - output_token_debt = output_tokens - max_out - - remaining_max_input_tokens = max_in - output_token_debt - return remaining_max_input_tokens - - except Exception: - # We can try the less accurate approach if this fails - return None - - def get_max_input_tokens( model_name: str | None = GEN_AI_MODEL_VERSION, model_provider: str = GEN_AI_MODEL_PROVIDER, @@ -320,22 +287,14 @@ def get_max_input_tokens( model_name = model_name or get_default_llm_version()[0] - input_toks = get_llm_max_input_tokens( - output_tokens=output_tokens, - model_map=litellm_model_map, - model_name=model_name, - model_provider=model_provider, - ) - - if input_toks is None: - input_toks = ( - get_llm_max_tokens( - model_name=model_name, - model_provider=model_provider, - model_map=litellm_model_map, - ) - - output_tokens + input_toks = ( + get_llm_max_tokens( + model_name=model_name, + model_provider=model_provider, + model_map=litellm_model_map, ) + - output_tokens + ) if input_toks <= 0: raise RuntimeError("No tokens for input for the LLM given settings") diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 320027c54b..e37cc0e435 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -21,6 +21,7 @@ from danswer.db.chat import get_or_create_root_message from danswer.db.chat import get_prompt_by_id from danswer.db.chat import translate_db_message_to_chat_message_detail +from danswer.db.chat import translate_db_search_doc_to_server_search_doc from danswer.db.engine import get_session_context_manager from danswer.db.models import User from danswer.llm.answering.answer import Answer @@ -35,7 +36,6 @@ from danswer.one_shot_answer.qa_utils import combine_message_thread from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer -from danswer.search.models import SavedSearchDoc from danswer.search.models import SearchRequest from danswer.search.pipeline import SearchPipeline from danswer.search.utils import chunks_to_search_docs @@ -135,12 +135,20 @@ def stream_answer_objects( # First fetch and return the top chunks so the user can immediately see some results top_chunks = search_pipeline.reranked_docs top_docs = chunks_to_search_docs(top_chunks) - fake_saved_docs = [SavedSearchDoc.from_search_doc(doc) for doc in top_docs] - # Since this is in the one shot answer flow, we don't need to actually save the docs to DB + reference_db_search_docs = [ + create_db_search_doc(server_search_doc=top_doc, db_session=db_session) + for top_doc in top_docs + ] + + response_docs = [ + translate_db_search_doc_to_server_search_doc(db_search_doc) + for db_search_doc in reference_db_search_docs + ] + initial_response = QADocsResponse( rephrased_query=rephrased_query, - top_documents=fake_saved_docs, + top_documents=response_docs, predicted_flow=search_pipeline.predicted_flow, predicted_search=search_pipeline.predicted_search_type, applied_source_filters=search_pipeline.search_query.filters.source_type, diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index d2ad74c34e..d199a3b6bb 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -138,9 +138,11 @@ class SavedSearchDoc(SearchDoc): def from_search_doc( cls, search_doc: SearchDoc, db_doc_id: int = 0 ) -> "SavedSearchDoc": - """IMPORTANT: careful using this and not providing a db_doc_id""" + """IMPORTANT: careful using this and not providing a db_doc_id If db_doc_id is not + provided, it won't be able to actually fetch the saved doc and info later on. So only skip + providing this if the SavedSearchDoc will not be used in the future""" search_doc_data = search_doc.dict() - search_doc_data["score"] = search_doc_data.get("score", 0.0) + search_doc_data["score"] = search_doc_data.get("score") or 0.0 return cls(**search_doc_data, db_doc_id=db_doc_id) From 055cab2944f3cd44fd553ddca5852b9d22246fb1 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 28 Mar 2024 18:21:28 -0700 Subject: [PATCH 27/58] Public Slack Feedback Option (#1270) --- backend/danswer/configs/danswerbot_configs.py | 8 +++++ backend/danswer/danswerbot/slack/constants.py | 8 +++++ .../slack/handlers/handle_buttons.py | 36 +++++++++++++++---- backend/danswer/danswerbot/slack/utils.py | 9 +++++ .../search/postprocessing/postprocessing.py | 2 +- .../docker_compose/docker-compose.dev.yml | 1 + 6 files changed, 56 insertions(+), 8 deletions(-) diff --git a/backend/danswer/configs/danswerbot_configs.py b/backend/danswer/configs/danswerbot_configs.py index 5935c9b999..192a0594d1 100644 --- a/backend/danswer/configs/danswerbot_configs.py +++ b/backend/danswer/configs/danswerbot_configs.py @@ -21,6 +21,14 @@ DANSWER_REACT_EMOJI = os.environ.get("DANSWER_REACT_EMOJI") or "eyes" # When User needs more help, what should the emoji be DANSWER_FOLLOWUP_EMOJI = os.environ.get("DANSWER_FOLLOWUP_EMOJI") or "sos" +# What kind of message should be shown when someone gives an AI answer feedback to DanswerBot +# Defaults to Private if not provided or invalid +# Private: Only visible to user clicking the feedback +# Anonymous: Public but anonymous +# Public: Visible with the user name who submitted the feedback +DANSWER_BOT_FEEDBACK_VISIBILITY = ( + os.environ.get("DANSWER_BOT_FEEDBACK_VISIBILITY") or "private" +) # Should DanswerBot send an apology message if it's not able to find an answer # That way the user isn't confused as to why DanswerBot reacted but then said nothing # Off by default to be less intrusive (don't want to give a notif that just says we couldnt help) diff --git a/backend/danswer/danswerbot/slack/constants.py b/backend/danswer/danswerbot/slack/constants.py index a4930b593c..1e524025fc 100644 --- a/backend/danswer/danswerbot/slack/constants.py +++ b/backend/danswer/danswerbot/slack/constants.py @@ -1,3 +1,5 @@ +from enum import Enum + LIKE_BLOCK_ACTION_ID = "feedback-like" DISLIKE_BLOCK_ACTION_ID = "feedback-dislike" FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button" @@ -6,3 +8,9 @@ FOLLOWUP_BUTTON_RESOLVED_ACTION_ID = "followup-resolved-button" SLACK_CHANNEL_ID = "channel_id" VIEW_DOC_FEEDBACK_ID = "view-doc-feedback" + + +class FeedbackVisibility(str, Enum): + PRIVATE = "private" + ANONYMOUS = "anonymous" + PUBLIC = "public" diff --git a/backend/danswer/danswerbot/slack/handlers/handle_buttons.py b/backend/danswer/danswerbot/slack/handlers/handle_buttons.py index 0ca030612f..bec1959e3c 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_buttons.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_buttons.py @@ -15,6 +15,7 @@ from danswer.danswerbot.slack.blocks import get_document_feedback_blocks from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID +from danswer.danswerbot.slack.constants import FeedbackVisibility from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID from danswer.danswerbot.slack.constants import VIEW_DOC_FEEDBACK_ID from danswer.danswerbot.slack.utils import build_feedback_id @@ -22,6 +23,7 @@ from danswer.danswerbot.slack.utils import fetch_groupids_from_names from danswer.danswerbot.slack.utils import fetch_userids_from_emails from danswer.danswerbot.slack.utils import get_channel_name_from_id +from danswer.danswerbot.slack.utils import get_feedback_visibility from danswer.danswerbot.slack.utils import respond_in_thread from danswer.danswerbot.slack.utils import update_emote_react from danswer.db.engine import get_sqlalchemy_engine @@ -120,13 +122,33 @@ def handle_slack_feedback( else: logger_base.error(f"Feedback type '{feedback_type}' not supported") - # post message to slack confirming that feedback was received - client.chat_postEphemeral( - channel=channel_id_to_post_confirmation, - user=user_id_to_post_confirmation, - thread_ts=thread_ts_to_post_confirmation, - text="Thanks for your feedback!", - ) + if get_feedback_visibility() == FeedbackVisibility.PRIVATE or feedback_type not in [ + LIKE_BLOCK_ACTION_ID, + DISLIKE_BLOCK_ACTION_ID, + ]: + client.chat_postEphemeral( + channel=channel_id_to_post_confirmation, + user=user_id_to_post_confirmation, + thread_ts=thread_ts_to_post_confirmation, + text="Thanks for your feedback!", + ) + else: + feedback_response_txt = ( + "liked" if feedback_type == LIKE_BLOCK_ACTION_ID else "disliked" + ) + + if get_feedback_visibility() == FeedbackVisibility.ANONYMOUS: + msg = f"A user has {feedback_response_txt} the AI Answer" + else: + msg = f"<@{user_id_to_post_confirmation}> has {feedback_response_txt} the AI Answer" + + respond_in_thread( + client=client, + channel=channel_id_to_post_confirmation, + text=msg, + thread_ts=thread_ts_to_post_confirmation, + unfurl=False, + ) def handle_followup_button( diff --git a/backend/danswer/danswerbot/slack/utils.py b/backend/danswer/danswerbot/slack/utils.py index 5d761dec0e..5895dc52f9 100644 --- a/backend/danswer/danswerbot/slack/utils.py +++ b/backend/danswer/danswerbot/slack/utils.py @@ -18,11 +18,13 @@ from danswer.configs.app_configs import DISABLE_TELEMETRY from danswer.configs.constants import ID_SEPARATOR from danswer.configs.constants import MessageType +from danswer.configs.danswerbot_configs import DANSWER_BOT_FEEDBACK_VISIBILITY from danswer.configs.danswerbot_configs import DANSWER_BOT_MAX_QPM from danswer.configs.danswerbot_configs import DANSWER_BOT_MAX_WAIT_TIME from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES from danswer.connectors.slack.utils import make_slack_api_rate_limited from danswer.connectors.slack.utils import SlackTextCleaner +from danswer.danswerbot.slack.constants import FeedbackVisibility from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID from danswer.danswerbot.slack.tokens import fetch_tokens from danswer.db.engine import get_sqlalchemy_engine @@ -449,3 +451,10 @@ def waiter(self, func_randid: int) -> None: self.refill() del self.waiting_questions[0] + + +def get_feedback_visibility() -> FeedbackVisibility: + try: + return FeedbackVisibility(DANSWER_BOT_FEEDBACK_VISIBILITY.lower()) + except ValueError: + return FeedbackVisibility.PRIVATE diff --git a/backend/danswer/search/postprocessing/postprocessing.py b/backend/danswer/search/postprocessing/postprocessing.py index e1cee4bd6d..8b9c5617cc 100644 --- a/backend/danswer/search/postprocessing/postprocessing.py +++ b/backend/danswer/search/postprocessing/postprocessing.py @@ -219,4 +219,4 @@ def search_postprocessing( if chunk.unique_id in llm_chunk_selection ] else: - yield [] + yield cast(list[str], []) diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index d8cc63116f..0dfcba7661 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -151,6 +151,7 @@ services: - DANSWER_BOT_SLACK_APP_TOKEN=${DANSWER_BOT_SLACK_APP_TOKEN:-} - DANSWER_BOT_SLACK_BOT_TOKEN=${DANSWER_BOT_SLACK_BOT_TOKEN:-} - DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER=${DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER:-} + - DANSWER_BOT_FEEDBACK_VISIBILITY=${DANSWER_BOT_FEEDBACK_VISIBILITY:-} - DANSWER_BOT_DISPLAY_ERROR_MSGS=${DANSWER_BOT_DISPLAY_ERROR_MSGS:-} - DANSWER_BOT_RESPOND_EVERY_CHANNEL=${DANSWER_BOT_RESPOND_EVERY_CHANNEL:-} - DANSWER_BOT_DISABLE_COT=${DANSWER_BOT_DISABLE_COT:-} # Currently unused From 49acde0a8fac6024d287c42b661e2cc1af6c34ce Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 28 Mar 2024 23:53:42 -0700 Subject: [PATCH 28/58] URL-based chat sharing --- .../38eda64af7fe_add_chat_session_sharing.py | 41 ++ backend/danswer/db/chat.py | 55 +- backend/danswer/db/models.py | 10 + .../server/query_and_chat/chat_backend.py | 47 +- .../danswer/server/query_and_chat/models.py | 9 + web/package-lock.json | 580 +++++++++++++++++- web/package.json | 1 + web/src/app/chat/Chat.tsx | 39 +- web/src/app/chat/interfaces.ts | 9 + web/src/app/chat/message/Messages.tsx | 13 +- .../app/chat/modal/ShareChatSessionModal.tsx | 160 +++++ web/src/app/chat/page.tsx | 11 - .../app/chat/sessionSidebar/ChatSidebar.tsx | 12 +- .../chat/sessionSidebar/SessionDisplay.tsx | 174 ++++-- .../shared/[chatId]/SharedChatDisplay.tsx | 92 +++ web/src/app/chat/shared/[chatId]/page.tsx | 73 +++ web/src/components/BasicClickable.tsx | 4 +- web/src/components/CopyButton.tsx | 29 + web/src/components/openai/ApiKeyModal.tsx | 1 - web/src/components/popover/Popover.tsx | 38 ++ web/src/lib/time.ts | 16 + 21 files changed, 1297 insertions(+), 117 deletions(-) create mode 100644 backend/alembic/versions/38eda64af7fe_add_chat_session_sharing.py create mode 100644 web/src/app/chat/modal/ShareChatSessionModal.tsx create mode 100644 web/src/app/chat/shared/[chatId]/SharedChatDisplay.tsx create mode 100644 web/src/app/chat/shared/[chatId]/page.tsx create mode 100644 web/src/components/CopyButton.tsx create mode 100644 web/src/components/popover/Popover.tsx diff --git a/backend/alembic/versions/38eda64af7fe_add_chat_session_sharing.py b/backend/alembic/versions/38eda64af7fe_add_chat_session_sharing.py new file mode 100644 index 0000000000..e77ee186f4 --- /dev/null +++ b/backend/alembic/versions/38eda64af7fe_add_chat_session_sharing.py @@ -0,0 +1,41 @@ +"""Add chat session sharing + +Revision ID: 38eda64af7fe +Revises: 776b3bbe9092 +Create Date: 2024-03-27 19:41:29.073594 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "38eda64af7fe" +down_revision = "776b3bbe9092" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "chat_session", + sa.Column( + "shared_status", + sa.Enum( + "PUBLIC", + "PRIVATE", + name="chatsessionsharedstatus", + native_enum=False, + ), + nullable=True, + ), + ) + op.execute("UPDATE chat_session SET shared_status='PRIVATE'") + op.alter_column( + "chat_session", + "shared_status", + nullable=False, + ) + + +def downgrade() -> None: + op.drop_column("chat_session", "shared_status") diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 6dfa02c2f9..e2262f0f3d 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -18,6 +18,7 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import ChatMessage from danswer.db.models import ChatSession +from danswer.db.models import ChatSessionSharedStatus from danswer.db.models import DocumentSet as DBDocumentSet from danswer.db.models import Persona from danswer.db.models import Persona__User @@ -42,13 +43,17 @@ def get_chat_session_by_id( user_id: UUID | None, db_session: Session, include_deleted: bool = False, + is_shared: bool = False, ) -> ChatSession: stmt = select(ChatSession).where(ChatSession.id == chat_session_id) - # if user_id is None, assume this is an admin who should be able - # to view all chat sessions - if user_id is not None: - stmt = stmt.where(ChatSession.user_id == user_id) + if is_shared: + stmt = stmt.where(ChatSession.shared_status == ChatSessionSharedStatus.PUBLIC) + else: + # if user_id is None, assume this is an admin who should be able + # to view all chat sessions + if user_id is not None: + stmt = stmt.where(ChatSession.user_id == user_id) result = db_session.execute(stmt) chat_session = result.scalar_one_or_none() @@ -103,7 +108,11 @@ def create_chat_session( def update_chat_session( - user_id: UUID | None, chat_session_id: int, description: str, db_session: Session + db_session: Session, + user_id: UUID | None, + chat_session_id: int, + description: str | None = None, + sharing_status: ChatSessionSharedStatus | None = None, ) -> ChatSession: chat_session = get_chat_session_by_id( chat_session_id=chat_session_id, user_id=user_id, db_session=db_session @@ -112,7 +121,10 @@ def update_chat_session( if chat_session.deleted: raise ValueError("Trying to rename a deleted chat session") - chat_session.description = description + if description is not None: + chat_session.description = description + if sharing_status is not None: + chat_session.shared_status = sharing_status db_session.commit() @@ -745,6 +757,7 @@ def get_db_search_doc_by_id(doc_id: int, db_session: Session) -> DBSearchDoc | N def translate_db_search_doc_to_server_search_doc( db_search_doc: SearchDoc, + remove_doc_content: bool = False, ) -> SavedSearchDoc: return SavedSearchDoc( db_doc_id=db_search_doc.id, @@ -752,22 +765,30 @@ def translate_db_search_doc_to_server_search_doc( chunk_ind=db_search_doc.chunk_ind, semantic_identifier=db_search_doc.semantic_id, link=db_search_doc.link, - blurb=db_search_doc.blurb, + blurb=db_search_doc.blurb if not remove_doc_content else "", source_type=db_search_doc.source_type, boost=db_search_doc.boost, hidden=db_search_doc.hidden, - metadata=db_search_doc.doc_metadata, + metadata=db_search_doc.doc_metadata if not remove_doc_content else {}, score=db_search_doc.score, - match_highlights=db_search_doc.match_highlights, - updated_at=db_search_doc.updated_at, - primary_owners=db_search_doc.primary_owners, - secondary_owners=db_search_doc.secondary_owners, + match_highlights=db_search_doc.match_highlights + if not remove_doc_content + else [], + updated_at=db_search_doc.updated_at if not remove_doc_content else None, + primary_owners=db_search_doc.primary_owners if not remove_doc_content else [], + secondary_owners=db_search_doc.secondary_owners + if not remove_doc_content + else [], ) -def get_retrieval_docs_from_chat_message(chat_message: ChatMessage) -> RetrievalDocs: +def get_retrieval_docs_from_chat_message( + chat_message: ChatMessage, remove_doc_content: bool = False +) -> RetrievalDocs: top_documents = [ - translate_db_search_doc_to_server_search_doc(db_doc) + translate_db_search_doc_to_server_search_doc( + db_doc, remove_doc_content=remove_doc_content + ) for db_doc in chat_message.search_docs ] top_documents = sorted(top_documents, key=lambda doc: doc.score, reverse=True) # type: ignore @@ -775,7 +796,7 @@ def get_retrieval_docs_from_chat_message(chat_message: ChatMessage) -> Retrieval def translate_db_message_to_chat_message_detail( - chat_message: ChatMessage, + chat_message: ChatMessage, remove_doc_content: bool = False ) -> ChatMessageDetail: chat_msg_detail = ChatMessageDetail( message_id=chat_message.id, @@ -783,7 +804,9 @@ def translate_db_message_to_chat_message_detail( latest_child_message=chat_message.latest_child_message, message=chat_message.message, rephrased_query=chat_message.rephrased_query, - context_docs=get_retrieval_docs_from_chat_message(chat_message), + context_docs=get_retrieval_docs_from_chat_message( + chat_message, remove_doc_content=remove_doc_content + ), message_type=chat_message.message_type, time_sent=chat_message.time_sent, citations=chat_message.citations, diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index faafd2aedf..4a44882f2d 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -69,6 +69,11 @@ class IndexModelStatus(str, PyEnum): FUTURE = "FUTURE" +class ChatSessionSharedStatus(str, PyEnum): + PUBLIC = "public" + PRIVATE = "private" + + class Base(DeclarativeBase): pass @@ -586,6 +591,11 @@ class ChatSession(Base): one_shot: Mapped[bool] = mapped_column(Boolean, default=False) # Only ever set to True if system is set to not hard-delete chats deleted: Mapped[bool] = mapped_column(Boolean, default=False) + # controls whether or not this conversation is viewable by others + shared_status: Mapped[ChatSessionSharedStatus] = mapped_column( + Enum(ChatSessionSharedStatus, native_enum=False), + default=ChatSessionSharedStatus.PRIVATE, + ) time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 4fb98c5a15..4558903f7a 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -35,6 +35,7 @@ from danswer.server.query_and_chat.models import ChatSessionDetailResponse from danswer.server.query_and_chat.models import ChatSessionDetails from danswer.server.query_and_chat.models import ChatSessionsResponse +from danswer.server.query_and_chat.models import ChatSessionUpdateRequest from danswer.server.query_and_chat.models import CreateChatMessageRequest from danswer.server.query_and_chat.models import CreateChatSessionID from danswer.server.query_and_chat.models import RenameChatSessionResponse @@ -64,6 +65,7 @@ def get_user_chat_sessions( name=chat.description, persona_id=chat.persona_id, time_created=chat.time_created.isoformat(), + shared_status=chat.shared_status, ) for chat in chat_sessions ] @@ -73,6 +75,7 @@ def get_user_chat_sessions( @router.get("/get-chat-session/{session_id}") def get_chat_session( session_id: int, + is_shared: bool = False, user: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> ChatSessionDetailResponse: @@ -80,7 +83,10 @@ def get_chat_session( try: chat_session = get_chat_session_by_id( - chat_session_id=session_id, user_id=user_id, db_session=db_session + chat_session_id=session_id, + user_id=user_id, + db_session=db_session, + is_shared=is_shared, ) except ValueError: raise ValueError("Chat session does not exist or has been deleted") @@ -93,9 +99,15 @@ def get_chat_session( chat_session_id=session_id, description=chat_session.description, persona_id=chat_session.persona_id, + persona_name=chat_session.persona.name, messages=[ - translate_db_message_to_chat_message_detail(msg) for msg in session_messages + translate_db_message_to_chat_message_detail( + msg, remove_doc_content=is_shared # if shared, don't leak doc content + ) + for msg in session_messages ], + time_created=chat_session.time_created, + shared_status=chat_session.shared_status, ) @@ -133,7 +145,12 @@ def rename_chat_session( logger.info(f"Received rename request for chat session: {chat_session_id}") if name: - update_chat_session(user_id, chat_session_id, name, db_session) + update_chat_session( + db_session=db_session, + user_id=user_id, + chat_session_id=chat_session_id, + description=name, + ) return RenameChatSessionResponse(new_name=name) final_msg, history_msgs = create_chat_chain( @@ -143,11 +160,33 @@ def rename_chat_session( new_name = get_renamed_conversation_name(full_history=full_history) - update_chat_session(user_id, chat_session_id, new_name, db_session) + update_chat_session( + db_session=db_session, + user_id=user_id, + chat_session_id=chat_session_id, + description=new_name, + ) return RenameChatSessionResponse(new_name=new_name) +@router.patch("/chat-session/{session_id}") +def patch_chat_session( + session_id: int, + chat_session_update_req: ChatSessionUpdateRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + user_id = user.id if user is not None else None + update_chat_session( + db_session=db_session, + user_id=user_id, + chat_session_id=session_id, + sharing_status=chat_session_update_req.sharing_status, + ) + return None + + @router.delete("/delete-chat-session/{session_id}") def delete_chat_session_by_id( session_id: int, diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 592a4bdf27..054cb3179e 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -8,6 +8,7 @@ from danswer.configs.constants import DocumentSource from danswer.configs.constants import MessageType from danswer.configs.constants import SearchFeedbackType +from danswer.db.models import ChatSessionSharedStatus from danswer.search.models import BaseFilters from danswer.search.models import RetrievalDetails from danswer.search.models import SearchDoc @@ -120,6 +121,10 @@ class ChatRenameRequest(BaseModel): name: str | None = None +class ChatSessionUpdateRequest(BaseModel): + sharing_status: ChatSessionSharedStatus + + class RenameChatSessionResponse(BaseModel): new_name: str # This is only really useful if the name is generated @@ -129,6 +134,7 @@ class ChatSessionDetails(BaseModel): name: str persona_id: int time_created: str + shared_status: ChatSessionSharedStatus class ChatSessionsResponse(BaseModel): @@ -174,7 +180,10 @@ class ChatSessionDetailResponse(BaseModel): chat_session_id: int description: str persona_id: int + persona_name: str messages: list[ChatMessageDetail] + time_created: datetime + shared_status: ChatSessionSharedStatus class QueryValidationResponse(BaseModel): diff --git a/web/package-lock.json b/web/package-lock.json index 41b3655089..85323bd857 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -12,6 +12,7 @@ "@dnd-kit/modifiers": "^7.0.0", "@dnd-kit/sortable": "^8.0.0", "@phosphor-icons/react": "^2.0.8", + "@radix-ui/react-popover": "^1.0.7", "@tremor/react": "^3.9.2", "@types/js-cookie": "^3.0.3", "@types/node": "18.15.11", @@ -714,14 +715,19 @@ } }, "node_modules/@floating-ui/dom": { - "version": "1.5.3", - "resolved": "https://registry.npmjs.org/@floating-ui/dom/-/dom-1.5.3.tgz", - "integrity": "sha512-ClAbQnEqJAKCJOEbbLo5IUlZHkNszqhuxS4fHAVxRPXPya6Ysf2G8KypnYcOTpx6I8xcgF9bbHb6g/2KpbV8qA==", + "version": "1.6.3", + "resolved": "https://registry.npmjs.org/@floating-ui/dom/-/dom-1.6.3.tgz", + "integrity": "sha512-RnDthu3mzPlQ31Ss/BTwQ1zjzIhr3lk1gZB1OC56h/1vEtaXkESrOqL5fQVMfXpwGtRwX+YsZBdyHtJMQnkArw==", "dependencies": { - "@floating-ui/core": "^1.4.2", - "@floating-ui/utils": "^0.1.3" + "@floating-ui/core": "^1.0.0", + "@floating-ui/utils": "^0.2.0" } }, + "node_modules/@floating-ui/dom/node_modules/@floating-ui/utils": { + "version": "0.2.1", + "resolved": "https://registry.npmjs.org/@floating-ui/utils/-/utils-0.2.1.tgz", + "integrity": "sha512-9TANp6GPoMtYzQdt54kfAyMmz1+osLlXdg2ENroU7zzrtflTLrrC/lgrIfaSe+Wu0b89GKccT7vxXA0MoAIO+Q==" + }, "node_modules/@floating-ui/react": { "version": "0.19.2", "resolved": "https://registry.npmjs.org/@floating-ui/react/-/react-0.19.2.tgz", @@ -1148,6 +1154,441 @@ "node": ">=14" } }, + "node_modules/@radix-ui/primitive": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/primitive/-/primitive-1.0.1.tgz", + "integrity": "sha512-yQ8oGX2GVsEYMWGxcovu1uGWPCxV5BFfeeYxqPmuAzUyLT9qmaMXSAhXpb0WrspIeqYzdJpkh2vHModJPgRIaw==", + "dependencies": { + "@babel/runtime": "^7.13.10" + } + }, + "node_modules/@radix-ui/react-arrow": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-arrow/-/react-arrow-1.0.3.tgz", + "integrity": "sha512-wSP+pHsB/jQRaL6voubsQ/ZlrGBHHrOjmBnr19hxYgtS0WvAFwZhK2WP/YY5yF9uKECCEEDGxuLxq1NBK51wFA==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-primitive": "1.0.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-compose-refs": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-compose-refs/-/react-compose-refs-1.0.1.tgz", + "integrity": "sha512-fDSBgd44FKHa1FRMU59qBMPFcl2PZE+2nmqunj+BWFyYYjnhIDWL2ItDs3rrbJDQOtzt5nIebLCQc4QRfz6LJw==", + "dependencies": { + "@babel/runtime": "^7.13.10" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-context": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.0.1.tgz", + "integrity": "sha512-ebbrdFoYTcuZ0v4wG5tedGnp9tzcV8awzsxYph7gXUyvnNLuTIcCk1q17JEbnVhXAKG9oX3KtchwiMIAYp9NLg==", + "dependencies": { + "@babel/runtime": "^7.13.10" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-dismissable-layer": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.0.5.tgz", + "integrity": "sha512-aJeDjQhywg9LBu2t/At58hCvr7pEm0o2Ke1x33B+MhjNmmZ17sy4KImo0KPLgsnc/zN7GPdce8Cnn0SWvwZO7g==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-use-callback-ref": "1.0.1", + "@radix-ui/react-use-escape-keydown": "1.0.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-focus-guards": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-guards/-/react-focus-guards-1.0.1.tgz", + "integrity": "sha512-Rect2dWbQ8waGzhMavsIbmSVCgYxkXLxxR3ZvCX79JOglzdEy4JXMb98lq4hPxUbLr77nP0UOGf4rcMU+s1pUA==", + "dependencies": { + "@babel/runtime": "^7.13.10" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-focus-scope": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-scope/-/react-focus-scope-1.0.4.tgz", + "integrity": "sha512-sL04Mgvf+FmyvZeYfNu1EPAaaxD+aw7cYeIB9L9Fvq8+urhltTRaEo5ysKOpHuKPclsZcSUMKlN05x4u+CINpA==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-use-callback-ref": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-id": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-id/-/react-id-1.0.1.tgz", + "integrity": "sha512-tI7sT/kqYp8p96yGWY1OAnLHrqDgzHefRBKQ2YAkBS5ja7QLcZ9Z/uY7bEjPUatf8RomoXM8/1sMj1IJaE5UzQ==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-use-layout-effect": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-popover": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-popover/-/react-popover-1.0.7.tgz", + "integrity": "sha512-shtvVnlsxT6faMnK/a7n0wptwBD23xc1Z5mdrtKLwVEfsEMXodS0r5s0/g5P0hX//EKYZS2sxUjqfzlg52ZSnQ==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-dismissable-layer": "1.0.5", + "@radix-ui/react-focus-guards": "1.0.1", + "@radix-ui/react-focus-scope": "1.0.4", + "@radix-ui/react-id": "1.0.1", + "@radix-ui/react-popper": "1.1.3", + "@radix-ui/react-portal": "1.0.4", + "@radix-ui/react-presence": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-slot": "1.0.2", + "@radix-ui/react-use-controllable-state": "1.0.1", + "aria-hidden": "^1.1.1", + "react-remove-scroll": "2.5.5" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-popper": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.1.3.tgz", + "integrity": "sha512-cKpopj/5RHZWjrbF2846jBNacjQVwkP068DfmgrNJXpvVWrOvlAmE9xSiy5OqeE+Gi8D9fP+oDhUnPqNMY8/5w==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@floating-ui/react-dom": "^2.0.0", + "@radix-ui/react-arrow": "1.0.3", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-use-callback-ref": "1.0.1", + "@radix-ui/react-use-layout-effect": "1.0.1", + "@radix-ui/react-use-rect": "1.0.1", + "@radix-ui/react-use-size": "1.0.1", + "@radix-ui/rect": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-popper/node_modules/@floating-ui/react-dom": { + "version": "2.0.8", + "resolved": "https://registry.npmjs.org/@floating-ui/react-dom/-/react-dom-2.0.8.tgz", + "integrity": "sha512-HOdqOt3R3OGeTKidaLvJKcgg75S6tibQ3Tif4eyd91QnIJWr0NLvoXFpJA/j8HqkFSL68GDca9AuyWEHlhyClw==", + "dependencies": { + "@floating-ui/dom": "^1.6.1" + }, + "peerDependencies": { + "react": ">=16.8.0", + "react-dom": ">=16.8.0" + } + }, + "node_modules/@radix-ui/react-portal": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/@radix-ui/react-portal/-/react-portal-1.0.4.tgz", + "integrity": "sha512-Qki+C/EuGUVCQTOTD5vzJzJuMUlewbzuKyUy+/iHM2uwGiru9gZeBJtHAPKAEkB5KWGi9mP/CHKcY0wt1aW45Q==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-primitive": "1.0.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-presence": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-presence/-/react-presence-1.0.1.tgz", + "integrity": "sha512-UXLW4UAbIY5ZjcvzjfRFo5gxva8QirC9hF7wRE4U5gz+TP0DbRk+//qyuAQ1McDxBt1xNMBTaciFGvEmJvAZCg==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-use-layout-effect": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-primitive": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-1.0.3.tgz", + "integrity": "sha512-yi58uVyoAcK/Nq1inRY56ZSjKypBNKTa/1mcL8qdl6oJeEaDbOldlzrGn7P6Q3Id5d+SYNGc5AJgc4vGhjs5+g==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-slot": "1.0.2" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-slot": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.0.2.tgz", + "integrity": "sha512-YeTpuq4deV+6DusvVUW4ivBgnkHwECUu0BiN43L5UCDFgdhsRUWAghhTF5MbvNTPzmiFOx90asDSUjWuCNapwg==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-compose-refs": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-callback-ref": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-1.0.1.tgz", + "integrity": "sha512-D94LjX4Sp0xJFVaoQOd3OO9k7tpBYNOXdVhkltUbGv2Qb9OXdrg/CpsjlZv7ia14Sylv398LswWBVVu5nqKzAQ==", + "dependencies": { + "@babel/runtime": "^7.13.10" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-controllable-state": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-controllable-state/-/react-use-controllable-state-1.0.1.tgz", + "integrity": "sha512-Svl5GY5FQeN758fWKrjM6Qb7asvXeiZltlT4U2gVfl8Gx5UAv2sMR0LWo8yhsIZh2oQ0eFdZ59aoOOMV7b47VA==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-use-callback-ref": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-escape-keydown": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-escape-keydown/-/react-use-escape-keydown-1.0.3.tgz", + "integrity": "sha512-vyL82j40hcFicA+M4Ex7hVkB9vHgSse1ZWomAqV2Je3RleKGO5iM8KMOEtfoSB0PnIelMd2lATjTGMYqN5ylTg==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-use-callback-ref": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-layout-effect": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-layout-effect/-/react-use-layout-effect-1.0.1.tgz", + "integrity": "sha512-v/5RegiJWYdoCvMnITBkNNx6bCj20fiaJnWtRkU18yITptraXjffz5Qbn05uOiQnOvi+dbkznkoaMltz1GnszQ==", + "dependencies": { + "@babel/runtime": "^7.13.10" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-rect": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-rect/-/react-use-rect-1.0.1.tgz", + "integrity": "sha512-Cq5DLuSiuYVKNU8orzJMbl15TXilTnJKUCltMVQg53BQOF1/C5toAaGrowkgksdBQ9H+SRL23g0HDmg9tvmxXw==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/rect": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-size": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-size/-/react-use-size-1.0.1.tgz", + "integrity": "sha512-ibay+VqrgcaI6veAojjofPATwledXiSmX+C0KrBk/xgpX9rBzPV3OsfwlhQdUOFbh+LKQorLYT+xTXW9V8yd0g==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-use-layout-effect": "1.0.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/rect": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/rect/-/rect-1.0.1.tgz", + "integrity": "sha512-fyrgCaedtvMg9NK3en0pnOYJdtfwxUcNolezkNPUsoX57X8oQk+NkqcvzHXD2uKNij6GXmWU9NDru2IWjrO4BQ==", + "dependencies": { + "@babel/runtime": "^7.13.10" + } + }, "node_modules/@rushstack/eslint-patch": { "version": "1.6.0", "resolved": "https://registry.npmjs.org/@rushstack/eslint-patch/-/eslint-patch-1.6.0.tgz", @@ -2475,6 +2916,11 @@ "node": ">=8" } }, + "node_modules/detect-node-es": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/detect-node-es/-/detect-node-es-1.1.0.tgz", + "integrity": "sha512-ypdmJU/TbBby2Dxibuv7ZLW3Bs1QEmM7nHjEANfohJLvE0XVujisn1qPJcZxg+qDucsr+bP6fLD1rPS3AhJ7EQ==" + }, "node_modules/devlop": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/devlop/-/devlop-1.1.0.tgz", @@ -3418,6 +3864,14 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/get-nonce": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/get-nonce/-/get-nonce-1.0.1.tgz", + "integrity": "sha512-FJhYRoDaiatfEkUK8HKlicmu/3SGFD51q3itKDGoSTysQJBnfOcxU5GxnhE1E6soB76MbT0MBtnKJuXyAx+96Q==", + "engines": { + "node": ">=6" + } + }, "node_modules/get-symbol-description": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/get-symbol-description/-/get-symbol-description-1.0.0.tgz", @@ -3789,6 +4243,14 @@ "node": ">=12" } }, + "node_modules/invariant": { + "version": "2.2.4", + "resolved": "https://registry.npmjs.org/invariant/-/invariant-2.2.4.tgz", + "integrity": "sha512-phJfQVBuaJM5raOpJjSfkiD6BpbCE4Ns//LaXl6wGYtUBY83nWS6Rf9tXm2e8VaK60JEjYldbPif/A2B1C2gNA==", + "dependencies": { + "loose-envify": "^1.0.0" + } + }, "node_modules/is-alphabetical": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/is-alphabetical/-/is-alphabetical-2.0.1.tgz", @@ -6206,6 +6668,51 @@ "react": ">=18" } }, + "node_modules/react-remove-scroll": { + "version": "2.5.5", + "resolved": "https://registry.npmjs.org/react-remove-scroll/-/react-remove-scroll-2.5.5.tgz", + "integrity": "sha512-ImKhrzJJsyXJfBZ4bzu8Bwpka14c/fQt0k+cyFp/PBhTfyDnU5hjOtM4AG/0AMyy8oKzOTR0lDgJIM7pYXI0kw==", + "dependencies": { + "react-remove-scroll-bar": "^2.3.3", + "react-style-singleton": "^2.2.1", + "tslib": "^2.1.0", + "use-callback-ref": "^1.3.0", + "use-sidecar": "^1.1.2" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/react-remove-scroll-bar": { + "version": "2.3.6", + "resolved": "https://registry.npmjs.org/react-remove-scroll-bar/-/react-remove-scroll-bar-2.3.6.tgz", + "integrity": "sha512-DtSYaao4mBmX+HDo5YWYdBWQwYIQQshUV/dVxFxK+KM26Wjwp1gZ6rv6OC3oujI6Bfu6Xyg3TwK533AQutsn/g==", + "dependencies": { + "react-style-singleton": "^2.2.1", + "tslib": "^2.0.0" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/react-smooth": { "version": "2.0.5", "resolved": "https://registry.npmjs.org/react-smooth/-/react-smooth-2.0.5.tgz", @@ -6243,6 +6750,28 @@ "react-dom": ">=15.0.0" } }, + "node_modules/react-style-singleton": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/react-style-singleton/-/react-style-singleton-2.2.1.tgz", + "integrity": "sha512-ZWj0fHEMyWkHzKYUr2Bs/4zU6XLmq9HsgBURm7g5pAVfyn49DgUiNgY2d4lXRlYSiCif9YBGpQleewkcqddc7g==", + "dependencies": { + "get-nonce": "^1.0.0", + "invariant": "^2.2.4", + "tslib": "^2.0.0" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/react-transition-group": { "version": "4.4.5", "resolved": "https://registry.npmjs.org/react-transition-group/-/react-transition-group-4.4.5.tgz", @@ -7608,6 +8137,47 @@ "punycode": "^2.1.0" } }, + "node_modules/use-callback-ref": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/use-callback-ref/-/use-callback-ref-1.3.2.tgz", + "integrity": "sha512-elOQwe6Q8gqZgDA8mrh44qRTQqpIHDcZ3hXTLjBe1i4ph8XpNJnO+aQf3NaG+lriLopI4HMx9VjQLfPQ6vhnoA==", + "dependencies": { + "tslib": "^2.0.0" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/use-sidecar": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/use-sidecar/-/use-sidecar-1.1.2.tgz", + "integrity": "sha512-epTbsLuzZ7lPClpz2TyryBfztm7m+28DlEv2ZCQ3MDr5ssiwyOwGH/e5F9CkfWjJ1t4clvI58yF822/GUkjjhw==", + "dependencies": { + "detect-node-es": "^1.1.0", + "tslib": "^2.0.0" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "@types/react": "^16.9.0 || ^17.0.0 || ^18.0.0", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/use-sync-external-store": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/use-sync-external-store/-/use-sync-external-store-1.2.0.tgz", diff --git a/web/package.json b/web/package.json index 4d3fc3794b..37788280d4 100644 --- a/web/package.json +++ b/web/package.json @@ -13,6 +13,7 @@ "@dnd-kit/modifiers": "^7.0.0", "@dnd-kit/sortable": "^8.0.0", "@phosphor-icons/react": "^2.0.8", + "@radix-ui/react-popover": "^1.0.7", "@tremor/react": "^3.9.2", "@types/js-cookie": "^3.0.3", "@types/node": "18.15.11", diff --git a/web/src/app/chat/Chat.tsx b/web/src/app/chat/Chat.tsx index c2851c8be2..9430c4e668 100644 --- a/web/src/app/chat/Chat.tsx +++ b/web/src/app/chat/Chat.tsx @@ -1,12 +1,13 @@ "use client"; import { useEffect, useRef, useState } from "react"; -import { FiRefreshCcw, FiSend, FiStopCircle } from "react-icons/fi"; +import { FiSend, FiShare2, FiStopCircle } from "react-icons/fi"; import { AIMessage, HumanMessage } from "./message/Messages"; import { AnswerPiecePacket, DanswerDocument } from "@/lib/search/interfaces"; import { BackendChatSession, BackendMessage, + ChatSessionSharedStatus, DocumentsResponse, Message, RetrievalType, @@ -44,6 +45,7 @@ import { HEADER_PADDING } from "@/lib/constants"; import { computeAvailableFilters } from "@/lib/filters"; import { useDocumentSelection } from "./useDocumentSelection"; import { StarterMessage } from "./StarterMessage"; +import { ShareChatSessionModal } from "./modal/ShareChatSessionModal"; const MAX_INPUT_HEIGHT = 200; @@ -114,6 +116,7 @@ export const Chat = ({ setSelectedPersona(undefined); } setMessageHistory([]); + setChatSessionSharedStatus(ChatSessionSharedStatus.Private); return; } @@ -127,6 +130,7 @@ export const Chat = ({ (persona) => persona.id === chatSession.persona_id ) ); + const newMessageHistory = processRawChatHistory(chatSession.messages); setMessageHistory(newMessageHistory); @@ -136,6 +140,8 @@ export const Chat = ({ latestMessageId !== undefined ? latestMessageId : null ); + setChatSessionSharedStatus(chatSession.shared_status); + setIsFetchingChatMessages(false); } @@ -173,6 +179,9 @@ export const Chat = ({ ); const livePersona = selectedPersona || availablePersonas[0]; + const [chatSessionSharedStatus, setChatSessionSharedStatus] = + useState(ChatSessionSharedStatus.Private); + useEffect(() => { if (messageHistory.length === 0 && chatSessionId === null) { setSelectedPersona( @@ -225,6 +234,8 @@ export const Chat = ({ const [currentFeedback, setCurrentFeedback] = useState< [FeedbackType, number] | null >(null); + const [sharingModalVisible, setSharingModalVisible] = + useState(false); // auto scroll as message comes out const scrollableDivRef = useRef(null); @@ -503,6 +514,21 @@ export const Chat = ({ /> )} + {sharingModalVisible && chatSessionId !== null && ( + setSharingModalVisible(false)} + onShare={(shared) => + setChatSessionSharedStatus( + shared + ? ChatSessionSharedStatus.Public + : ChatSessionSharedStatus.Private + ) + } + /> + )} + {documentSidebarInitialWidth !== undefined ? ( <>
{livePersona && ( -
+
+ + {chatSessionId !== null && ( +
setSharingModalVisible(true)} + className="ml-auto mr-6 my-auto border-border border p-2 rounded cursor-pointer hover:bg-hover-light" + > + +
+ )}
)} diff --git a/web/src/app/chat/interfaces.ts b/web/src/app/chat/interfaces.ts index 7eb9f50ce9..3ef716720b 100644 --- a/web/src/app/chat/interfaces.ts +++ b/web/src/app/chat/interfaces.ts @@ -6,6 +6,11 @@ export enum RetrievalType { SelectedDocs = "selectedDocs", } +export enum ChatSessionSharedStatus { + Private = "private", + Public = "public", +} + export interface RetrievalDetails { run_search: "always" | "never" | "auto"; real_time: boolean; @@ -20,6 +25,7 @@ export interface ChatSession { name: string; persona_id: number; time_created: string; + shared_status: ChatSessionSharedStatus; } export interface Message { @@ -36,7 +42,10 @@ export interface BackendChatSession { chat_session_id: number; description: string; persona_id: number; + persona_name: string; messages: BackendMessage[]; + time_created: string; + shared_status: ChatSessionSharedStatus; } export interface BackendMessage { diff --git a/web/src/app/chat/message/Messages.tsx b/web/src/app/chat/message/Messages.tsx index a4ab222d93..d90d3bfa30 100644 --- a/web/src/app/chat/message/Messages.tsx +++ b/web/src/app/chat/message/Messages.tsx @@ -15,6 +15,7 @@ import { SourceIcon } from "@/components/SourceIcon"; import { ThreeDots } from "react-loader-spinner"; import { SkippedSearch } from "./SkippedSearch"; import remarkGfm from "remark-gfm"; +import { CopyButton } from "@/components/CopyButton"; export const Hoverable: React.FC<{ children: JSX.Element; @@ -22,7 +23,7 @@ export const Hoverable: React.FC<{ }> = ({ children, onClick }) => { return (
{children} @@ -201,15 +202,7 @@ export const AIMessage = ({
{handleFeedback && (
- { - navigator.clipboard.writeText(content.toString()); - setCopyClicked(true); - setTimeout(() => setCopyClicked(false), 3000); - }} - > - {copyClicked ? : } - + handleFeedback("like")}> diff --git a/web/src/app/chat/modal/ShareChatSessionModal.tsx b/web/src/app/chat/modal/ShareChatSessionModal.tsx new file mode 100644 index 0000000000..5a00c67390 --- /dev/null +++ b/web/src/app/chat/modal/ShareChatSessionModal.tsx @@ -0,0 +1,160 @@ +import { useState } from "react"; +import { ModalWrapper } from "./ModalWrapper"; +import { Button, Callout, Divider, Text } from "@tremor/react"; +import { Spinner } from "@/components/Spinner"; +import { ChatSessionSharedStatus } from "../interfaces"; +import { FiCopy, FiX } from "react-icons/fi"; +import { Hoverable } from "../message/Messages"; +import { CopyButton } from "@/components/CopyButton"; + +function buildShareLink(chatSessionId: number) { + const baseUrl = `${window.location.protocol}//${window.location.host}`; + return `${baseUrl}/chat/shared/${chatSessionId}`; +} + +async function generateShareLink(chatSessionId: number) { + const response = await fetch(`/api/chat/chat-session/${chatSessionId}`, { + method: "PATCH", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ sharing_status: "public" }), + }); + + if (response.ok) { + return buildShareLink(chatSessionId); + } + return null; +} + +async function deleteShareLink(chatSessionId: number) { + const response = await fetch(`/api/chat/chat-session/${chatSessionId}`, { + method: "PATCH", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ sharing_status: "private" }), + }); + + return response.ok; +} + +export function ShareChatSessionModal({ + chatSessionId, + existingSharedStatus, + onShare, + onClose, +}: { + chatSessionId: number; + existingSharedStatus: ChatSessionSharedStatus; + onShare?: (shared: boolean) => void; + onClose: () => void; +}) { + const [linkGenerating, setLinkGenerating] = useState(false); + const [shareLink, setShareLink] = useState( + existingSharedStatus === ChatSessionSharedStatus.Public + ? buildShareLink(chatSessionId) + : "" + ); + + return ( + + <> +
+

+ Share link to Chat +

+ +
+ +
+
+ + {linkGenerating && } + +
+ {shareLink ? ( +
+ + This chat session is currently shared. Anyone at your + organization can view the message history using the following + link: + + + + + + + + Click the button below to make the chat private again. + + + +
+ ) : ( +
+ + Ensure that all content in the chat is safe to share with the + whole organization. The content of the retrieved documents will + not be visible, but the names of cited documents as well as the + AI and human messages will be visible. + + + +
+ )} +
+ +
+ ); +} diff --git a/web/src/app/chat/page.tsx b/web/src/app/chat/page.tsx index 3106319e74..57de9de805 100644 --- a/web/src/app/chat/page.tsx +++ b/web/src/app/chat/page.tsx @@ -43,7 +43,6 @@ export default async function Page({ fetchSS("/persona?include_default=true"), fetchSS("/chat/get-user-chat-sessions"), fetchSS("/query/valid-tags"), - fetchSS("/secondary-index/get-embedding-models"), ]; // catch cases where the backend is completely unreachable here @@ -68,7 +67,6 @@ export default async function Page({ const personasResponse = results[4] as Response | null; const chatSessionsResponse = results[5] as Response | null; const tagsResponse = results[6] as Response | null; - const embeddingModelResponse = results[7] as Response | null; const authDisabled = authTypeMetadata?.authType === "disabled"; if (!authDisabled && !user) { @@ -130,15 +128,6 @@ export default async function Page({ console.log(`Failed to fetch tags - ${tagsResponse?.status}`); } - const embeddingModelVersionInfo = - embeddingModelResponse && embeddingModelResponse.ok - ? ((await embeddingModelResponse.json()) as FullEmbeddingModelResponse) - : null; - const currentEmbeddingModelName = - embeddingModelVersionInfo?.current_model_name; - const nextEmbeddingModelName = - embeddingModelVersionInfo?.secondary_model_name; - const defaultPersonaIdRaw = searchParams["personaId"]; const defaultPersonaId = defaultPersonaIdRaw ? parseInt(defaultPersonaIdRaw) diff --git a/web/src/app/chat/sessionSidebar/ChatSidebar.tsx b/web/src/app/chat/sessionSidebar/ChatSidebar.tsx index 19f6eeaccf..199fab3e50 100644 --- a/web/src/app/chat/sessionSidebar/ChatSidebar.tsx +++ b/web/src/app/chat/sessionSidebar/ChatSidebar.tsx @@ -78,8 +78,8 @@ export const ChatSidebar = ({ return (
{ - const isSelected = currentChatId === chat.id; - return ( -
- -
- ); - })} */}
{ @@ -33,6 +49,14 @@ export function ChatSessionDisplay({ return ( <> + {isShareModalVisible && ( + setIsShareModalVisible(false)} + /> + )} + {isDeletionModalVisible && ( setIsDeletionModalVisible(false)} @@ -50,69 +74,107 @@ export function ChatSessionDisplay({ /> )} -
-
- -
{" "} - {isRenamingChat ? ( - setChatName(e.target.value)} - onKeyDown={(event) => { - if (event.key === "Enter") { - onRename(); - event.preventDefault(); - } - }} - className="-my-px px-1 mr-2 w-full rounded" - /> - ) : ( -

- {chatName || `Chat ${chatSession.id}`} -

- )} - {isSelected && - (isRenamingChat ? ( -
-
- -
-
{ - setChatName(chatSession.name); - setIsRenamingChat(false); - }} - className={`hover:bg-black/10 p-1 -m-1 rounded ml-2`} - > - -
-
+ <> +
+
+ +
{" "} + {isRenamingChat ? ( + setChatName(e.target.value)} + onKeyDown={(event) => { + if (event.key === "Enter") { + onRename(); + event.preventDefault(); + } + }} + className="-my-px px-1 mr-2 w-full rounded" + /> ) : ( -
-
setIsRenamingChat(true)} - className={`hover:bg-black/10 p-1 -m-1 rounded`} - > - +

+ {chatName || `Chat ${chatSession.id}`} +

+ )} + {isSelected && + (isRenamingChat ? ( +
+
+ +
+
{ + setChatName(chatSession.name); + setIsRenamingChat(false); + }} + className={`hover:bg-black/10 p-1 -m-1 rounded ml-2`} + > + +
-
setIsDeletionModalVisible(true)} - className={`hover:bg-black/10 p-1 -m-1 rounded ml-2`} - > - + ) : ( +
+
+
{ + setIsMoreOptionsDropdownOpen( + !isMoreOptionsDropdownOpen + ); + }} + className={"-m-1"} + > + + setIsMoreOptionsDropdownOpen(open) + } + content={ +
+ +
+ } + popover={ +
+ setIsShareModalVisible(true)} + /> + setIsRenamingChat(true)} + /> +
+ } + /> +
+
+
setIsDeletionModalVisible(true)} + className={`hover:bg-black/10 p-1 -m-1 rounded ml-2`} + > + +
-
- ))} -
+ ))} +
+ {isSelected && !isRenamingChat && ( +
+ )} + {!isSelected && ( +
+ )} + diff --git a/web/src/app/chat/shared/[chatId]/SharedChatDisplay.tsx b/web/src/app/chat/shared/[chatId]/SharedChatDisplay.tsx new file mode 100644 index 0000000000..63d51f2398 --- /dev/null +++ b/web/src/app/chat/shared/[chatId]/SharedChatDisplay.tsx @@ -0,0 +1,92 @@ +"use client"; + +import { humanReadableFormat } from "@/lib/time"; +import { BackendChatSession } from "../../interfaces"; +import { getCitedDocumentsFromMessage, processRawChatHistory } from "../../lib"; +import { AIMessage, HumanMessage } from "../../message/Messages"; +import { Button, Callout, Divider } from "@tremor/react"; +import { useRouter } from "next/navigation"; + +function BackToDanswerButton() { + const router = useRouter(); + + return ( +
+
+ +
+
+ ); +} + +export function SharedChatDisplay({ + chatSession, +}: { + chatSession: BackendChatSession | null; +}) { + if (!chatSession) { + return ( +
+
+ + Did not find a shared chat with the specified ID. + +
+ + +
+ ); + } + + const messages = processRawChatHistory(chatSession.messages); + + return ( +
+
+
+
+
+

+ {chatSession.description || + `Chat ${chatSession.chat_session_id}`} +

+

+ {humanReadableFormat(chatSession.time_created)} +

+ + +
+ +
+ {messages.map((message) => { + if (message.type === "user") { + return ( + + ); + } else { + return ( + + ); + } + })} +
+
+
+
+ + +
+ ); +} diff --git a/web/src/app/chat/shared/[chatId]/page.tsx b/web/src/app/chat/shared/[chatId]/page.tsx new file mode 100644 index 0000000000..8587c211e6 --- /dev/null +++ b/web/src/app/chat/shared/[chatId]/page.tsx @@ -0,0 +1,73 @@ +import { User } from "@/lib/types"; +import { + AuthTypeMetadata, + getAuthTypeMetadataSS, + getCurrentUserSS, +} from "@/lib/userSS"; +import { fetchSS } from "@/lib/utilsSS"; +import { redirect } from "next/navigation"; +import { BackendChatSession } from "../../interfaces"; +import { Header } from "@/components/Header"; +import { SharedChatDisplay } from "./SharedChatDisplay"; + +async function getSharedChat(chatId: string) { + const response = await fetchSS( + `/chat/get-chat-session/${chatId}?is_shared=True` + ); + if (response.ok) { + return await response.json(); + } + return null; +} + +export default async function Page({ params }: { params: { chatId: string } }) { + const tasks = [ + getAuthTypeMetadataSS(), + getCurrentUserSS(), + getSharedChat(params.chatId), + ]; + + // catch cases where the backend is completely unreachable here + // without try / catch, will just raise an exception and the page + // will not render + let results: (User | AuthTypeMetadata | null)[] = [ + null, + null, + null, + null, + null, + null, + null, + null, + null, + ]; + try { + results = await Promise.all(tasks); + } catch (e) { + console.log(`Some fetch failed for the main search page - ${e}`); + } + const authTypeMetadata = results[0] as AuthTypeMetadata | null; + const user = results[1] as User | null; + const chatSession = results[2] as BackendChatSession | null; + + const authDisabled = authTypeMetadata?.authType === "disabled"; + if (!authDisabled && !user) { + return redirect("/auth/login"); + } + + if (user && !user.is_verified && authTypeMetadata?.requiresVerification) { + return redirect("/auth/waiting-on-verification"); + } + + return ( +
+
+
+
+ +
+ +
+
+ ); +} diff --git a/web/src/components/BasicClickable.tsx b/web/src/components/BasicClickable.tsx index 8a6b9ce04c..9184035ab3 100644 --- a/web/src/components/BasicClickable.tsx +++ b/web/src/components/BasicClickable.tsx @@ -71,7 +71,7 @@ export function BasicSelectable({ fullWidth?: boolean; }) { return ( - +
); } diff --git a/web/src/components/CopyButton.tsx b/web/src/components/CopyButton.tsx new file mode 100644 index 0000000000..7adcb8a9af --- /dev/null +++ b/web/src/components/CopyButton.tsx @@ -0,0 +1,29 @@ +import { Hoverable } from "@/app/chat/message/Messages"; +import { useState } from "react"; +import { FiCheck, FiCopy } from "react-icons/fi"; + +export function CopyButton({ + content, + onClick, +}: { + content?: string; + onClick?: () => void; +}) { + const [copyClicked, setCopyClicked] = useState(false); + + return ( + { + if (content) { + navigator.clipboard.writeText(content.toString()); + } + onClick && onClick(); + + setCopyClicked(true); + setTimeout(() => setCopyClicked(false), 3000); + }} + > + {copyClicked ? : } + + ); +} diff --git a/web/src/components/openai/ApiKeyModal.tsx b/web/src/components/openai/ApiKeyModal.tsx index a0bc5dc56e..1c38160e9d 100644 --- a/web/src/components/openai/ApiKeyModal.tsx +++ b/web/src/components/openai/ApiKeyModal.tsx @@ -19,7 +19,6 @@ export const ApiKeyModal = () => { useEffect(() => { checkApiKey().then((error) => { - console.log(error); if (error) { setErrorMsg(error); } diff --git a/web/src/components/popover/Popover.tsx b/web/src/components/popover/Popover.tsx new file mode 100644 index 0000000000..ac5d5bcf2a --- /dev/null +++ b/web/src/components/popover/Popover.tsx @@ -0,0 +1,38 @@ +"use client"; + +import * as RadixPopover from "@radix-ui/react-popover"; + +export function Popover({ + open, + onOpenChange, + content, + popover, +}: { + open: boolean; + onOpenChange: (open: boolean) => void; + content: JSX.Element; + popover: JSX.Element; +}) { + /* + This Popover is needed when we want to put a popup / dropdown in a component + with `overflow-hidden`. This is because the Radix Popover uses `absolute` positioning + outside of the component's container. + */ + if (!open) { + return content; + } + + return ( + + + {/* NOTE: this weird `-mb-1.5` is needed to offset the Anchor, otherwise + the content will shift up by 1.5px when the Popover is open. */} + {open ?
{content}
: content} +
+ + + {popover} + +
+ ); +} diff --git a/web/src/lib/time.ts b/web/src/lib/time.ts index a6b61c5add..0ec42b2ef7 100644 --- a/web/src/lib/time.ts +++ b/web/src/lib/time.ts @@ -59,3 +59,19 @@ export function localizeAndPrettify(dateString: string) { const date = new Date(dateString); return date.toLocaleString(); } + +export function humanReadableFormat(dateString: string): string { + // Create a Date object from the dateString + const date = new Date(dateString); + + // Use Intl.DateTimeFormat to format the date + // Specify the locale as 'en-US' and options for month, day, and year + const formatter = new Intl.DateTimeFormat("en-US", { + month: "long", // full month name + day: "numeric", // numeric day + year: "numeric", // numeric year + }); + + // Format the date and return it + return formatter.format(date); +} From 22477b1acaad93b5d12a831750dfa373bcf16c3e Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sat, 30 Mar 2024 00:18:57 -0700 Subject: [PATCH 29/58] Chunker Gmail Issue Logging (#1274) --- backend/danswer/indexing/chunker.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/backend/danswer/indexing/chunker.py b/backend/danswer/indexing/chunker.py index 9be9348b9f..b6f59d1890 100644 --- a/backend/danswer/indexing/chunker.py +++ b/backend/danswer/indexing/chunker.py @@ -5,18 +5,22 @@ from danswer.configs.app_configs import BLURB_SIZE from danswer.configs.app_configs import CHUNK_OVERLAP from danswer.configs.app_configs import MINI_CHUNK_SIZE +from danswer.configs.constants import DocumentSource from danswer.configs.constants import SECTION_SEPARATOR from danswer.configs.constants import TITLE_SEPARATOR from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.connectors.models import Document from danswer.indexing.models import DocAwareChunk from danswer.search.search_nlp_models import get_default_tokenizer +from danswer.utils.logger import setup_logger from danswer.utils.text_processing import shared_precompare_cleanup - if TYPE_CHECKING: from transformers import AutoTokenizer # type:ignore + +logger = setup_logger() + ChunkFunc = Callable[[Document], list[DocAwareChunk]] @@ -178,4 +182,7 @@ def chunk(self, document: Document) -> list[DocAwareChunk]: class DefaultChunker(Chunker): def chunk(self, document: Document) -> list[DocAwareChunk]: + # Specifically for reproducing an issue with gmail + if document.source == DocumentSource.GMAIL: + logger.debug(f"Chunking {document.semantic_identifier}") return chunk_document(document) From 783696a6714af909f1fe830bc09633ad6baa7d75 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sun, 31 Mar 2024 14:45:20 -0700 Subject: [PATCH 30/58] Axero Spaces (#1276) --- ...er-build-push-backend-container-on-tag.yml | 1 + backend/danswer/connectors/axero/connector.py | 59 +++-- web/src/app/admin/connectors/axero/page.tsx | 235 +++++++++--------- web/src/lib/types.ts | 5 + 4 files changed, 168 insertions(+), 132 deletions(-) diff --git a/.github/workflows/docker-build-push-backend-container-on-tag.yml b/.github/workflows/docker-build-push-backend-container-on-tag.yml index e95c143fb4..82aa24e6ab 100644 --- a/.github/workflows/docker-build-push-backend-container-on-tag.yml +++ b/.github/workflows/docker-build-push-backend-container-on-tag.yml @@ -38,5 +38,6 @@ jobs: - name: Run Trivy vulnerability scanner uses: aquasecurity/trivy-action@master with: + # To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend image-ref: docker.io/danswer/danswer-backend:${{ github.ref_name }} severity: 'CRITICAL,HIGH' diff --git a/backend/danswer/connectors/axero/connector.py b/backend/danswer/connectors/axero/connector.py index e19fb39eb9..fcb4395589 100644 --- a/backend/danswer/connectors/axero/connector.py +++ b/backend/danswer/connectors/axero/connector.py @@ -38,6 +38,7 @@ def _get_entities( axero_base_url: str, start: datetime, end: datetime, + space_id: str | None = None, ) -> list[dict]: endpoint = axero_base_url + "api/content/list" page_num = 1 @@ -51,6 +52,10 @@ def _get_entities( "SortOrder": "1", # descending "StartPage": str(page_num), } + + if space_id is not None: + params["SpaceID"] = space_id + res = requests.get(endpoint, headers=_get_auth_header(api_key), params=params) res.raise_for_status() @@ -116,7 +121,8 @@ def _translate_content_to_doc(content: dict) -> Document: class AxeroConnector(PollConnector): def __init__( self, - base_url: str, + # Strings of the integer ids of the spaces + spaces: list[str] | None = None, include_article: bool = True, include_blog: bool = True, include_wiki: bool = True, @@ -129,20 +135,24 @@ def __init__( self.include_wiki = include_wiki self.include_forum = include_forum self.batch_size = batch_size + self.space_ids = spaces self.axero_key = None + self.base_url = None + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + self.axero_key = credentials["axero_api_token"] + # As the API key specifically applies to a particular deployment, this is + # included as part of the credential + base_url = credentials["base_url"] if not base_url.endswith("/"): base_url += "/" self.base_url = base_url - - def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: - self.axero_key = credentials["axero_api_token"] return None def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: - if not self.axero_key: + if not self.axero_key or not self.base_url: raise ConnectorMissingCredentialError("Axero") start_datetime = datetime.utcfromtimestamp(start).replace(tzinfo=timezone.utc) @@ -158,26 +168,35 @@ def poll_source( if self.include_forum: raise NotImplementedError("Forums for Axero not supported currently") - for entity in entity_types: - articles = _get_entities( - entity_type=entity, - api_key=self.axero_key, - axero_base_url=self.base_url, - start=start_datetime, - end=end_datetime, - ) - yield from process_in_batches( - objects=articles, - process_function=_translate_content_to_doc, - batch_size=self.batch_size, - ) + iterable_space_ids = self.space_ids if self.space_ids else [None] + + for space_id in iterable_space_ids: + for entity in entity_types: + axero_obj = _get_entities( + entity_type=entity, + api_key=self.axero_key, + axero_base_url=self.base_url, + start=start_datetime, + end=end_datetime, + space_id=space_id, + ) + yield from process_in_batches( + objects=axero_obj, + process_function=_translate_content_to_doc, + batch_size=self.batch_size, + ) if __name__ == "__main__": import os - connector = AxeroConnector(base_url=os.environ["AXERO_BASE_URL"]) - connector.load_credentials({"axero_api_token": os.environ["AXERO_API_TOKEN"]}) + connector = AxeroConnector() + connector.load_credentials( + { + "axero_api_token": os.environ["AXERO_API_TOKEN"], + "base_url": os.environ["AXERO_BASE_URL"], + } + ) current = time.time() one_year_ago = current - 24 * 60 * 60 * 360 diff --git a/web/src/app/admin/connectors/axero/page.tsx b/web/src/app/admin/connectors/axero/page.tsx index 532345d832..b434f8528d 100644 --- a/web/src/app/admin/connectors/axero/page.tsx +++ b/web/src/app/admin/connectors/axero/page.tsx @@ -1,30 +1,32 @@ "use client"; import * as Yup from "yup"; -import { AxeroIcon, LinearIcon, TrashIcon } from "@/components/icons/icons"; -import { TextFormField } from "@/components/admin/connectors/Field"; +import { AxeroIcon, TrashIcon } from "@/components/icons/icons"; +import { fetcher } from "@/lib/fetcher"; +import useSWR, { useSWRConfig } from "swr"; +import { LoadingAnimation } from "@/components/Loading"; import { HealthCheckBanner } from "@/components/health/healthcheck"; -import { CredentialForm } from "@/components/admin/connectors/CredentialForm"; import { - Credential, - ConnectorIndexingStatus, - LinearCredentialJson, + AxeroConfig, AxeroCredentialJson, + ConnectorIndexingStatus, + Credential, } from "@/lib/types"; -import useSWR, { useSWRConfig } from "swr"; -import { fetcher } from "@/lib/fetcher"; -import { LoadingAnimation } from "@/components/Loading"; import { adminDeleteCredential, linkCredential } from "@/lib/credential"; -import { ConnectorForm } from "@/components/admin/connectors/ConnectorForm"; +import { CredentialForm } from "@/components/admin/connectors/CredentialForm"; +import { + TextFormField, + TextArrayFieldBuilder, + BooleanFormField, + TextArrayField, +} from "@/components/admin/connectors/Field"; import { ConnectorsTable } from "@/components/admin/connectors/table/ConnectorsTable"; -import { usePopup } from "@/components/admin/connectors/Popup"; +import { ConnectorForm } from "@/components/admin/connectors/ConnectorForm"; import { usePublicCredentials } from "@/lib/hooks"; -import { Card, Text, Title } from "@tremor/react"; +import { Button, Card, Divider, Text, Title } from "@tremor/react"; import { AdminPageTitle } from "@/components/admin/Title"; -const Main = () => { - const { popup, setPopup } = usePopup(); - +const MainSection = () => { const { mutate } = useSWRConfig(); const { data: connectorIndexingStatuses, @@ -32,21 +34,19 @@ const Main = () => { error: isConnectorIndexingStatusesError, } = useSWR[]>( "/api/manage/admin/connector/indexing-status", - fetcher, - { refreshInterval: 5000 } // 5 seconds + fetcher ); + const { data: credentialsData, isLoading: isCredentialsLoading, error: isCredentialsError, - isValidating: isCredentialsValidating, refreshCredentials, } = usePublicCredentials(); if ( - isConnectorIndexingStatusesLoading || - isCredentialsLoading || - isCredentialsValidating + (!connectorIndexingStatuses && isConnectorIndexingStatusesLoading) || + (!credentialsData && isCredentialsLoading) ) { return ; } @@ -60,7 +60,7 @@ const Main = () => { } const axeroConnectorIndexingStatuses: ConnectorIndexingStatus< - {}, + AxeroConfig, AxeroCredentialJson >[] = connectorIndexingStatuses.filter( (connectorIndexingStatus) => @@ -73,40 +73,32 @@ const Main = () => { return ( <> - {popup} - Step 1: Provide your Credentials + Step 1: Provide Axero API Key - {axeroCredential ? ( <>
- Existing API Key: - - {axeroCredential.credential_json?.axero_api_token} + Existing Axero API Key: + + {axeroCredential.credential_json.axero_api_token} - +
) : ( <> - +

To use the Axero connector, first follow the guide{" "} { here {" "} to generate an API Key. - - +

+ formBody={ <> + { } validationSchema={Yup.object().shape({ + base_url: Yup.string().required( + "Please enter the base URL of your Axero instance" + ), axero_api_token: Yup.string().required( - "Please enter your Axero API Key!" + "Please enter your Axero API Token" ), })} initialValues={{ + base_url: "", axero_api_token: "", }} onSubmit={(isSuccess) => { @@ -147,79 +144,93 @@ const Main = () => { )} - Step 2: Start indexing + Step 2: Which spaces do you want to connect? - {axeroCredential ? ( + + {axeroConnectorIndexingStatuses.length > 0 && ( <> - {axeroConnectorIndexingStatuses.length > 0 ? ( - <> - - We pull the latest Articles, Blogs, and{" "} - Wikis every 10 minutes. - -
- - connectorIndexingStatuses={axeroConnectorIndexingStatuses} - liveCredential={axeroCredential} - getCredential={(credential) => { - return ( -
-

{credential.credential_json.axero_api_token}

-
- ); - }} - onCredentialLink={async (connectorId) => { - if (axeroCredential) { - await linkCredential(connectorId, axeroCredential.id); - mutate("/api/manage/admin/connector/indexing-status"); - } - }} - onUpdate={() => - mutate("/api/manage/admin/connector/indexing-status") - } - /> -
- - ) : ( - -

Create Connector

-

- Press connect below to start the connection Axero. We pull the - latest Articles, Blogs, and Wikis every{" "} - 10 minutes. -

- - nameBuilder={() => "AxeroConnector"} - ccPairNameBuilder={() => "Axero"} - source="axero" - inputType="poll" - formBody={ - <> - - + + We pull the latest Articles, Blogs, and Wikis{" "} + every 10 minutes. + +
+ + connectorIndexingStatuses={axeroConnectorIndexingStatuses} + liveCredential={axeroCredential} + getCredential={(credential) => + credential.credential_json.axero_api_token + } + specialColumns={[ + { + header: "Space", + key: "spaces", + getValue: (ccPairStatus) => { + const connectorConfig = + ccPairStatus.connector.connector_specific_config; + return connectorConfig.spaces && + connectorConfig.spaces.length > 0 + ? connectorConfig.spaces.join(", ") + : ""; + }, + }, + ]} + onUpdate={() => + mutate("/api/manage/admin/connector/indexing-status") + } + onCredentialLink={async (connectorId) => { + if (axeroCredential) { + await linkCredential(connectorId, axeroCredential.id); + mutate("/api/manage/admin/connector/indexing-status"); } - validationSchema={Yup.object().shape({})} - initialValues={{ - base_url: "", - }} - refreshFreq={10 * 60} // 10 minutes - credentialId={axeroCredential.id} - /> - - )} + }} + /> +
+ + )} + + {axeroCredential ? ( + +

Configure an Axero Connector

+ + nameBuilder={(values) => + values.spaces + ? `AxeroConnector-${values.spaces.join("_")}` + : `AxeroConnector` + } + source="axero" + inputType="poll" + formBodyBuilder={(values) => { + return ( + <> + + {TextArrayFieldBuilder({ + name: "spaces", + label: "Space IDs:", + subtext: ` + Specify zero or more Spaces to index (by the Space IDs). If no Space IDs + are specified, all Spaces will be indexed.`, + })(values)} + + ); + }} + validationSchema={Yup.object().shape({ + spaces: Yup.array() + .of(Yup.string().required("Space Ids cannot be empty")) + .required(), + })} + initialValues={{ + spaces: [], + }} + refreshFreq={10 * 60} // 10 minutes + credentialId={axeroCredential.id} + /> +
) : ( - <> - - Please provide your access token in Step 1 first! Once done with - that, you can then start indexing Linear. - - + + Please provide your Axero API Token in Step 1 first! Once done with + that, you can then specify which spaces you want to connect. + )} ); @@ -234,7 +245,7 @@ export default function Page() { } title="Axero" /> -
+
); } diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index f06e4515a2..d09ad6c906 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -113,6 +113,10 @@ export interface SharepointConfig { sites?: string[]; } +export interface AxeroConfig { + spaces?: string[]; +} + export interface ProductboardConfig {} export interface SlackConfig { @@ -329,6 +333,7 @@ export interface SharepointCredentialJson { } export interface AxeroCredentialJson { + base_url: string; axero_api_token: string; } From 29f251660b3ea439a090d1242fbe387f579a0a66 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sun, 31 Mar 2024 15:32:22 -0700 Subject: [PATCH 31/58] Trivi Security Scan (#1277) --- ...er-build-push-backend-container-on-tag.yml | 1 + backend/.trivyignore | 46 +++++++++++++++++++ backend/Dockerfile | 7 ++- 3 files changed, 52 insertions(+), 2 deletions(-) create mode 100644 backend/.trivyignore diff --git a/.github/workflows/docker-build-push-backend-container-on-tag.yml b/.github/workflows/docker-build-push-backend-container-on-tag.yml index 82aa24e6ab..6f4716ea39 100644 --- a/.github/workflows/docker-build-push-backend-container-on-tag.yml +++ b/.github/workflows/docker-build-push-backend-container-on-tag.yml @@ -41,3 +41,4 @@ jobs: # To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend image-ref: docker.io/danswer/danswer-backend:${{ github.ref_name }} severity: 'CRITICAL,HIGH' + trivyignores: ../../backend/.trivyignore diff --git a/backend/.trivyignore b/backend/.trivyignore new file mode 100644 index 0000000000..e8351b4074 --- /dev/null +++ b/backend/.trivyignore @@ -0,0 +1,46 @@ +# https://github.com/madler/zlib/issues/868 +# Pulled in with base Debian image, it's part of the contrib folder but unused +# zlib1g is fine +# Will be gone with Debian image upgrade +# No impact in our settings +CVE-2023-45853 + +# krb5 related, worst case is denial of service by resource exhaustion +# Accept the risk +CVE-2024-26458 +CVE-2024-26461 +CVE-2024-26462 +CVE-2024-26458 +CVE-2024-26461 +CVE-2024-26462 +CVE-2024-26458 +CVE-2024-26461 +CVE-2024-26462 +CVE-2024-26458 +CVE-2024-26461 +CVE-2024-26462 + +# Specific to Firefox which we do not use +# No impact in our settings +CVE-2024-0743 + +# bind9 related, worst case is denial of service by CPU resource exhaustion +# Accept the risk +CVE-2023-50387 +CVE-2023-50868 +CVE-2023-50387 +CVE-2023-50868 + +# libexpat1, XML parsing resource exhaustion +# We don't parse any user provided XMLs +# No impact in our settings +CVE-2023-52425 +CVE-2024-28757 + +# sqlite, only used by NLTK library to grab word lemmatizer and stopwords +# No impact in our settings +CVE-2023-7104 + +# libharfbuzz0b, O(n^2) growth, worst case is denial of service +# Accept the risk +CVE-2023-25193 diff --git a/backend/Dockerfile b/backend/Dockerfile index a9bc852a5a..d18bd3ecdb 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -12,7 +12,9 @@ RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}" # zip for Vespa step futher down # ca-certificates for HTTPS RUN apt-get update && \ - apt-get install -y cmake curl zip ca-certificates libgnutls30=3.7.9-2+deb12u2 && \ + apt-get install -y cmake curl zip ca-certificates libgnutls30=3.7.9-2+deb12u2 \ + libblkid1=2.38.1-5+deb12u1 libmount1=2.38.1-5+deb12u1 libsmartcols1=2.38.1-5+deb12u1 \ + libuuid1=2.38.1-5+deb12u1 && \ rm -rf /var/lib/apt/lists/* && \ apt-get clean @@ -29,7 +31,8 @@ RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \ # xserver-common and xvfb included by playwright installation but not needed after # perl-base is part of the base Python Debian image but not needed for Danswer functionality # perl-base could only be removed with --allow-remove-essential -RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cmake libldap-2.5-0 libldap-2.5-0 && \ +RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cmake \ + libldap-2.5-0 libldap-2.5-0 && \ apt-get autoremove -y && \ rm -rf /var/lib/apt/lists/* && \ rm /usr/local/lib/python3.11/site-packages/tornado/test/test.key From b8af1377ba7bef1c3218adbd005158d2acfc14b8 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sun, 31 Mar 2024 16:22:48 -0700 Subject: [PATCH 32/58] Trivy Ignore Path (#1278) --- .../workflows/docker-build-push-backend-container-on-tag.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docker-build-push-backend-container-on-tag.yml b/.github/workflows/docker-build-push-backend-container-on-tag.yml index 6f4716ea39..a7d46a0973 100644 --- a/.github/workflows/docker-build-push-backend-container-on-tag.yml +++ b/.github/workflows/docker-build-push-backend-container-on-tag.yml @@ -41,4 +41,4 @@ jobs: # To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend image-ref: docker.io/danswer/danswer-backend:${{ github.ref_name }} severity: 'CRITICAL,HIGH' - trivyignores: ../../backend/.trivyignore + trivyignores: ./backend/.trivyignore From 32f55ddb8f619173147597808ea28128529752bb Mon Sep 17 00:00:00 2001 From: Weves Date: Sun, 31 Mar 2024 17:49:26 -0700 Subject: [PATCH 33/58] URL-based chat seeding --- backend/danswer/chat/process_message.py | 10 ++- .../slack/handlers/handle_message.py | 6 +- backend/danswer/llm/answering/answer.py | 26 ++++--- backend/danswer/llm/answering/doc_pruning.py | 14 ++-- backend/danswer/llm/answering/models.py | 66 ++++++++++++++++ .../llm/answering/prompts/citations_prompt.py | 76 ++++++++++--------- .../llm/answering/prompts/quotes_prompt.py | 8 +- backend/danswer/llm/factory.py | 11 ++- backend/danswer/llm/utils.py | 10 ++- .../one_shot_answer/answer_question.py | 6 +- backend/danswer/prompts/prompt_utils.py | 3 +- .../server/query_and_chat/chat_backend.py | 6 +- .../danswer/server/query_and_chat/models.py | 15 ++++ web/src/app/chat/Chat.tsx | 40 ++++++++-- web/src/app/chat/lib.tsx | 74 +++++++++++++++--- web/src/app/chat/searchParams.ts | 22 ++++++ .../app/chat/sessionSidebar/ChatSidebar.tsx | 2 +- 17 files changed, 307 insertions(+), 88 deletions(-) create mode 100644 web/src/app/chat/searchParams.ts diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 270afc67e2..d9e7f9b6c4 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -34,7 +34,9 @@ from danswer.llm.answering.models import AnswerStyleConfig from danswer.llm.answering.models import CitationConfig from danswer.llm.answering.models import DocumentPruningConfig +from danswer.llm.answering.models import LLMConfig from danswer.llm.answering.models import PreviousMessage +from danswer.llm.answering.models import PromptConfig from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.utils import get_default_llm_tokenizer @@ -343,8 +345,12 @@ def stream_chat_message_objects( ), document_pruning_config=document_pruning_config, ), - prompt=final_msg.prompt, - persona=persona, + prompt_config=PromptConfig.from_model( + final_msg.prompt, prompt_override=new_msg_req.prompt_override + ), + llm_config=LLMConfig.from_persona( + persona, llm_override=new_msg_req.llm_override + ), message_history=[ PreviousMessage.from_chat_message(msg) for msg in history_msgs ], diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index b3fdb79c88..6d21e28d0f 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -38,7 +38,9 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import SlackBotConfig from danswer.db.models import SlackBotResponseType -from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens +from danswer.llm.answering.prompts.citations_prompt import ( + compute_max_document_tokens_for_persona, +) from danswer.llm.utils import check_number_of_tokens from danswer.llm.utils import get_default_llm_version from danswer.llm.utils import get_max_input_tokens @@ -247,7 +249,7 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse: query_text = new_message_request.messages[0].message if persona: - max_document_tokens = compute_max_document_tokens( + max_document_tokens = compute_max_document_tokens_for_persona( persona=persona, actual_user_input=query_text, max_llm_token_override=remaining_tokens, diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index 76d399d8bd..44eae76df2 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -10,11 +10,11 @@ from danswer.chat.models import LlmDoc from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE from danswer.configs.chat_configs import QA_TIMEOUT -from danswer.db.models import Persona -from danswer.db.models import Prompt from danswer.llm.answering.doc_pruning import prune_documents from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import LLMConfig from danswer.llm.answering.models import PreviousMessage +from danswer.llm.answering.models import PromptConfig from danswer.llm.answering.models import StreamProcessor from danswer.llm.answering.prompts.citations_prompt import build_citations_prompt from danswer.llm.answering.prompts.quotes_prompt import ( @@ -51,8 +51,8 @@ def __init__( question: str, docs: list[LlmDoc], answer_style_config: AnswerStyleConfig, - prompt: Prompt, - persona: Persona, + llm_config: LLMConfig, + prompt_config: PromptConfig, # must be the same length as `docs`. If None, all docs are considered "relevant" doc_relevance_list: list[bool] | None = None, message_history: list[PreviousMessage] | None = None, @@ -72,16 +72,17 @@ def __init__( self.single_message_history = single_message_history self.answer_style_config = answer_style_config + self.llm_config = llm_config + self.prompt_config = prompt_config self.llm = get_default_llm( - gen_ai_model_version_override=persona.llm_model_version_override, + gen_ai_model_provider=self.llm_config.model_provider, + gen_ai_model_version_override=self.llm_config.model_version, timeout=timeout, + temperature=self.llm_config.temperature, ) self.llm_tokenizer = get_default_llm_tokenizer() - self.prompt = prompt - self.persona = persona - self.process_stream_fn = _get_stream_processor(docs, answer_style_config) self._final_prompt: list[BaseMessage] | None = None @@ -99,7 +100,8 @@ def pruned_docs(self) -> list[LlmDoc]: self._pruned_docs = prune_documents( docs=self.docs, doc_relevance_list=self.doc_relevance_list, - persona=self.persona, + prompt_config=self.prompt_config, + llm_config=self.llm_config, question=self.question, document_pruning_config=self.answer_style_config.document_pruning_config, ) @@ -114,8 +116,8 @@ def final_prompt(self) -> list[BaseMessage]: self._final_prompt = build_citations_prompt( question=self.question, message_history=self.message_history, - persona=self.persona, - prompt=self.prompt, + llm_config=self.llm_config, + prompt_config=self.prompt_config, context_docs=self.pruned_docs, all_doc_useful=self.answer_style_config.citation_config.all_docs_useful, llm_tokenizer_encode_func=self.llm_tokenizer.encode, @@ -126,7 +128,7 @@ def final_prompt(self) -> list[BaseMessage]: question=self.question, context_docs=self.pruned_docs, history_str=self.single_message_history or "", - prompt=self.prompt, + prompt=self.prompt_config, ) return cast(list[BaseMessage], self._final_prompt) diff --git a/backend/danswer/llm/answering/doc_pruning.py b/backend/danswer/llm/answering/doc_pruning.py index 29c913673d..f1007d19e5 100644 --- a/backend/danswer/llm/answering/doc_pruning.py +++ b/backend/danswer/llm/answering/doc_pruning.py @@ -6,9 +6,10 @@ ) from danswer.configs.constants import IGNORE_FOR_QA from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE -from danswer.db.models import Persona from danswer.indexing.models import InferenceChunk from danswer.llm.answering.models import DocumentPruningConfig +from danswer.llm.answering.models import LLMConfig +from danswer.llm.answering.models import PromptConfig from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens from danswer.llm.utils import get_default_llm_tokenizer from danswer.llm.utils import tokenizer_trim_content @@ -28,14 +29,15 @@ class PruningError(Exception): def _compute_limit( - persona: Persona, + prompt_config: PromptConfig, + llm_config: LLMConfig, question: str, max_chunks: int | None, max_window_percentage: float | None, max_tokens: int | None, ) -> int: llm_max_document_tokens = compute_max_document_tokens( - persona=persona, actual_user_input=question + prompt_config=prompt_config, llm_config=llm_config, actual_user_input=question ) window_percentage_based_limit = ( @@ -183,7 +185,8 @@ def _apply_pruning( def prune_documents( docs: list[LlmDoc], doc_relevance_list: list[bool] | None, - persona: Persona, + prompt_config: PromptConfig, + llm_config: LLMConfig, question: str, document_pruning_config: DocumentPruningConfig, ) -> list[LlmDoc]: @@ -191,7 +194,8 @@ def prune_documents( assert len(docs) == len(doc_relevance_list) doc_token_limit = _compute_limit( - persona=persona, + prompt_config=prompt_config, + llm_config=llm_config, question=question, max_chunks=document_pruning_config.max_chunks, max_window_percentage=document_pruning_config.max_window_percentage, diff --git a/backend/danswer/llm/answering/models.py b/backend/danswer/llm/answering/models.py index 360535ac80..f7f2bbad99 100644 --- a/backend/danswer/llm/answering/models.py +++ b/backend/danswer/llm/answering/models.py @@ -9,9 +9,15 @@ from danswer.chat.models import AnswerQuestionStreamReturn from danswer.configs.constants import MessageType +from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER +from danswer.llm.utils import get_default_llm_version +from danswer.server.query_and_chat.models import LLMOverride +from danswer.server.query_and_chat.models import PromptOverride if TYPE_CHECKING: from danswer.db.models import ChatMessage + from danswer.db.models import Prompt + from danswer.db.models import Persona StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn] @@ -75,3 +81,63 @@ def check_quotes_and_citation(cls, values: dict[str, Any]) -> dict[str, Any]: ) return values + + +class LLMConfig(BaseModel): + """Final representation of the LLM configuration passed into + the `Answer` object.""" + + model_provider: str + model_version: str + temperature: float + + @classmethod + def from_persona( + cls, persona: "Persona", llm_override: LLMOverride | None = None + ) -> "LLMConfig": + model_provider_override = llm_override.model_provider if llm_override else None + model_version_override = llm_override.model_version if llm_override else None + temperature_override = llm_override.temperature if llm_override else None + + return cls( + model_provider=model_provider_override or GEN_AI_MODEL_PROVIDER, + model_version=( + model_version_override + or persona.llm_model_version_override + or get_default_llm_version()[0] + ), + temperature=temperature_override or 0.0, + ) + + class Config: + frozen = True + + +class PromptConfig(BaseModel): + """Final representation of the Prompt configuration passed + into the `Answer` object.""" + + system_prompt: str + task_prompt: str + datetime_aware: bool + include_citations: bool + + @classmethod + def from_model( + cls, model: "Prompt", prompt_override: PromptOverride | None = None + ) -> "PromptConfig": + override_system_prompt = ( + prompt_override.system_prompt if prompt_override else None + ) + override_task_prompt = prompt_override.task_prompt if prompt_override else None + + return cls( + system_prompt=override_system_prompt or model.system_prompt, + task_prompt=override_task_prompt or model.task_prompt, + datetime_aware=model.datetime_aware, + include_citations=model.include_citations, + ) + + # needed so that this can be passed into lru_cache funcs + class Config: + frozen = True diff --git a/backend/danswer/llm/answering/prompts/citations_prompt.py b/backend/danswer/llm/answering/prompts/citations_prompt.py index 61c42c19c7..60f1e1098f 100644 --- a/backend/danswer/llm/answering/prompts/citations_prompt.py +++ b/backend/danswer/llm/answering/prompts/citations_prompt.py @@ -11,12 +11,12 @@ from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS from danswer.db.chat import get_default_prompt from danswer.db.models import Persona -from danswer.db.models import Prompt from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.models import LLMConfig from danswer.llm.answering.models import PreviousMessage +from danswer.llm.answering.models import PromptConfig from danswer.llm.utils import check_number_of_tokens from danswer.llm.utils import get_default_llm_tokenizer -from danswer.llm.utils import get_default_llm_version from danswer.llm.utils import get_max_input_tokens from danswer.llm.utils import translate_history_to_basemessages from danswer.prompts.chat_prompts import ADDITIONAL_INFO @@ -92,16 +92,16 @@ def drop_messages_history_overflow( return prompt -def get_prompt_tokens(prompt: Prompt) -> int: +def get_prompt_tokens(prompt_config: PromptConfig) -> int: # Note: currently custom prompts do not allow datetime aware, only default prompts return ( - check_number_of_tokens(prompt.system_prompt) - + check_number_of_tokens(prompt.task_prompt) + check_number_of_tokens(prompt_config.system_prompt) + + check_number_of_tokens(prompt_config.task_prompt) + CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT + CITATION_STATEMENT_TOKEN_CNT + CITATION_REMINDER_TOKEN_CNT + (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0) - + (ADDITIONAL_INFO_TOKEN_CNT if prompt.datetime_aware else 0) + + (ADDITIONAL_INFO_TOKEN_CNT if prompt_config.datetime_aware else 0) ) @@ -111,7 +111,8 @@ def get_prompt_tokens(prompt: Prompt) -> int: def compute_max_document_tokens( - persona: Persona, + prompt_config: PromptConfig, + llm_config: LLMConfig, actual_user_input: str | None = None, max_llm_token_override: int | None = None, ) -> int: @@ -126,21 +127,13 @@ def compute_max_document_tokens( if we're trying to determine if the user should be able to select another document) then we just set an arbitrary "upper bound". """ - llm_name = get_default_llm_version()[0] - if persona.llm_model_version_override: - llm_name = persona.llm_model_version_override - # if we can't find a number of tokens, just assume some common default max_input_tokens = ( max_llm_token_override if max_llm_token_override - else get_max_input_tokens(model_name=llm_name) + else get_max_input_tokens(model_name=llm_config.model_version) ) - if persona.prompts: - # TODO this may not always be the first prompt - prompt_tokens = get_prompt_tokens(persona.prompts[0]) - else: - prompt_tokens = get_prompt_tokens(get_default_prompt()) + prompt_tokens = get_prompt_tokens(prompt_config) user_input_tokens = ( check_number_of_tokens(actual_user_input) @@ -151,31 +144,44 @@ def compute_max_document_tokens( return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER -def compute_max_llm_input_tokens(persona: Persona) -> int: +def compute_max_document_tokens_for_persona( + persona: Persona, + actual_user_input: str | None = None, + max_llm_token_override: int | None = None, +) -> int: + prompt = persona.prompts[0] if persona.prompts else get_default_prompt() + return compute_max_document_tokens( + prompt_config=PromptConfig.from_model(prompt), + llm_config=LLMConfig.from_persona(persona), + actual_user_input=actual_user_input, + max_llm_token_override=max_llm_token_override, + ) + + +def compute_max_llm_input_tokens(llm_config: LLMConfig) -> int: """Maximum tokens allows in the input to the LLM (of any type).""" - llm_name = get_default_llm_version()[0] - if persona.llm_model_version_override: - llm_name = persona.llm_model_version_override - input_tokens = get_max_input_tokens(model_name=llm_name) + input_tokens = get_max_input_tokens( + model_name=llm_config.model_version, model_provider=llm_config.model_provider + ) return input_tokens - _MISC_BUFFER @lru_cache() def build_system_message( - prompt: Prompt, + prompt_config: PromptConfig, context_exists: bool, llm_tokenizer_encode_func: Callable, citation_line: str = REQUIRE_CITATION_STATEMENT, no_citation_line: str = NO_CITATION_STATEMENT, ) -> tuple[SystemMessage | None, int]: - system_prompt = prompt.system_prompt.strip() - if prompt.include_citations: + system_prompt = prompt_config.system_prompt.strip() + if prompt_config.include_citations: if context_exists: system_prompt += citation_line else: system_prompt += no_citation_line - if prompt.datetime_aware: + if prompt_config.datetime_aware: if system_prompt: system_prompt += ADDITIONAL_INFO.format( datetime_info=get_current_llm_day_time() @@ -194,7 +200,7 @@ def build_system_message( def build_user_message( question: str, - prompt: Prompt, + prompt_config: PromptConfig, context_docs: list[LlmDoc] | list[InferenceChunk], all_doc_useful: bool, history_message: str, @@ -206,9 +212,9 @@ def build_user_message( # Simpler prompt for cases where there is no context user_prompt = ( CHAT_USER_CONTEXT_FREE_PROMPT.format( - task_prompt=prompt.task_prompt, user_query=question + task_prompt=prompt_config.task_prompt, user_query=question ) - if prompt.task_prompt + if prompt_config.task_prompt else question ) user_prompt = user_prompt.strip() @@ -219,7 +225,7 @@ def build_user_message( context_docs_str = build_complete_context_str(context_docs) optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT - task_prompt_with_reminder = build_task_prompt_reminders(prompt) + task_prompt_with_reminder = build_task_prompt_reminders(prompt_config) user_prompt = CITATIONS_PROMPT.format( optional_ignore_statement=optional_ignore, @@ -239,8 +245,8 @@ def build_user_message( def build_citations_prompt( question: str, message_history: list[PreviousMessage], - persona: Persona, - prompt: Prompt, + prompt_config: PromptConfig, + llm_config: LLMConfig, context_docs: list[LlmDoc] | list[InferenceChunk], all_doc_useful: bool, history_message: str, @@ -249,7 +255,7 @@ def build_citations_prompt( context_exists = len(context_docs) > 0 system_message_or_none, system_tokens = build_system_message( - prompt=prompt, + prompt_config=prompt_config, context_exists=context_exists, llm_tokenizer_encode_func=llm_tokenizer_encode_func, ) @@ -262,7 +268,7 @@ def build_citations_prompt( # Is the same as passed in later for extracting citations user_message, user_tokens = build_user_message( question=question, - prompt=prompt, + prompt_config=prompt_config, context_docs=context_docs, all_doc_useful=all_doc_useful, history_message=history_message, @@ -275,7 +281,7 @@ def build_citations_prompt( history_token_counts=history_token_counts, final_msg=user_message, final_msg_token_count=user_tokens, - max_allowed_tokens=compute_max_llm_input_tokens(persona), + max_allowed_tokens=compute_max_llm_input_tokens(llm_config), ) return final_prompt_msgs diff --git a/backend/danswer/llm/answering/prompts/quotes_prompt.py b/backend/danswer/llm/answering/prompts/quotes_prompt.py index c9e145e810..0824ffa646 100644 --- a/backend/danswer/llm/answering/prompts/quotes_prompt.py +++ b/backend/danswer/llm/answering/prompts/quotes_prompt.py @@ -4,8 +4,8 @@ from danswer.chat.models import LlmDoc from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE -from danswer.db.models import Prompt from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.models import PromptConfig from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK from danswer.prompts.direct_qa_prompts import JSON_PROMPT @@ -18,7 +18,7 @@ def _build_weak_llm_quotes_prompt( question: str, context_docs: list[LlmDoc] | list[InferenceChunk], history_str: str, - prompt: Prompt, + prompt: PromptConfig, use_language_hint: bool, ) -> list[BaseMessage]: """Since Danswer supports a variety of LLMs, this less demanding prompt is provided @@ -43,7 +43,7 @@ def _build_strong_llm_quotes_prompt( question: str, context_docs: list[LlmDoc] | list[InferenceChunk], history_str: str, - prompt: Prompt, + prompt: PromptConfig, use_language_hint: bool, ) -> list[BaseMessage]: context_block = "" @@ -70,7 +70,7 @@ def build_quotes_prompt( question: str, context_docs: list[LlmDoc] | list[InferenceChunk], history_str: str, - prompt: Prompt, + prompt: PromptConfig, use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), ) -> list[BaseMessage]: prompt_builder = ( diff --git a/backend/danswer/llm/factory.py b/backend/danswer/llm/factory.py index 19c6ac7327..f274aa7901 100644 --- a/backend/danswer/llm/factory.py +++ b/backend/danswer/llm/factory.py @@ -1,6 +1,7 @@ from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.chat_configs import QA_TIMEOUT from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER +from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.llm.chat_llm import DefaultMultiLLM from danswer.llm.custom_llm import CustomModelServer from danswer.llm.exceptions import GenAIDisabledException @@ -14,6 +15,7 @@ def get_default_llm( gen_ai_model_provider: str = GEN_AI_MODEL_PROVIDER, api_key: str | None = None, timeout: int = QA_TIMEOUT, + temperature: float = GEN_AI_TEMPERATURE, use_fast_llm: bool = False, gen_ai_model_version_override: str | None = None, ) -> LLM: @@ -34,8 +36,13 @@ def get_default_llm( return CustomModelServer(api_key=api_key, timeout=timeout) if gen_ai_model_provider.lower() == "gpt4all": - return DanswerGPT4All(model_version=model_version, timeout=timeout) + return DanswerGPT4All( + model_version=model_version, timeout=timeout, temperature=temperature + ) return DefaultMultiLLM( - model_version=model_version, api_key=api_key, timeout=timeout + model_version=model_version, + api_key=api_key, + timeout=timeout, + temperature=temperature, ) diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 507bbbff62..b41c85b9ed 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -4,6 +4,8 @@ from functools import lru_cache from typing import Any from typing import cast +from typing import TYPE_CHECKING +from typing import Union import litellm # type: ignore import tiktoken @@ -33,10 +35,12 @@ from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.indexing.models import InferenceChunk -from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM from danswer.utils.logger import setup_logger +if TYPE_CHECKING: + from danswer.llm.answering.models import PreviousMessage + logger = setup_logger() _LLM_TOKENIZER: Any = None @@ -116,7 +120,7 @@ def tokenizer_trim_chunks( def translate_danswer_msg_to_langchain( - msg: ChatMessage | PreviousMessage, + msg: Union[ChatMessage, "PreviousMessage"], ) -> BaseMessage: if msg.message_type == MessageType.SYSTEM: raise ValueError("System messages are not currently part of history") @@ -129,7 +133,7 @@ def translate_danswer_msg_to_langchain( def translate_history_to_basemessages( - history: list[ChatMessage] | list[PreviousMessage], + history: list[ChatMessage] | list["PreviousMessage"], ) -> tuple[list[BaseMessage], list[int]]: history_basemessages = [ translate_danswer_msg_to_langchain(msg) diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index e37cc0e435..8fd5a1c0dd 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -28,6 +28,8 @@ from danswer.llm.answering.models import AnswerStyleConfig from danswer.llm.answering.models import CitationConfig from danswer.llm.answering.models import DocumentPruningConfig +from danswer.llm.answering.models import LLMConfig +from danswer.llm.answering.models import PromptConfig from danswer.llm.answering.models import QuotesConfig from danswer.llm.utils import get_default_llm_token_encode from danswer.one_shot_answer.models import DirectQARequest @@ -203,8 +205,8 @@ def stream_answer_objects( question=query_msg.message, docs=[llm_doc_from_inference_chunk(chunk) for chunk in top_chunks], answer_style_config=answer_config, - prompt=prompt, - persona=chat_session.persona, + prompt_config=PromptConfig.from_model(prompt), + llm_config=LLMConfig.from_persona(chat_session.persona), doc_relevance_list=search_pipeline.chunk_relevance_list, single_message_history=history_str, timeout=timeout, diff --git a/backend/danswer/prompts/prompt_utils.py b/backend/danswer/prompts/prompt_utils.py index dcc7c6f0f5..a681330153 100644 --- a/backend/danswer/prompts/prompt_utils.py +++ b/backend/danswer/prompts/prompt_utils.py @@ -6,6 +6,7 @@ from danswer.configs.constants import DocumentSource from danswer.db.models import Prompt from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.models import PromptConfig from danswer.prompts.chat_prompts import CITATION_REMINDER from danswer.prompts.constants import CODE_BLOCK_PAT from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT @@ -20,7 +21,7 @@ def get_current_llm_day_time() -> str: def build_task_prompt_reminders( - prompt: Prompt, + prompt: Prompt | PromptConfig, use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), citation_str: str = CITATION_REMINDER, language_hint_str: str = LANGUAGE_HINT, diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 4558903f7a..8a7d9e2154 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -24,7 +24,9 @@ from danswer.db.models import User from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index -from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens +from danswer.llm.answering.prompts.citations_prompt import ( + compute_max_document_tokens_for_persona, +) from danswer.secondary_llm_flows.chat_session_naming import ( get_renamed_conversation_name, ) @@ -303,5 +305,5 @@ def get_max_document_tokens( raise HTTPException(status_code=404, detail="Persona not found") return MaxSelectedDocumentTokens( - max_tokens=compute_max_document_tokens(persona), + max_tokens=compute_max_document_tokens_for_persona(persona), ) diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 054cb3179e..d77d31684a 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -67,6 +67,17 @@ class DocumentSearchRequest(BaseModel): skip_rerank: bool = False +class LLMOverride(BaseModel): + model_provider: str | None = None + model_version: str | None = None + temperature: float | None = None + + +class PromptOverride(BaseModel): + system_prompt: str | None = None + task_prompt: str | None = None + + """ Currently the different branches are generated by changing the search query @@ -98,6 +109,10 @@ class CreateChatMessageRequest(BaseModel): query_override: str | None = None no_ai_answer: bool = False + # allows the caller to override the Persona / Prompt + llm_override: LLMOverride | None = None + prompt_override: PromptOverride | None = None + @root_validator def check_search_doc_ids_or_retrieval_options(cls: BaseModel, values: dict) -> dict: search_doc_ids, retrieval_options = values.get("search_doc_ids"), values.get( diff --git a/web/src/app/chat/Chat.tsx b/web/src/app/chat/Chat.tsx index 9430c4e668..e16e873b06 100644 --- a/web/src/app/chat/Chat.tsx +++ b/web/src/app/chat/Chat.tsx @@ -13,9 +13,10 @@ import { RetrievalType, StreamingError, } from "./interfaces"; -import { useRouter } from "next/navigation"; +import { useRouter, useSearchParams } from "next/navigation"; import { FeedbackType } from "./types"; import { + buildChatUrl, createChatSession, getCitedDocumentsFromMessage, getHumanAndAIMessageFromMessageNumber, @@ -46,6 +47,7 @@ import { computeAvailableFilters } from "@/lib/filters"; import { useDocumentSelection } from "./useDocumentSelection"; import { StarterMessage } from "./StarterMessage"; import { ShareChatSessionModal } from "./modal/ShareChatSessionModal"; +import { SEARCH_PARAM_NAMES, shouldSubmitOnLoad } from "./searchParams"; const MAX_INPUT_HEIGHT = 200; @@ -71,6 +73,13 @@ export const Chat = ({ shouldhideBeforeScroll?: boolean; }) => { const router = useRouter(); + const searchParams = useSearchParams(); + // used to track whether or not the initial "submit on load" has been performed + // this only applies if `?submit-on-load=true` or `?submit-on-load=1` is in the URL + // NOTE: this is required due to React strict mode, where all `useEffect` hooks + // are run twice on initial load during development + const submitOnLoadPerformed = useRef(false); + const { popup, setPopup } = usePopup(); // fetch messages for the chat session @@ -117,6 +126,16 @@ export const Chat = ({ } setMessageHistory([]); setChatSessionSharedStatus(ChatSessionSharedStatus.Private); + + // if we're supposed to submit on initial load, then do that here + if ( + shouldSubmitOnLoad(searchParams) && + !submitOnLoadPerformed.current + ) { + submitOnLoadPerformed.current = true; + onSubmit(); + } + return; } @@ -151,7 +170,9 @@ export const Chat = ({ const [chatSessionId, setChatSessionId] = useState( existingChatSessionId ); - const [message, setMessage] = useState(""); + const [message, setMessage] = useState( + searchParams.get(SEARCH_PARAM_NAMES.USER_MESSAGE) || "" + ); const [messageHistory, setMessageHistory] = useState([]); const [isStreaming, setIsStreaming] = useState(false); @@ -385,6 +406,13 @@ export const Chat = ({ .map((document) => document.db_doc_id as number), queryOverride, forceSearch, + modelVersion: + searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) || undefined, + temperature: + parseFloat(searchParams.get(SEARCH_PARAM_NAMES.TEMPERATURE) || "") || + undefined, + systemPromptOverride: + searchParams.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined, })) { for (const packet of packetBunch) { if (Object.hasOwn(packet, "answer_piece")) { @@ -454,7 +482,7 @@ export const Chat = ({ currChatSessionId === urlChatSessionId.current || urlChatSessionId.current === null ) { - router.push(`/chat?chatId=${currChatSessionId}`, { + router.push(buildChatUrl(searchParams, currChatSessionId, null), { scroll: false, }); } @@ -550,7 +578,9 @@ export const Chat = ({ if (persona) { setSelectedPersona(persona); textareaRef.current?.focus(); - router.push(`/chat?personaId=${persona.id}`); + router.push( + buildChatUrl(searchParams, null, persona.id) + ); } }} /> @@ -577,7 +607,7 @@ export const Chat = ({ handlePersonaSelect={(persona) => { setSelectedPersona(persona); textareaRef.current?.focus(); - router.push(`/chat?personaId=${persona.id}`); + router.push(buildChatUrl(searchParams, null, persona.id)); }} /> )} diff --git a/web/src/app/chat/lib.tsx b/web/src/app/chat/lib.tsx index fcca6c0407..dbc0506fd0 100644 --- a/web/src/app/chat/lib.tsx +++ b/web/src/app/chat/lib.tsx @@ -15,6 +15,8 @@ import { StreamingError, } from "./interfaces"; import { Persona } from "../admin/personas/interfaces"; +import { ReadonlyURLSearchParams } from "next/navigation"; +import { SEARCH_PARAM_NAMES } from "./searchParams"; export async function createChatSession(personaId: number): Promise { const createChatSessionResponse = await fetch( @@ -39,17 +41,6 @@ export async function createChatSession(personaId: number): Promise { return chatSessionResponseJson.chat_session_id; } -export interface SendMessageRequest { - message: string; - parentMessageId: number | null; - chatSessionId: number; - promptId: number | null | undefined; - filters: Filters | null; - selectedDocumentIds: number[] | null; - queryOverride?: string; - forceSearch?: boolean; -} - export async function* sendMessage({ message, parentMessageId, @@ -59,7 +50,24 @@ export async function* sendMessage({ selectedDocumentIds, queryOverride, forceSearch, -}: SendMessageRequest) { + modelVersion, + temperature, + systemPromptOverride, +}: { + message: string; + parentMessageId: number | null; + chatSessionId: number; + promptId: number | null | undefined; + filters: Filters | null; + selectedDocumentIds: number[] | null; + queryOverride?: string; + forceSearch?: boolean; + // LLM overrides + modelVersion?: string; + temperature?: number; + // prompt overrides + systemPromptOverride?: string; +}) { const documentsAreSelected = selectedDocumentIds && selectedDocumentIds.length > 0; const sendMessageResponse = await fetch("/api/chat/send-message", { @@ -87,6 +95,13 @@ export async function* sendMessage({ } : null, query_override: queryOverride, + prompt_override: { + system_prompt: systemPromptOverride, + }, + llm_override: { + temperature, + model_version: modelVersion, + }, }), }); if (!sendMessageResponse.ok) { @@ -354,3 +369,38 @@ export function processRawChatHistory(rawMessages: BackendMessage[]) { export function personaIncludesRetrieval(selectedPersona: Persona) { return selectedPersona.num_chunks !== 0; } + +const PARAMS_TO_SKIP = [ + SEARCH_PARAM_NAMES.SUBMIT_ON_LOAD, + SEARCH_PARAM_NAMES.USER_MESSAGE, + // only use these if explicitly passed in + SEARCH_PARAM_NAMES.CHAT_ID, + SEARCH_PARAM_NAMES.PERSONA_ID, +]; + +export function buildChatUrl( + existingSearchParams: ReadonlyURLSearchParams, + chatSessionId: number | null, + personaId: number | null +) { + const finalSearchParams: string[] = []; + if (chatSessionId) { + finalSearchParams.push(`${SEARCH_PARAM_NAMES.CHAT_ID}=${chatSessionId}`); + } + if (personaId) { + finalSearchParams.push(`${SEARCH_PARAM_NAMES.PERSONA_ID}=${personaId}`); + } + + existingSearchParams.forEach((value, key) => { + if (!PARAMS_TO_SKIP.includes(key)) { + finalSearchParams.push(`${key}=${value}`); + } + }); + const finalSearchParamsString = finalSearchParams.join("&"); + + if (finalSearchParamsString) { + return `/chat?${finalSearchParamsString}`; + } + + return "/chat"; +} diff --git a/web/src/app/chat/searchParams.ts b/web/src/app/chat/searchParams.ts new file mode 100644 index 0000000000..9de5d9ec89 --- /dev/null +++ b/web/src/app/chat/searchParams.ts @@ -0,0 +1,22 @@ +import { ReadonlyURLSearchParams } from "next/navigation"; + +// search params +export const SEARCH_PARAM_NAMES = { + CHAT_ID: "chatId", + PERSONA_ID: "personaId", + // overrides + TEMPERATURE: "temperature", + MODEL_VERSION: "model-version", + SYSTEM_PROMPT: "system-prompt", + // user message + USER_MESSAGE: "user-message", + SUBMIT_ON_LOAD: "submit-on-load", +}; + +export function shouldSubmitOnLoad(searchParams: ReadonlyURLSearchParams) { + const rawSubmitOnLoad = searchParams.get(SEARCH_PARAM_NAMES.SUBMIT_ON_LOAD); + if (rawSubmitOnLoad === "true" || rawSubmitOnLoad === "1") { + return true; + } + return false; +} diff --git a/web/src/app/chat/sessionSidebar/ChatSidebar.tsx b/web/src/app/chat/sessionSidebar/ChatSidebar.tsx index 199fab3e50..175139c03d 100644 --- a/web/src/app/chat/sessionSidebar/ChatSidebar.tsx +++ b/web/src/app/chat/sessionSidebar/ChatSidebar.tsx @@ -117,7 +117,7 @@ export const ChatSidebar = ({ {chatSessions.map((chat) => { const isSelected = currentChatId === chat.id; return ( -
+
Date: Sun, 31 Mar 2024 21:45:46 -0700 Subject: [PATCH 34/58] More Options for Search APIs (#1280) --- backend/danswer/configs/model_configs.py | 7 ++----- .../danswerbot/slack/handlers/handle_message.py | 2 ++ backend/danswer/db/chat.py | 3 ++- .../danswer/one_shot_answer/answer_question.py | 2 ++ backend/danswer/one_shot_answer/models.py | 3 +++ backend/danswer/search/models.py | 6 ++++-- .../search/postprocessing/postprocessing.py | 1 + .../search/preprocessing/preprocessing.py | 16 ++++++++++++++-- backend/danswer/search/search_nlp_models.py | 3 +++ backend/danswer/server/query_and_chat/models.py | 4 +++- backend/requirements/default.txt | 8 ++++---- 11 files changed, 40 insertions(+), 15 deletions(-) diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index f6cd71f31d..ce79693731 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -48,11 +48,8 @@ ENABLE_RERANKING_REAL_TIME_FLOW = ( os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true" ) -# https://www.sbert.net/docs/pretrained-models/ce-msmarco.html -CROSS_ENCODER_MODEL_ENSEMBLE = [ - "cross-encoder/ms-marco-MiniLM-L-4-v2", - "cross-encoder/ms-marco-TinyBERT-L-2-v2", -] +# Only using one for now +CROSS_ENCODER_MODEL_ENSEMBLE = ["mixedbread-ai/mxbai-rerank-xsmall-v1"] # For score normalizing purposes, only way is to know the expected ranges CROSS_ENCODER_RANGE_MAX = 12 CROSS_ENCODER_RANGE_MIN = -12 diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index 6d21e28d0f..33d64d9eff 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -22,6 +22,7 @@ from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI from danswer.configs.danswerbot_configs import DISABLE_DANSWER_BOT_FILTER_DETECT from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION +from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW from danswer.danswerbot.slack.blocks import build_documents_blocks from danswer.danswerbot.slack.blocks import build_follow_up_block from danswer.danswerbot.slack.blocks import build_qa_response_blocks @@ -310,6 +311,7 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse: persona_id=persona.id if persona is not None else 0, retrieval_options=retrieval_details, chain_of_thought=not disable_cot, + skip_rerank=not ENABLE_RERANKING_ASYNC_FLOW, ) ) except Exception as e: diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index e2262f0f3d..eb2d49b4c3 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -736,7 +736,8 @@ def create_db_search_doc( boost=server_search_doc.boost, hidden=server_search_doc.hidden, doc_metadata=server_search_doc.metadata, - score=server_search_doc.score, + # For docs further down that aren't reranked, we can't use the retrieval score + score=server_search_doc.score or 0.0, match_highlights=server_search_doc.match_highlights, updated_at=server_search_doc.updated_at, primary_owners=server_search_doc.primary_owners, diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 8fd5a1c0dd..17ff3186a0 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -126,6 +126,8 @@ def stream_answer_objects( persona=chat_session.persona, offset=query_req.retrieval_options.offset, limit=query_req.retrieval_options.limit, + skip_rerank=query_req.skip_rerank, + skip_llm_chunk_filter=query_req.skip_llm_chunk_filter, ), user=user, db_session=db_session, diff --git a/backend/danswer/one_shot_answer/models.py b/backend/danswer/one_shot_answer/models.py index 0fefc5a7b3..c7f6dbe49d 100644 --- a/backend/danswer/one_shot_answer/models.py +++ b/backend/danswer/one_shot_answer/models.py @@ -27,6 +27,9 @@ class DirectQARequest(BaseModel): prompt_id: int | None persona_id: int retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails) + # This is to forcibly skip (or run) the step, if None it uses the system defaults + skip_rerank: bool | None = None + skip_llm_chunk_filter: bool | None = None chain_of_thought: bool = False return_contexts: bool = False diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index d199a3b6bb..c59dbf4dab 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -58,7 +58,9 @@ class SearchRequest(BaseModel): recency_bias_multiplier: float = 1.0 hybrid_alpha: float = HYBRID_ALPHA - skip_rerank: bool = True + # This is to forcibly skip (or run) the step, if None it uses the system defaults + skip_rerank: bool | None = None + skip_llm_chunk_filter: bool | None = None class Config: arbitrary_types_allowed = True @@ -72,9 +74,9 @@ class SearchQuery(BaseModel): offset: int = 0 search_type: SearchType = SearchType.HYBRID skip_rerank: bool = not ENABLE_RERANKING_REAL_TIME_FLOW + skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER # Only used if not skip_rerank num_rerank: int | None = NUM_RERANKED_RESULTS - skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER # Only used if not skip_llm_chunk_filter max_llm_filter_chunks: int = NUM_RERANKED_RESULTS diff --git a/backend/danswer/search/postprocessing/postprocessing.py b/backend/danswer/search/postprocessing/postprocessing.py index 8b9c5617cc..13303c3be9 100644 --- a/backend/danswer/search/postprocessing/postprocessing.py +++ b/backend/danswer/search/postprocessing/postprocessing.py @@ -158,6 +158,7 @@ def search_postprocessing( post_processing_tasks: list[FunctionCall] = [] rerank_task_id = None + chunks_yielded = False if should_rerank(search_query): post_processing_tasks.append( FunctionCall( diff --git a/backend/danswer/search/preprocessing/preprocessing.py b/backend/danswer/search/preprocessing/preprocessing.py index f35afe4389..4fb2665a83 100644 --- a/backend/danswer/search/preprocessing/preprocessing.py +++ b/backend/danswer/search/preprocessing/preprocessing.py @@ -5,6 +5,7 @@ from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER from danswer.configs.chat_configs import NUM_RETURNED_HITS +from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW from danswer.db.models import User from danswer.search.enums import QueryFlow from danswer.search.enums import RecencyBiasSetting @@ -141,11 +142,22 @@ def retrieval_preprocessing( ) llm_chunk_filter = False - if persona: + if search_request.skip_llm_chunk_filter is not None: + llm_chunk_filter = not search_request.skip_llm_chunk_filter + elif persona: llm_chunk_filter = persona.llm_relevance_filter + if disable_llm_chunk_filter: + if llm_chunk_filter: + logger.info( + "LLM chunk filtering would have run but has been globally disabled" + ) llm_chunk_filter = False + skip_rerank = search_request.skip_rerank + if skip_rerank is None: + skip_rerank = not ENABLE_RERANKING_REAL_TIME_FLOW + # Decays at 1 / (1 + (multiplier * num years)) if persona and persona.recency_bias == RecencyBiasSetting.NO_DECAY: recency_bias_multiplier = 0.0 @@ -167,7 +179,7 @@ def retrieval_preprocessing( recency_bias_multiplier=recency_bias_multiplier, num_hits=limit if limit is not None else NUM_RETURNED_HITS, offset=offset or 0, - skip_rerank=search_request.skip_rerank, + skip_rerank=skip_rerank, skip_llm_chunk_filter=not llm_chunk_filter, ), predicted_search_type, diff --git a/backend/danswer/search/search_nlp_models.py b/backend/danswer/search/search_nlp_models.py index bc5a6fac42..783c2f8366 100644 --- a/backend/danswer/search/search_nlp_models.py +++ b/backend/danswer/search/search_nlp_models.py @@ -116,6 +116,9 @@ def get_local_reranking_model_ensemble( global _RERANK_MODELS if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length: + del _RERANK_MODELS + gc.collect() + _RERANK_MODELS = [] for model_name in model_names: logger.info(f"Loading {model_name}") diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index d77d31684a..3ee701620a 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -64,7 +64,9 @@ class DocumentSearchRequest(BaseModel): search_type: SearchType retrieval_options: RetrievalDetails recency_bias_multiplier: float = 1.0 - skip_rerank: bool = False + # This is to forcibly skip (or run) the step, if None it uses the system defaults + skip_rerank: bool | None = None + skip_llm_chunk_filter: bool | None = None class LLMOverride(BaseModel): diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 9ca893876b..638375426f 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -53,11 +53,11 @@ requests==2.31.0 requests-oauthlib==1.3.1 retry==0.9.2 # This pulls in py which is in CVE-2022-42969, must remove py from image rfc3986==1.5.0 +rt==3.1.2 # need to pin `safetensors` version, since the latest versions requires # building from source using Rust -rt==3.1.2 -safetensors==0.3.1 -sentence-transformers==2.2.2 +safetensors==0.4.2 +sentence-transformers==2.6.1 slack-sdk==3.20.2 SQLAlchemy[mypy]==2.0.15 starlette==0.36.3 @@ -67,7 +67,7 @@ tiktoken==0.4.0 timeago==1.0.16 torch==2.0.1 torchvision==0.15.2 -transformers==4.36.2 +transformers==4.39.2 uvicorn==0.21.1 zulip==0.8.2 hubspot-api-client==8.1.0 From a4869b727d47acdc7b37ee8a9fa0bcc5818acd83 Mon Sep 17 00:00:00 2001 From: Weves Date: Sun, 31 Mar 2024 21:42:30 -0700 Subject: [PATCH 35/58] Add ability to control available pages --- backend/danswer/main.py | 4 + backend/danswer/server/settings/api.py | 30 ++++ backend/danswer/server/settings/models.py | 36 +++++ backend/danswer/server/settings/store.py | 23 +++ web/next.config.js | 8 +- web/src/app/admin/settings/SettingsForm.tsx | 149 ++++++++++++++++++++ web/src/app/admin/settings/interfaces.ts | 5 + web/src/app/admin/settings/page.tsx | 33 +++++ web/src/app/chat/ChatPage.tsx | 5 +- web/src/app/chat/page.tsx | 12 +- web/src/app/chat/shared/[chatId]/page.tsx | 18 +-- web/src/app/page.tsx | 16 +++ web/src/app/search/page.tsx | 11 +- web/src/components/Dropdown.tsx | 103 -------------- web/src/components/Header.tsx | 57 +++++--- web/src/components/admin/Layout.tsx | 31 +++- web/src/lib/settings.ts | 10 ++ 17 files changed, 399 insertions(+), 152 deletions(-) create mode 100644 backend/danswer/server/settings/api.py create mode 100644 backend/danswer/server/settings/models.py create mode 100644 backend/danswer/server/settings/store.py create mode 100644 web/src/app/admin/settings/SettingsForm.tsx create mode 100644 web/src/app/admin/settings/interfaces.ts create mode 100644 web/src/app/admin/settings/page.tsx create mode 100644 web/src/app/page.tsx create mode 100644 web/src/lib/settings.ts diff --git a/backend/danswer/main.py b/backend/danswer/main.py index ad4584774a..90abab7372 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -76,6 +76,8 @@ admin_router as admin_query_router, ) from danswer.server.query_and_chat.query_backend import basic_router as query_router +from danswer.server.settings.api import admin_router as settings_admin_router +from danswer.server.settings.api import basic_router as settings_router from danswer.utils.logger import setup_logger from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType @@ -279,6 +281,8 @@ def get_application() -> FastAPI: include_router_with_global_prefix_prepended(application, state_router) include_router_with_global_prefix_prepended(application, danswer_api_router) include_router_with_global_prefix_prepended(application, gpts_router) + include_router_with_global_prefix_prepended(application, settings_router) + include_router_with_global_prefix_prepended(application, settings_admin_router) if AUTH_TYPE == AuthType.DISABLED: # Server logs this during auth setup verification step diff --git a/backend/danswer/server/settings/api.py b/backend/danswer/server/settings/api.py new file mode 100644 index 0000000000..422e268c13 --- /dev/null +++ b/backend/danswer/server/settings/api.py @@ -0,0 +1,30 @@ +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException + +from danswer.auth.users import current_admin_user +from danswer.auth.users import current_user +from danswer.db.models import User +from danswer.server.settings.models import Settings +from danswer.server.settings.store import load_settings +from danswer.server.settings.store import store_settings + + +admin_router = APIRouter(prefix="/admin/settings") +basic_router = APIRouter(prefix="/settings") + + +@admin_router.put("") +def put_settings( + settings: Settings, _: User | None = Depends(current_admin_user) +) -> None: + try: + settings.check_validity() + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + store_settings(settings) + + +@basic_router.get("") +def fetch_settings(_: User | None = Depends(current_user)) -> Settings: + return load_settings() diff --git a/backend/danswer/server/settings/models.py b/backend/danswer/server/settings/models.py new file mode 100644 index 0000000000..041e360d72 --- /dev/null +++ b/backend/danswer/server/settings/models.py @@ -0,0 +1,36 @@ +from enum import Enum + +from pydantic import BaseModel + + +class PageType(str, Enum): + CHAT = "chat" + SEARCH = "search" + + +class Settings(BaseModel): + """General settings""" + + chat_page_enabled: bool = True + search_page_enabled: bool = True + default_page: PageType = PageType.SEARCH + + def check_validity(self) -> None: + chat_page_enabled = self.chat_page_enabled + search_page_enabled = self.search_page_enabled + default_page = self.default_page + + if chat_page_enabled is False and search_page_enabled is False: + raise ValueError( + "One of `search_page_enabled` and `chat_page_enabled` must be True." + ) + + if default_page == PageType.CHAT and chat_page_enabled is False: + raise ValueError( + "The default page cannot be 'chat' if the chat page is disabled." + ) + + if default_page == PageType.SEARCH and search_page_enabled is False: + raise ValueError( + "The default page cannot be 'search' if the search page is disabled." + ) diff --git a/backend/danswer/server/settings/store.py b/backend/danswer/server/settings/store.py new file mode 100644 index 0000000000..ead1e3652a --- /dev/null +++ b/backend/danswer/server/settings/store.py @@ -0,0 +1,23 @@ +from typing import cast + +from danswer.dynamic_configs.factory import get_dynamic_config_store +from danswer.dynamic_configs.interface import ConfigNotFoundError +from danswer.server.settings.models import Settings + + +_SETTINGS_KEY = "danswer_settings" + + +def load_settings() -> Settings: + dynamic_config_store = get_dynamic_config_store() + try: + settings = Settings(**cast(dict, dynamic_config_store.load(_SETTINGS_KEY))) + except ConfigNotFoundError: + settings = Settings() + dynamic_config_store.store(_SETTINGS_KEY, settings.dict()) + + return settings + + +def store_settings(settings: Settings) -> None: + get_dynamic_config_store().store(_SETTINGS_KEY, settings.dict()) diff --git a/web/next.config.js b/web/next.config.js index 6f7de34ae4..d7fc7a551a 100644 --- a/web/next.config.js +++ b/web/next.config.js @@ -24,13 +24,7 @@ const nextConfig = { // In production, something else (nginx in the one box setup) should take // care of this redirect. TODO (chris): better support setups where // web_server and api_server are on different machines. - const defaultRedirects = [ - { - source: "/", - destination: "/search", - permanent: true, - }, - ]; + const defaultRedirects = []; if (process.env.NODE_ENV === "production") return defaultRedirects; diff --git a/web/src/app/admin/settings/SettingsForm.tsx b/web/src/app/admin/settings/SettingsForm.tsx new file mode 100644 index 0000000000..3be9e3cb7b --- /dev/null +++ b/web/src/app/admin/settings/SettingsForm.tsx @@ -0,0 +1,149 @@ +"use client"; + +import { Label, SubLabel } from "@/components/admin/connectors/Field"; +import { Title } from "@tremor/react"; +import { Settings } from "./interfaces"; +import { useRouter } from "next/navigation"; +import { DefaultDropdown, Option } from "@/components/Dropdown"; + +function Checkbox({ + label, + sublabel, + checked, + onChange, +}: { + label: string; + sublabel: string; + checked: boolean; + onChange: (e: React.ChangeEvent) => void; +}) { + return ( + + ); +} + +function Selector({ + label, + subtext, + options, + selected, + onSelect, +}: { + label: string; + subtext: string; + options: Option[]; + selected: string; + onSelect: (value: string | number | null) => void; +}) { + return ( +
+ {label && } + {subtext && {subtext}} + +
+ +
+
+ ); +} + +export function SettingsForm({ settings }: { settings: Settings }) { + const router = useRouter(); + + async function updateSettingField( + updateRequests: { fieldName: keyof Settings; newValue: any }[] + ) { + const newValues: any = {}; + updateRequests.forEach(({ fieldName, newValue }) => { + newValues[fieldName] = newValue; + }); + + const response = await fetch("/api/admin/settings", { + method: "PUT", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + ...settings, + ...newValues, + }), + }); + if (response.ok) { + router.refresh(); + } else { + const errorMsg = (await response.json()).detail; + alert(`Failed to update settings. ${errorMsg}`); + } + } + + return ( +
+ Page Visibility + + { + const updates: any[] = [ + { fieldName: "search_page_enabled", newValue: e.target.checked }, + ]; + if (!e.target.checked && settings.default_page === "search") { + updates.push({ fieldName: "default_page", newValue: "chat" }); + } + updateSettingField(updates); + }} + /> + + { + const updates: any[] = [ + { fieldName: "chat_page_enabled", newValue: e.target.checked }, + ]; + if (!e.target.checked && settings.default_page === "chat") { + updates.push({ fieldName: "default_page", newValue: "search" }); + } + updateSettingField(updates); + }} + /> + + { + value && + updateSettingField([ + { fieldName: "default_page", newValue: value }, + ]); + }} + /> +
+ ); +} diff --git a/web/src/app/admin/settings/interfaces.ts b/web/src/app/admin/settings/interfaces.ts new file mode 100644 index 0000000000..c62a392141 --- /dev/null +++ b/web/src/app/admin/settings/interfaces.ts @@ -0,0 +1,5 @@ +export interface Settings { + chat_page_enabled: boolean; + search_page_enabled: boolean; + default_page: "search" | "chat"; +} diff --git a/web/src/app/admin/settings/page.tsx b/web/src/app/admin/settings/page.tsx new file mode 100644 index 0000000000..1a30495b5f --- /dev/null +++ b/web/src/app/admin/settings/page.tsx @@ -0,0 +1,33 @@ +import { AdminPageTitle } from "@/components/admin/Title"; +import { FiSettings } from "react-icons/fi"; +import { Settings } from "./interfaces"; +import { fetchSS } from "@/lib/utilsSS"; +import { SettingsForm } from "./SettingsForm"; +import { Callout, Text } from "@tremor/react"; + +export default async function Page() { + const response = await fetchSS("/settings"); + + if (!response.ok) { + const errorMsg = await response.text(); + return {errorMsg}; + } + + const settings = (await response.json()) as Settings; + + return ( +
+ } + /> + + + Manage general Danswer settings applicable to all users in the + workspace. + + + +
+ ); +} diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 7c132cc2c6..1fc1b4e5fb 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -9,9 +9,11 @@ import { Persona } from "../admin/personas/interfaces"; import { Header } from "@/components/Header"; import { HealthCheckBanner } from "@/components/health/healthcheck"; import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh"; +import { Settings } from "../admin/settings/interfaces"; export function ChatLayout({ user, + settings, chatSessions, availableSources, availableDocumentSets, @@ -21,6 +23,7 @@ export function ChatLayout({ documentSidebarInitialWidth, }: { user: User | null; + settings: Settings | null; chatSessions: ChatSession[]; availableSources: ValidSources[]; availableDocumentSets: DocumentSet[]; @@ -40,7 +43,7 @@ export function ChatLayout({ return ( <>
-
+
diff --git a/web/src/app/chat/page.tsx b/web/src/app/chat/page.tsx index 57de9de805..4e32d6ffcc 100644 --- a/web/src/app/chat/page.tsx +++ b/web/src/app/chat/page.tsx @@ -27,6 +27,8 @@ import { personaComparator } from "../admin/personas/lib"; import { ChatLayout } from "./ChatPage"; import { FullEmbeddingModelResponse } from "../admin/models/embedding/embeddingModels"; import { NoCompleteSourcesModal } from "@/components/initialSetup/search/NoCompleteSourceModal"; +import { getSettingsSS } from "@/lib/settings"; +import { Settings } from "../admin/settings/interfaces"; export default async function Page({ searchParams, @@ -43,6 +45,7 @@ export default async function Page({ fetchSS("/persona?include_default=true"), fetchSS("/chat/get-user-chat-sessions"), fetchSS("/query/valid-tags"), + getSettingsSS(), ]; // catch cases where the backend is completely unreachable here @@ -53,8 +56,9 @@ export default async function Page({ | Response | AuthTypeMetadata | FullEmbeddingModelResponse + | Settings | null - )[] = [null, null, null, null, null, null, null, null, null]; + )[] = [null, null, null, null, null, null, null, null, null, null]; try { results = await Promise.all(tasks); } catch (e) { @@ -67,6 +71,7 @@ export default async function Page({ const personasResponse = results[4] as Response | null; const chatSessionsResponse = results[5] as Response | null; const tagsResponse = results[6] as Response | null; + const settings = results[7] as Settings | null; const authDisabled = authTypeMetadata?.authType === "disabled"; if (!authDisabled && !user) { @@ -77,6 +82,10 @@ export default async function Page({ return redirect("/auth/waiting-on-verification"); } + if (settings && !settings.chat_page_enabled) { + return redirect("/search"); + } + let ccPairs: CCPairBasicInfo[] = []; if (ccPairsResponse?.ok) { ccPairs = await ccPairsResponse.json(); @@ -172,6 +181,7 @@ export default async function Page({
-
+
diff --git a/web/src/app/page.tsx b/web/src/app/page.tsx new file mode 100644 index 0000000000..c6b291d22d --- /dev/null +++ b/web/src/app/page.tsx @@ -0,0 +1,16 @@ +import { getSettingsSS } from "@/lib/settings"; +import { redirect } from "next/navigation"; + +export default async function Page() { + const settings = await getSettingsSS(); + + if (!settings) { + redirect("/search"); + } + + if (settings.default_page === "search") { + redirect("/search"); + } else { + redirect("/chat"); + } +} diff --git a/web/src/app/search/page.tsx b/web/src/app/search/page.tsx index fa72940332..2299ea7721 100644 --- a/web/src/app/search/page.tsx +++ b/web/src/app/search/page.tsx @@ -23,6 +23,8 @@ import { personaComparator } from "../admin/personas/lib"; import { FullEmbeddingModelResponse } from "../admin/models/embedding/embeddingModels"; import { NoSourcesModal } from "@/components/initialSetup/search/NoSourcesModal"; import { NoCompleteSourcesModal } from "@/components/initialSetup/search/NoCompleteSourceModal"; +import { getSettingsSS } from "@/lib/settings"; +import { Settings } from "../admin/settings/interfaces"; export default async function Home() { // Disable caching so we always get the up to date connector / document set / persona info @@ -38,6 +40,7 @@ export default async function Home() { fetchSS("/persona"), fetchSS("/query/valid-tags"), fetchSS("/secondary-index/get-embedding-models"), + getSettingsSS(), ]; // catch cases where the backend is completely unreachable here @@ -48,6 +51,7 @@ export default async function Home() { | Response | AuthTypeMetadata | FullEmbeddingModelResponse + | Settings | null )[] = [null, null, null, null, null, null, null]; try { @@ -62,6 +66,7 @@ export default async function Home() { const personaResponse = results[4] as Response | null; const tagsResponse = results[5] as Response | null; const embeddingModelResponse = results[6] as Response | null; + const settings = results[7] as Settings | null; const authDisabled = authTypeMetadata?.authType === "disabled"; if (!authDisabled && !user) { @@ -72,6 +77,10 @@ export default async function Home() { return redirect("/auth/waiting-on-verification"); } + if (settings && !settings.search_page_enabled) { + return redirect("/chat"); + } + let ccPairs: CCPairBasicInfo[] = []; if (ccPairsResponse?.ok) { ccPairs = await ccPairsResponse.json(); @@ -143,7 +152,7 @@ export default async function Home() { return ( <> -
+
diff --git a/web/src/components/Dropdown.tsx b/web/src/components/Dropdown.tsx index 6637b8eb68..3cb1ba70d4 100644 --- a/web/src/components/Dropdown.tsx +++ b/web/src/components/Dropdown.tsx @@ -1,7 +1,6 @@ import { ChangeEvent, FC, useEffect, useRef, useState } from "react"; import { ChevronDownIcon } from "./icons/icons"; import { FiCheck, FiChevronDown } from "react-icons/fi"; -import { FaRobot } from "react-icons/fa"; export interface Option { name: string; @@ -12,108 +11,6 @@ export interface Option { export type StringOrNumberOption = Option; -interface DropdownProps { - options: Option[]; - selected: string; - onSelect: (selected: Option | null) => void; -} - -export const Dropdown = ({ - options, - selected, - onSelect, -}: DropdownProps) => { - const [isOpen, setIsOpen] = useState(false); - const dropdownRef = useRef(null); - - const selectedName = options.find( - (option) => option.value === selected - )?.name; - - const handleSelect = (option: StringOrNumberOption) => { - onSelect(option); - setIsOpen(false); - }; - - useEffect(() => { - const handleClickOutside = (event: MouseEvent) => { - if ( - dropdownRef.current && - !dropdownRef.current.contains(event.target as Node) - ) { - setIsOpen(false); - } - }; - - document.addEventListener("mousedown", handleClickOutside); - return () => { - document.removeEventListener("mousedown", handleClickOutside); - }; - }, []); - - return ( -
-
- -
- - {isOpen ? ( -
-
- {options.map((option, index) => ( - - ))} -
-
- ) : null} -
- ); -}; - function StandardDropdownOption({ index, option, diff --git a/web/src/components/Header.tsx b/web/src/components/Header.tsx index a4a04244da..7a28a0aa7d 100644 --- a/web/src/components/Header.tsx +++ b/web/src/components/Header.tsx @@ -9,14 +9,15 @@ import React, { useEffect, useRef, useState } from "react"; import { CustomDropdown, DefaultDropdownElement } from "./Dropdown"; import { FiMessageSquare, FiSearch } from "react-icons/fi"; import { usePathname } from "next/navigation"; +import { Settings } from "@/app/admin/settings/interfaces"; interface HeaderProps { user: User | null; + settings: Settings | null; } -export const Header: React.FC = ({ user }) => { +export function Header({ user, settings }: HeaderProps) { const router = useRouter(); - const pathname = usePathname(); const [dropdownOpen, setDropdownOpen] = useState(false); const dropdownRef = useRef(null); @@ -56,7 +57,12 @@ export const Header: React.FC = ({ user }) => { return (
- +
Logo @@ -67,26 +73,31 @@ export const Header: React.FC = ({ user }) => {
- -
-
- -

Search

-
-
- + {(!settings || + (settings.search_page_enabled && settings.chat_page_enabled)) && ( + <> + +
+
+ +

Search

+
+
+ - -
-
- -

Chat

-
-
- + +
+
+ +

Chat

+
+
+ + + )}
@@ -124,7 +135,7 @@ export const Header: React.FC = ({ user }) => {
); -}; +} /* diff --git a/web/src/components/admin/Layout.tsx b/web/src/components/admin/Layout.tsx index fadeaee8d9..0221f5172d 100644 --- a/web/src/components/admin/Layout.tsx +++ b/web/src/components/admin/Layout.tsx @@ -1,3 +1,4 @@ +import { Settings } from "@/app/admin/settings/interfaces"; import { Header } from "@/components/Header"; import { AdminSidebar } from "@/components/admin/connectors/AdminSidebar"; import { @@ -12,6 +13,7 @@ import { ConnectorIcon, SlackIcon, } from "@/components/icons/icons"; +import { getSettingsSS } from "@/lib/settings"; import { User } from "@/lib/types"; import { AuthTypeMetadata, @@ -19,15 +21,21 @@ import { getCurrentUserSS, } from "@/lib/userSS"; import { redirect } from "next/navigation"; -import { FiCpu, FiLayers, FiPackage, FiSlack } from "react-icons/fi"; +import { + FiCpu, + FiLayers, + FiPackage, + FiSettings, + FiSlack, +} from "react-icons/fi"; export async function Layout({ children }: { children: React.ReactNode }) { - const tasks = [getAuthTypeMetadataSS(), getCurrentUserSS()]; + const tasks = [getAuthTypeMetadataSS(), getCurrentUserSS(), getSettingsSS()]; // catch cases where the backend is completely unreachable here // without try / catch, will just raise an exception and the page // will not render - let results: (User | AuthTypeMetadata | null)[] = [null, null]; + let results: (User | AuthTypeMetadata | Settings | null)[] = [null, null]; try { results = await Promise.all(tasks); } catch (e) { @@ -36,6 +44,7 @@ export async function Layout({ children }: { children: React.ReactNode }) { const authTypeMetadata = results[0] as AuthTypeMetadata | null; const user = results[1] as User | null; + const settings = results[2] as Settings | null; const authDisabled = authTypeMetadata?.authType === "disabled"; const requiresVerification = authTypeMetadata?.requiresVerification; @@ -54,7 +63,7 @@ export async function Layout({ children }: { children: React.ReactNode }) { return (
-
+
@@ -175,6 +184,20 @@ export async function Layout({ children }: { children: React.ReactNode }) { }, ], }, + { + name: "Settings", + items: [ + { + name: ( +
+ +
Workspace Settings
+
+ ), + link: "/admin/settings", + }, + ], + }, ]} />
diff --git a/web/src/lib/settings.ts b/web/src/lib/settings.ts new file mode 100644 index 0000000000..76cfd143ce --- /dev/null +++ b/web/src/lib/settings.ts @@ -0,0 +1,10 @@ +import { Settings } from "@/app/admin/settings/interfaces"; +import { buildUrl } from "./utilsSS"; + +export async function getSettingsSS(): Promise { + const response = await fetch(buildUrl("/settings")); + if (response.ok) { + return await response.json(); + } + return null; +} From 5b8cdd4eeebd4a67b6614398ee8614b41bea94c0 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sun, 31 Mar 2024 21:58:00 -0700 Subject: [PATCH 36/58] Gpt-3.5-0125 Option (#1282) --- backend/danswer/server/features/persona/api.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index d75ff69480..b4359f6a1f 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -174,8 +174,9 @@ def build_final_template_prompt( Putting here for now, since we have no other flows which use this.""" GPT_4_MODEL_VERSIONS = [ - "gpt-4-1106-preview", "gpt-4", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", "gpt-4-32k", "gpt-4-0613", "gpt-4-32k-0613", @@ -183,8 +184,9 @@ def build_final_template_prompt( "gpt-4-32k-0314", ] GPT_3_5_TURBO_MODEL_VERSIONS = [ - "gpt-3.5-turbo-1106", "gpt-3.5-turbo", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", From 0b0fc785a189f8a8c8a3b20cc99433a2ec4a2889 Mon Sep 17 00:00:00 2001 From: Weves Date: Sun, 31 Mar 2024 23:49:49 -0700 Subject: [PATCH 37/58] Fix fetch settings SS --- web/src/lib/settings.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/src/lib/settings.ts b/web/src/lib/settings.ts index 76cfd143ce..260f62bf1c 100644 --- a/web/src/lib/settings.ts +++ b/web/src/lib/settings.ts @@ -1,8 +1,8 @@ import { Settings } from "@/app/admin/settings/interfaces"; -import { buildUrl } from "./utilsSS"; +import { fetchSS } from "./utilsSS"; export async function getSettingsSS(): Promise { - const response = await fetch(buildUrl("/settings")); + const response = await fetchSS("/settings"); if (response.ok) { return await response.json(); } From e82061a5ece6578ff7c9ad9b0293e70053f23ca6 Mon Sep 17 00:00:00 2001 From: Weves Date: Mon, 1 Apr 2024 11:18:54 -0700 Subject: [PATCH 38/58] Add support for specifying title via search params --- .../danswer/server/query_and_chat/chat_backend.py | 3 ++- backend/danswer/server/query_and_chat/models.py | 1 + web/src/app/chat/Chat.tsx | 12 ++++++++++-- web/src/app/chat/lib.tsx | 7 ++++++- web/src/app/chat/searchParams.ts | 2 ++ 5 files changed, 21 insertions(+), 4 deletions(-) diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 8a7d9e2154..19b7db23b3 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -123,7 +123,8 @@ def create_new_chat_session( try: new_chat_session = create_chat_session( db_session=db_session, - description="", # Leave the naming till later to prevent delay + description=chat_session_creation_request.description + or "", # Leave the naming till later to prevent delay user_id=user_id, persona_id=chat_session_creation_request.persona_id, ) diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 3ee701620a..8f86814772 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -31,6 +31,7 @@ class SimpleQueryRequest(BaseModel): class ChatSessionCreationRequest(BaseModel): # If not specified, use Danswer default persona persona_id: int = 0 + description: str | None = None class HelperResponse(BaseModel): diff --git a/web/src/app/chat/Chat.tsx b/web/src/app/chat/Chat.tsx index e16e873b06..cd2f1c4dac 100644 --- a/web/src/app/chat/Chat.tsx +++ b/web/src/app/chat/Chat.tsx @@ -334,8 +334,14 @@ export const Chat = ({ } = {}) => { let currChatSessionId: number; let isNewSession = chatSessionId === null; + const searchParamBasedChatSessionName = + searchParams.get(SEARCH_PARAM_NAMES.TITLE) || null; + if (isNewSession) { - currChatSessionId = await createChatSession(livePersona?.id || 0); + currChatSessionId = await createChatSession( + livePersona?.id || 0, + searchParamBasedChatSessionName + ); } else { currChatSessionId = chatSessionId as number; } @@ -475,7 +481,9 @@ export const Chat = ({ if (finalMessage) { setSelectedMessageForDocDisplay(finalMessage.message_id); } - await nameChatSession(currChatSessionId, currMessage); + if (!searchParamBasedChatSessionName) { + await nameChatSession(currChatSessionId, currMessage); + } // NOTE: don't switch pages if the user has navigated away from the chat if ( diff --git a/web/src/app/chat/lib.tsx b/web/src/app/chat/lib.tsx index dbc0506fd0..c5195abf83 100644 --- a/web/src/app/chat/lib.tsx +++ b/web/src/app/chat/lib.tsx @@ -18,7 +18,10 @@ import { Persona } from "../admin/personas/interfaces"; import { ReadonlyURLSearchParams } from "next/navigation"; import { SEARCH_PARAM_NAMES } from "./searchParams"; -export async function createChatSession(personaId: number): Promise { +export async function createChatSession( + personaId: number, + description: string | null +): Promise { const createChatSessionResponse = await fetch( "/api/chat/create-chat-session", { @@ -28,6 +31,7 @@ export async function createChatSession(personaId: number): Promise { }, body: JSON.stringify({ persona_id: personaId, + description, }), } ); @@ -373,6 +377,7 @@ export function personaIncludesRetrieval(selectedPersona: Persona) { const PARAMS_TO_SKIP = [ SEARCH_PARAM_NAMES.SUBMIT_ON_LOAD, SEARCH_PARAM_NAMES.USER_MESSAGE, + SEARCH_PARAM_NAMES.TITLE, // only use these if explicitly passed in SEARCH_PARAM_NAMES.CHAT_ID, SEARCH_PARAM_NAMES.PERSONA_ID, diff --git a/web/src/app/chat/searchParams.ts b/web/src/app/chat/searchParams.ts index 9de5d9ec89..0ebfbe4822 100644 --- a/web/src/app/chat/searchParams.ts +++ b/web/src/app/chat/searchParams.ts @@ -11,6 +11,8 @@ export const SEARCH_PARAM_NAMES = { // user message USER_MESSAGE: "user-message", SUBMIT_ON_LOAD: "submit-on-load", + // chat title + TITLE: "title", }; export function shouldSubmitOnLoad(searchParams: ReadonlyURLSearchParams) { From 87019fc18ee7cf0a4dfd45e8b9b10006ab44e9a8 Mon Sep 17 00:00:00 2001 From: Weves Date: Mon, 1 Apr 2024 17:35:35 -0700 Subject: [PATCH 39/58] Notion 404 graceful error handling --- .../danswer/connectors/notion/connector.py | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/backend/danswer/connectors/notion/connector.py b/backend/danswer/connectors/notion/connector.py index 28fb47a44d..e0e307fc56 100644 --- a/backend/danswer/connectors/notion/connector.py +++ b/backend/danswer/connectors/notion/connector.py @@ -93,7 +93,9 @@ def __init__( self.recursive_index_enabled = recursive_index_enabled or self.root_page_id @retry(tries=3, delay=1, backoff=2) - def _fetch_blocks(self, block_id: str, cursor: str | None = None) -> dict[str, Any]: + def _fetch_child_blocks( + self, block_id: str, cursor: str | None = None + ) -> dict[str, Any] | None: """Fetch all child blocks via the Notion API.""" logger.debug(f"Fetching children of block with ID '{block_id}'") block_url = f"https://api.notion.com/v1/blocks/{block_id}/children" @@ -107,6 +109,15 @@ def _fetch_blocks(self, block_id: str, cursor: str | None = None) -> dict[str, A try: res.raise_for_status() except Exception as e: + if res.status_code == 404: + # this happens when a page is not shared with the integration + # in this case, we should just ignore the page + logger.error( + f"Unable to access block with ID '{block_id}'. " + f"This is likely due to the block not being shared " + f"with the Danswer integration. Exact exception:\n\n{e}" + ) + return None logger.exception(f"Error fetching blocks - {res.json()}") raise e return res.json() @@ -187,24 +198,30 @@ def _read_pages_from_database(self, database_id: str) -> list[str]: return result_pages def _read_blocks( - self, page_block_id: str + self, base_block_id: str ) -> tuple[list[tuple[str, str]], list[str]]: - """Reads blocks for a page""" + """Reads all child blocks for the specified block""" result_lines: list[tuple[str, str]] = [] child_pages: list[str] = [] cursor = None while True: - data = self._fetch_blocks(page_block_id, cursor) + data = self._fetch_child_blocks(base_block_id, cursor) + + # this happens when a block is not shared with the integration + if data is None: + return result_lines, child_pages for result in data["results"]: - logger.debug(f"Found block for page '{page_block_id}': {result}") + logger.debug( + f"Found child block for block with ID '{base_block_id}': {result}" + ) result_block_id = result["id"] result_type = result["type"] result_obj = result[result_type] if result_type == "ai_block": logger.warning( - f"Skipping 'ai_block' ('{result_block_id}') for page '{page_block_id}': " + f"Skipping 'ai_block' ('{result_block_id}') for base block '{base_block_id}': " f"Notion API does not currently support reading AI blocks (as of 24/02/09) " f"(discussion: https://github.com/danswer-ai/danswer/issues/1053)" ) From b0e0557d636b9bf5eeaf96350094b815d2f7f71b Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Mon, 1 Apr 2024 22:41:40 -0700 Subject: [PATCH 40/58] Update Contributing (#1288) --- CONTRIBUTING.md | 29 +++-------------------------- 1 file changed, 3 insertions(+), 26 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4d88752da2..f32f4fff30 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -112,25 +112,11 @@ docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational (index refers to Vespa and relational_db refers to Postgres) #### Running Danswer - -Setup a folder to store config. Navigate to `danswer/backend` and run: -```bash -mkdir dynamic_config_storage -``` - To start the frontend, navigate to `danswer/web` and run: ```bash npm run dev ``` -Package the Vespa schema. This will only need to be done when the Vespa schema is updated locally. - -Navigate to `danswer/backend/danswer/document_index/vespa/app_config` and run: -```bash -zip -r ../vespa-app.zip . -``` -- Note: If you don't have the `zip` utility, you will need to install it prior to running the above - The first time running Danswer, you will also need to run the DB migrations for Postgres. After the first time, this is no longer required unless the DB models change. @@ -149,17 +135,12 @@ python ./scripts/dev_run_background_jobs.py To run the backend API server, navigate back to `danswer/backend` and run: ```bash -AUTH_TYPE=disabled \ -DYNAMIC_CONFIG_DIR_PATH=./dynamic_config_storage \ -VESPA_DEPLOYMENT_ZIP=./danswer/document_index/vespa/vespa-app.zip \ -uvicorn danswer.main:app --reload --port 8080 +AUTH_TYPE=disabled uvicorn danswer.main:app --reload --port 8080 ``` _For Windows (for compatibility with both PowerShell and Command Prompt):_ ```bash powershell -Command " $env:AUTH_TYPE='disabled' - $env:DYNAMIC_CONFIG_DIR_PATH='./dynamic_config_storage' - $env:VESPA_DEPLOYMENT_ZIP='./danswer/document_index/vespa/vespa-app.zip' uvicorn danswer.main:app --reload --port 8080 " ``` @@ -178,20 +159,16 @@ pre-commit install Additionally, we use `mypy` for static type checking. Danswer is fully type-annotated, and we would like to keep it that way! -Right now, there is no automated type checking at the moment (coming soon), but we ask you to manually run it before -creating a pull requests with `python -m mypy .` from the `danswer/backend` directory. +To run the mypy checks manually, run `python -m mypy .` from the `danswer/backend` directory. #### Web We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `danswer/web` directory. To run the formatter, use `npx prettier --write .` from the `danswer/web` directory. -Like `mypy`, we have no automated formatting yet (coming soon), but we request that, for now, -you run this manually before creating a pull request. +Please double check that prettier passes before creating a pull request. ### Release Process Danswer follows the semver versioning standard. A set of Docker containers will be pushed automatically to DockerHub with every tag. You can see the containers [here](https://hub.docker.com/search?q=danswer%2F). - -As pre-1.0 software, even patch releases may contain breaking or non-backwards-compatible changes. From b06b95dc3a593c8e65350041d11786b521d8cdd3 Mon Sep 17 00:00:00 2001 From: Weves Date: Tue, 2 Apr 2024 19:30:52 -0700 Subject: [PATCH 41/58] Bump litellm version to support latest Anthropic models --- backend/requirements/default.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 638375426f..1d92e8f25e 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -24,7 +24,7 @@ httpx-oauth==0.11.2 huggingface-hub==0.20.1 jira==3.5.1 langchain==0.1.9 -litellm==1.34.8 +litellm==1.34.21 llama-index==0.9.45 Mako==1.2.4 msal==1.26.0 From d329061f92a23758155603a13a5b8a36ba6647bf Mon Sep 17 00:00:00 2001 From: ThomaciousD <2194608+ThomaciousD@users.noreply.github.com> Date: Wed, 3 Apr 2024 08:17:53 +0200 Subject: [PATCH 42/58] Fixed: Web connector - documents deleted when no internet #1161 (#1292) * fixing check connection before scrape in web connector #1161 * reformat --------- Co-authored-by: ThomaciousD --- backend/danswer/connectors/web/connector.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/backend/danswer/connectors/web/connector.py b/backend/danswer/connectors/web/connector.py index 38f30a28ed..37b65f8da7 100644 --- a/backend/danswer/connectors/web/connector.py +++ b/backend/danswer/connectors/web/connector.py @@ -1,5 +1,4 @@ import io -import socket from enum import Enum from typing import Any from typing import cast @@ -43,15 +42,12 @@ class WEB_CONNECTOR_VALID_SETTINGS(str, Enum): UPLOAD = "upload" -def check_internet_connection() -> None: - dns_servers = [("1.1.1.1", 53), ("8.8.8.8", 53)] - for server in dns_servers: - try: - socket.create_connection(server, timeout=3) - return - except OSError: - continue - raise Exception("Unable to contact DNS server - check your internet connection") +def check_internet_connection(url: str) -> None: + try: + response = requests.get(url, timeout=3) + response.raise_for_status() + except (requests.RequestException, ValueError): + raise Exception(f"Unable to reach {url} - check your internet connection") def is_valid_url(url: str) -> bool: @@ -185,7 +181,6 @@ def load_from_state(self) -> GenerateDocumentsOutput: base_url = to_visit[0] # For the recursive case doc_batch: list[Document] = [] - check_internet_connection() playwright, context = start_playwright() restart_playwright = False while to_visit: @@ -197,6 +192,7 @@ def load_from_state(self) -> GenerateDocumentsOutput: logger.info(f"Visiting {current_url}") try: + check_internet_connection(current_url) if restart_playwright: playwright, context = start_playwright() restart_playwright = False From c7efce3bde82fb7a486e94ced6bb0f31035e628c Mon Sep 17 00:00:00 2001 From: Weves Date: Wed, 3 Apr 2024 23:13:49 -0700 Subject: [PATCH 43/58] Enable bedrock nodels in dev compose file --- deployment/docker_compose/docker-compose.dev.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 0dfcba7661..a6d7ef2946 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -44,6 +44,10 @@ services: - DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-} - DISABLE_LLM_QUERY_REPHRASE=${DISABLE_LLM_QUERY_REPHRASE:-} - DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-} + # Enables the use of bedrock models + - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-} + - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-} + - AWS_REGION_NAME=${AWS_REGION_NAME:-} # Query Options - DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years) - HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector) From 4abf5f27a03a908701cf3746c335c88111a98d99 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 4 Apr 2024 03:51:10 -0700 Subject: [PATCH 44/58] Axero Forums Support (#1287) --- backend/danswer/connectors/axero/connector.py | 169 +++++++++++++++++- web/src/app/admin/connectors/axero/page.tsx | 6 +- 2 files changed, 167 insertions(+), 8 deletions(-) diff --git a/backend/danswer/connectors/axero/connector.py b/backend/danswer/connectors/axero/connector.py index fcb4395589..9ee7b2d6f0 100644 --- a/backend/danswer/connectors/axero/connector.py +++ b/backend/danswer/connectors/axero/connector.py @@ -4,6 +4,7 @@ from typing import Any import requests +from pydantic import BaseModel from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource @@ -12,6 +13,10 @@ process_in_batches, ) from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc +from danswer.connectors.cross_connector_utils.rate_limit_wrapper import ( + rate_limit_builder, +) +from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder from danswer.connectors.interfaces import GenerateDocumentsOutput from danswer.connectors.interfaces import PollConnector from danswer.connectors.interfaces import SecondsSinceUnixEpoch @@ -31,6 +36,15 @@ def _get_auth_header(api_key: str) -> dict[str, str]: return {"Rest-Api-Key": api_key} +@retry_builder() +@rate_limit_builder(max_calls=5, period=1) +def _rate_limited_request( + endpoint: str, headers: dict, params: dict | None = None +) -> Any: + # https://my.axerosolutions.com/spaces/5/communifire-documentation/wiki/view/370/rest-api + return requests.get(endpoint, headers=headers, params=params) + + # https://my.axerosolutions.com/spaces/5/communifire-documentation/wiki/view/595/rest-api-get-content-list def _get_entities( entity_type: int, @@ -56,7 +70,9 @@ def _get_entities( if space_id is not None: params["SpaceID"] = space_id - res = requests.get(endpoint, headers=_get_auth_header(api_key), params=params) + res = _rate_limited_request( + endpoint, headers=_get_auth_header(api_key), params=params + ) res.raise_for_status() # Axero limitations: @@ -95,6 +111,126 @@ def _get_entities( return pages_to_return +def _get_obj_by_id(obj_id: int, api_key: str, axero_base_url: str) -> dict: + endpoint = axero_base_url + f"api/content/{obj_id}" + res = _rate_limited_request(endpoint, headers=_get_auth_header(api_key)) + res.raise_for_status() + + return res.json() + + +class AxeroForum(BaseModel): + doc_id: str + title: str + link: str + initial_content: str + responses: list[str] + last_update: datetime + + +def _map_post_to_parent( + posts: dict, + api_key: str, + axero_base_url: str, +) -> list[AxeroForum]: + """Cannot handle in batches since the posts aren't ordered or structured in any way + may need to map any number of them to the initial post""" + epoch_str = "1970-01-01T00:00:00.000" + post_map: dict[int, AxeroForum] = {} + + for ind, post in enumerate(posts): + if (ind + 1) % 25 == 0: + logger.debug(f"Processed {ind + 1} posts or responses") + + post_time = time_str_to_utc( + post.get("DateUpdated") or post.get("DateCreated") or epoch_str + ) + p_id = post.get("ParentContentID") + if p_id in post_map: + axero_forum = post_map[p_id] + axero_forum.responses.insert(0, post.get("ContentSummary")) + axero_forum.last_update = max(axero_forum.last_update, post_time) + else: + initial_post_d = _get_obj_by_id(p_id, api_key, axero_base_url)[ + "ResponseData" + ] + initial_post_time = time_str_to_utc( + initial_post_d.get("DateUpdated") + or initial_post_d.get("DateCreated") + or epoch_str + ) + post_map[p_id] = AxeroForum( + doc_id="AXERO_" + str(initial_post_d.get("ContentID")), + title=initial_post_d.get("ContentTitle"), + link=initial_post_d.get("ContentURL"), + initial_content=initial_post_d.get("ContentSummary"), + responses=[post.get("ContentSummary")], + last_update=max(post_time, initial_post_time), + ) + + return list(post_map.values()) + + +def _get_forums( + api_key: str, + axero_base_url: str, + space_id: str | None = None, +) -> list[dict]: + endpoint = axero_base_url + "api/content/list" + page_num = 1 + pages_fetched = 0 + pages_to_return = [] + break_out = False + + while True: + params = { + "EntityType": "54", + "SortColumn": "DateUpdated", + "SortOrder": "1", # descending + "StartPage": str(page_num), + } + + if space_id is not None: + params["SpaceID"] = space_id + + res = _rate_limited_request( + endpoint, headers=_get_auth_header(api_key), params=params + ) + res.raise_for_status() + + data = res.json() + total_records = data["TotalRecords"] + contents = data["ResponseData"] + pages_fetched += len(contents) + logger.debug(f"Fetched {pages_fetched} forums") + + for page in contents: + pages_to_return.append(page) + + if pages_fetched >= total_records: + break + + page_num += 1 + + if break_out: + break + + return pages_to_return + + +def _translate_forum_to_doc(af: AxeroForum) -> Document: + doc = Document( + id=af.doc_id, + sections=[Section(link=af.link, text=reply) for reply in af.responses], + source=DocumentSource.AXERO, + semantic_identifier=af.title, + doc_updated_at=af.last_update, + metadata={}, + ) + + return doc + + def _translate_content_to_doc(content: dict) -> Document: page_text = "" summary = content.get("ContentSummary") @@ -126,8 +262,7 @@ def __init__( include_article: bool = True, include_blog: bool = True, include_wiki: bool = True, - # Forums not supported atm - include_forum: bool = False, + include_forum: bool = True, batch_size: int = INDEX_BATCH_SIZE, ) -> None: self.include_article = include_article @@ -165,12 +300,11 @@ def poll_source( entity_types.append(4) if self.include_wiki: entity_types.append(9) - if self.include_forum: - raise NotImplementedError("Forums for Axero not supported currently") iterable_space_ids = self.space_ids if self.space_ids else [None] for space_id in iterable_space_ids: + entity_types = [] for entity in entity_types: axero_obj = _get_entities( entity_type=entity, @@ -186,6 +320,31 @@ def poll_source( batch_size=self.batch_size, ) + if self.include_forum: + forums_posts = _get_forums( + api_key=self.axero_key, + axero_base_url=self.base_url, + space_id=space_id, + ) + + all_axero_forums = _map_post_to_parent( + posts=forums_posts, + api_key=self.axero_key, + axero_base_url=self.base_url, + ) + + filtered_forums = [ + f + for f in all_axero_forums + if f.last_update >= start_datetime and f.last_update <= end_datetime + ] + + yield from process_in_batches( + objects=filtered_forums, + process_function=_translate_forum_to_doc, + batch_size=self.batch_size, + ) + if __name__ == "__main__": import os diff --git a/web/src/app/admin/connectors/axero/page.tsx b/web/src/app/admin/connectors/axero/page.tsx index b434f8528d..ccabc380c8 100644 --- a/web/src/app/admin/connectors/axero/page.tsx +++ b/web/src/app/admin/connectors/axero/page.tsx @@ -150,8 +150,8 @@ const MainSection = () => { {axeroConnectorIndexingStatuses.length > 0 && ( <> - We pull the latest Articles, Blogs, and Wikis{" "} - every 10 minutes. + We pull the latest Articles, Blogs, Wikis and{" "} + Forums once per day.
@@ -222,7 +222,7 @@ const MainSection = () => { initialValues={{ spaces: [], }} - refreshFreq={10 * 60} // 10 minutes + refreshFreq={60 * 60 * 24} // 1 day credentialId={axeroCredential.id} /> From 7298cc28356cd5419cd32eb11cc3660274339b9c Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 4 Apr 2024 05:30:23 -0700 Subject: [PATCH 45/58] Add verbose logging in case of query failure (#1297) --- backend/danswer/document_index/vespa/index.py | 72 ++++++++++++++----- 1 file changed, 55 insertions(+), 17 deletions(-) diff --git a/backend/danswer/document_index/vespa/index.py b/backend/danswer/document_index/vespa/index.py index 9f78f05c20..56c36d1e41 100644 --- a/backend/danswer/document_index/vespa/index.py +++ b/backend/danswer/document_index/vespa/index.py @@ -112,13 +112,13 @@ def _does_document_exist( """Returns whether the document already exists and the users/group whitelists Specifically in this case, document refers to a vespa document which is equivalent to a Danswer chunk. This checks for whether the chunk exists already in the index""" - doc_fetch_response = http_client.get( - f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}" - ) + doc_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}" + doc_fetch_response = http_client.get(doc_url) if doc_fetch_response.status_code == 404: return False if doc_fetch_response.status_code != 200: + logger.debug(f"Failed to check for document with URL {doc_url}") raise RuntimeError( f"Unexpected fetch document by ID value from Vespa " f"with error {doc_fetch_response.status_code}" @@ -157,7 +157,24 @@ def _get_vespa_chunk_ids_by_document_id( "hits": hits_per_page, } while True: - results = requests.post(SEARCH_ENDPOINT, json=params).json() + res = requests.post(SEARCH_ENDPOINT, json=params) + try: + res.raise_for_status() + except requests.HTTPError as e: + request_info = f"Headers: {res.request.headers}\nPayload: {params}" + response_info = ( + f"Status Code: {res.status_code}\nResponse Content: {res.text}" + ) + error_base = f"Error occurred getting chunk by Document ID {document_id}" + logger.error( + f"{error_base}:\n" + f"{request_info}\n" + f"{response_info}\n" + f"Exception: {e}" + ) + raise requests.HTTPError(error_base) from e + + results = res.json() hits = results["root"].get("children", []) doc_chunk_ids.extend( @@ -179,10 +196,14 @@ def _delete_vespa_doc_chunks( ) for chunk_id in doc_chunk_ids: - res = http_client.delete( - f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{chunk_id}" - ) - res.raise_for_status() + try: + res = http_client.delete( + f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{chunk_id}" + ) + res.raise_for_status() + except httpx.HTTPStatusError as e: + logger.error(f"Failed to delete chunk, details: {e.response.text}") + raise def _delete_vespa_docs( @@ -559,18 +580,35 @@ def _query_vespa(query_params: Mapping[str, str | int | float]) -> list[Inferenc if "query" in query_params and not cast(str, query_params["query"]).strip(): raise ValueError("No/empty query received") + params = dict( + **query_params, + **{ + "presentation.timing": True, + } + if LOG_VESPA_TIMING_INFORMATION + else {}, + ) + response = requests.post( SEARCH_ENDPOINT, - json=dict( - **query_params, - **{ - "presentation.timing": True, - } - if LOG_VESPA_TIMING_INFORMATION - else {}, - ), + json=params, ) - response.raise_for_status() + try: + response.raise_for_status() + except requests.HTTPError as e: + request_info = f"Headers: {response.request.headers}\nPayload: {params}" + response_info = ( + f"Status Code: {response.status_code}\n" + f"Response Content: {response.text}" + ) + error_base = "Failed to query Vespa" + logger.error( + f"{error_base}:\n" + f"{request_info}\n" + f"{response_info}\n" + f"Exception: {e}" + ) + raise requests.HTTPError(error_base) from e response_json: dict[str, Any] = response.json() if LOG_VESPA_TIMING_INFORMATION: From 58dc620c286633a5a8559f1e2d92cde126b59113 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 4 Apr 2024 06:29:00 -0700 Subject: [PATCH 46/58] Add Check for Enabling Reranking (#1298) --- backend/danswer/search/search_nlp_models.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/backend/danswer/search/search_nlp_models.py b/backend/danswer/search/search_nlp_models.py index 783c2f8366..a43b814065 100644 --- a/backend/danswer/search/search_nlp_models.py +++ b/backend/danswer/search/search_nlp_models.py @@ -14,6 +14,8 @@ from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL +from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW +from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW from danswer.configs.model_configs import INTENT_MODEL_VERSION from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE from danswer.utils.logger import setup_logger @@ -261,6 +263,14 @@ def __init__( ) def load_model(self) -> list["CrossEncoder"] | None: + if ( + ENABLE_RERANKING_REAL_TIME_FLOW is False + and ENABLE_RERANKING_ASYNC_FLOW is False + ): + raise RuntimeError( + "Should not be loading rerankers, they have been globally disabled" + ) + if self.rerank_server_endpoint: return None @@ -363,7 +373,7 @@ def predict( def warm_up_models( model_name: str, normalize: bool, - skip_cross_encoders: bool = False, + skip_cross_encoders: bool = True, indexer_only: bool = False, ) -> None: warm_up_str = ( From 33da86c8025804c5590a5eba10c35320ad22e58d Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 4 Apr 2024 06:59:48 -0700 Subject: [PATCH 47/58] Reranker Warning Log (#1299) --- backend/danswer/search/search_nlp_models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/backend/danswer/search/search_nlp_models.py b/backend/danswer/search/search_nlp_models.py index a43b814065..50decb92d2 100644 --- a/backend/danswer/search/search_nlp_models.py +++ b/backend/danswer/search/search_nlp_models.py @@ -267,8 +267,9 @@ def load_model(self) -> list["CrossEncoder"] | None: ENABLE_RERANKING_REAL_TIME_FLOW is False and ENABLE_RERANKING_ASYNC_FLOW is False ): - raise RuntimeError( - "Should not be loading rerankers, they have been globally disabled" + logger.warning( + "Running rerankers but they are globally disabled." + "Was this specified explicitly via an API?" ) if self.rerank_server_endpoint: From 7ba7224929074b1b0aeed1238ddca3cc66def85e Mon Sep 17 00:00:00 2001 From: Weves Date: Tue, 2 Apr 2024 00:14:05 -0700 Subject: [PATCH 48/58] Allow seeding of chat sessions via POST --- ...f1a3b_add_overrides_to_the_chat_session.py | 40 ++++++++ backend/danswer/chat/process_message.py | 76 +++++++++------ backend/danswer/db/chat.py | 10 +- backend/danswer/db/enums.py | 35 +++++++ backend/danswer/db/models.py | 55 ++++------- backend/danswer/db/pydantic_type.py | 32 +++++++ backend/danswer/llm/answering/models.py | 4 +- backend/danswer/llm/override_models.py | 17 ++++ .../server/query_and_chat/chat_backend.py | 95 ++++++++++++++++++- .../danswer/server/query_and_chat/models.py | 18 ++-- web/src/app/chat/Chat.tsx | 20 +++- web/src/app/chat/lib.tsx | 24 +++-- 12 files changed, 339 insertions(+), 87 deletions(-) create mode 100644 backend/alembic/versions/ecab2b3f1a3b_add_overrides_to_the_chat_session.py create mode 100644 backend/danswer/db/enums.py create mode 100644 backend/danswer/db/pydantic_type.py create mode 100644 backend/danswer/llm/override_models.py diff --git a/backend/alembic/versions/ecab2b3f1a3b_add_overrides_to_the_chat_session.py b/backend/alembic/versions/ecab2b3f1a3b_add_overrides_to_the_chat_session.py new file mode 100644 index 0000000000..791d7e42e0 --- /dev/null +++ b/backend/alembic/versions/ecab2b3f1a3b_add_overrides_to_the_chat_session.py @@ -0,0 +1,40 @@ +"""Add overrides to the chat session + +Revision ID: ecab2b3f1a3b +Revises: 38eda64af7fe +Create Date: 2024-04-01 19:08:21.359102 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "ecab2b3f1a3b" +down_revision = "38eda64af7fe" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "chat_session", + sa.Column( + "llm_override", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + ) + op.add_column( + "chat_session", + sa.Column( + "prompt_override", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + ) + + +def downgrade() -> None: + op.drop_column("chat_session", "prompt_override") + op.drop_column("chat_session", "llm_override") diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index d9e7f9b6c4..f904f49638 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -95,6 +95,10 @@ def stream_chat_message_objects( # For flow with search, don't include as many chunks as possible since we need to leave space # for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE, + # if specified, uses the last user message and does not create a new user message based + # on the `new_msg_req.message`. Currently, requires a state where the last message is a + # user message (e.g. this can only be used for the chat-seeding flow). + use_existing_user_message: bool = False, ) -> ChatPacketStream: """Streams in order: 1. [conditional] Retrieved documents if a search needs to be run @@ -161,33 +165,43 @@ def stream_chat_message_objects( else: parent_message = root_message - # Create new message at the right place in the tree and update the parent's child pointer - # Don't commit yet until we verify the chat message chain - new_user_message = create_new_chat_message( - chat_session_id=chat_session_id, - parent_message=parent_message, - prompt_id=prompt_id, - message=message_text, - token_count=len(llm_tokenizer_encode_func(message_text)), - message_type=MessageType.USER, - db_session=db_session, - commit=False, - ) - - # Create linear history of messages - final_msg, history_msgs = create_chat_chain( - chat_session_id=chat_session_id, db_session=db_session - ) - - if final_msg.id != new_user_message.id: - db_session.rollback() - raise RuntimeError( - "The new message was not on the mainline. " - "Be sure to update the chat pointers before calling this." + if not use_existing_user_message: + # Create new message at the right place in the tree and update the parent's child pointer + # Don't commit yet until we verify the chat message chain + user_message = create_new_chat_message( + chat_session_id=chat_session_id, + parent_message=parent_message, + prompt_id=prompt_id, + message=message_text, + token_count=len(llm_tokenizer_encode_func(message_text)), + message_type=MessageType.USER, + db_session=db_session, + commit=False, ) + # re-create linear history of messages + final_msg, history_msgs = create_chat_chain( + chat_session_id=chat_session_id, db_session=db_session + ) + if final_msg.id != user_message.id: + db_session.rollback() + raise RuntimeError( + "The new message was not on the mainline. " + "Be sure to update the chat pointers before calling this." + ) - # Save now to save the latest chat message - db_session.commit() + # Save now to save the latest chat message + db_session.commit() + else: + # re-create linear history of messages + final_msg, history_msgs = create_chat_chain( + chat_session_id=chat_session_id, db_session=db_session + ) + if final_msg.message_type != MessageType.USER: + raise RuntimeError( + "The last message was not a user message. Cannot call " + "`stream_chat_message_objects` with `is_regenerate=True` " + "when the last message is not a user message." + ) run_search = False # Retrieval options are only None if reference_doc_ids are provided @@ -304,7 +318,7 @@ def stream_chat_message_objects( partial_response = partial( create_new_chat_message, chat_session_id=chat_session_id, - parent_message=new_user_message, + parent_message=final_msg, prompt_id=prompt_id, # message=, rephrased_query=rephrased_query, @@ -346,10 +360,14 @@ def stream_chat_message_objects( document_pruning_config=document_pruning_config, ), prompt_config=PromptConfig.from_model( - final_msg.prompt, prompt_override=new_msg_req.prompt_override + final_msg.prompt, + prompt_override=( + new_msg_req.prompt_override or chat_session.prompt_override + ), ), llm_config=LLMConfig.from_persona( - persona, llm_override=new_msg_req.llm_override + persona, + llm_override=(new_msg_req.llm_override or chat_session.llm_override), ), message_history=[ PreviousMessage.from_chat_message(msg) for msg in history_msgs @@ -399,12 +417,14 @@ def stream_chat_message_objects( def stream_chat_message( new_msg_req: CreateChatMessageRequest, user: User | None, + use_existing_user_message: bool = False, ) -> Iterator[str]: with get_session_context_manager() as db_session: objects = stream_chat_message_objects( new_msg_req=new_msg_req, user=user, db_session=db_session, + use_existing_user_message=use_existing_user_message, ) for obj in objects: yield get_json_line(obj.dict()) diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index eb2d49b4c3..738d02a165 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -28,6 +28,8 @@ from danswer.db.models import SearchDoc as DBSearchDoc from danswer.db.models import StarterMessage from danswer.db.models import User__UserGroup +from danswer.llm.override_models import LLMOverride +from danswer.llm.override_models import PromptOverride from danswer.search.enums import RecencyBiasSetting from danswer.search.models import RetrievalDocs from danswer.search.models import SavedSearchDoc @@ -53,7 +55,9 @@ def get_chat_session_by_id( # if user_id is None, assume this is an admin who should be able # to view all chat sessions if user_id is not None: - stmt = stmt.where(ChatSession.user_id == user_id) + stmt = stmt.where( + or_(ChatSession.user_id == user_id, ChatSession.user_id.is_(None)) + ) result = db_session.execute(stmt) chat_session = result.scalar_one_or_none() @@ -92,12 +96,16 @@ def create_chat_session( description: str, user_id: UUID | None, persona_id: int | None = None, + llm_override: LLMOverride | None = None, + prompt_override: PromptOverride | None = None, one_shot: bool = False, ) -> ChatSession: chat_session = ChatSession( user_id=user_id, persona_id=persona_id, description=description, + llm_override=llm_override, + prompt_override=prompt_override, one_shot=one_shot, ) diff --git a/backend/danswer/db/enums.py b/backend/danswer/db/enums.py new file mode 100644 index 0000000000..2a02e078c6 --- /dev/null +++ b/backend/danswer/db/enums.py @@ -0,0 +1,35 @@ +from enum import Enum as PyEnum + + +class IndexingStatus(str, PyEnum): + NOT_STARTED = "not_started" + IN_PROGRESS = "in_progress" + SUCCESS = "success" + FAILED = "failed" + + +# these may differ in the future, which is why we're okay with this duplication +class DeletionStatus(str, PyEnum): + NOT_STARTED = "not_started" + IN_PROGRESS = "in_progress" + SUCCESS = "success" + FAILED = "failed" + + +# Consistent with Celery task statuses +class TaskStatus(str, PyEnum): + PENDING = "PENDING" + STARTED = "STARTED" + SUCCESS = "SUCCESS" + FAILURE = "FAILURE" + + +class IndexModelStatus(str, PyEnum): + PAST = "PAST" + PRESENT = "PRESENT" + FUTURE = "FUTURE" + + +class ChatSessionSharedStatus(str, PyEnum): + PUBLIC = "public" + PRIVATE = "private" diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 4a44882f2d..7fb6bbaa77 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -35,45 +35,18 @@ from danswer.configs.constants import MessageType from danswer.configs.constants import SearchFeedbackType from danswer.connectors.models import InputType +from danswer.db.enums import ChatSessionSharedStatus +from danswer.db.enums import IndexingStatus +from danswer.db.enums import IndexModelStatus +from danswer.db.enums import TaskStatus +from danswer.db.pydantic_type import PydanticType from danswer.dynamic_configs.interface import JSON_ro +from danswer.llm.override_models import LLMOverride +from danswer.llm.override_models import PromptOverride from danswer.search.enums import RecencyBiasSetting from danswer.search.enums import SearchType -class IndexingStatus(str, PyEnum): - NOT_STARTED = "not_started" - IN_PROGRESS = "in_progress" - SUCCESS = "success" - FAILED = "failed" - - -# these may differ in the future, which is why we're okay with this duplication -class DeletionStatus(str, PyEnum): - NOT_STARTED = "not_started" - IN_PROGRESS = "in_progress" - SUCCESS = "success" - FAILED = "failed" - - -# Consistent with Celery task statuses -class TaskStatus(str, PyEnum): - PENDING = "PENDING" - STARTED = "STARTED" - SUCCESS = "SUCCESS" - FAILURE = "FAILURE" - - -class IndexModelStatus(str, PyEnum): - PAST = "PAST" - PRESENT = "PRESENT" - FUTURE = "FUTURE" - - -class ChatSessionSharedStatus(str, PyEnum): - PUBLIC = "public" - PRIVATE = "private" - - class Base(DeclarativeBase): pass @@ -596,6 +569,20 @@ class ChatSession(Base): Enum(ChatSessionSharedStatus, native_enum=False), default=ChatSessionSharedStatus.PRIVATE, ) + + # the latest "overrides" specified by the user. These take precedence over + # the attached persona. However, overrides specified directly in the + # `send-message` call will take precedence over these. + # NOTE: currently only used by the chat seeding flow, will be used in the + # future once we allow users to override default values via the Chat UI + # itself + llm_override: Mapped[LLMOverride | None] = mapped_column( + PydanticType(LLMOverride), nullable=True + ) + prompt_override: Mapped[PromptOverride | None] = mapped_column( + PydanticType(PromptOverride), nullable=True + ) + time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), diff --git a/backend/danswer/db/pydantic_type.py b/backend/danswer/db/pydantic_type.py new file mode 100644 index 0000000000..1f37152a85 --- /dev/null +++ b/backend/danswer/db/pydantic_type.py @@ -0,0 +1,32 @@ +import json +from typing import Any +from typing import Optional +from typing import Type + +from pydantic import BaseModel +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.types import TypeDecorator + + +class PydanticType(TypeDecorator): + impl = JSONB + + def __init__( + self, pydantic_model: Type[BaseModel], *args: Any, **kwargs: Any + ) -> None: + super().__init__(*args, **kwargs) + self.pydantic_model = pydantic_model + + def process_bind_param( + self, value: Optional[BaseModel], dialect: Any + ) -> Optional[dict]: + if value is not None: + return json.loads(value.json()) + return None + + def process_result_value( + self, value: Optional[dict], dialect: Any + ) -> Optional[BaseModel]: + if value is not None: + return self.pydantic_model.parse_obj(value) + return None diff --git a/backend/danswer/llm/answering/models.py b/backend/danswer/llm/answering/models.py index f7f2bbad99..71ea66661a 100644 --- a/backend/danswer/llm/answering/models.py +++ b/backend/danswer/llm/answering/models.py @@ -10,9 +10,9 @@ from danswer.chat.models import AnswerQuestionStreamReturn from danswer.configs.constants import MessageType from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER +from danswer.llm.override_models import LLMOverride +from danswer.llm.override_models import PromptOverride from danswer.llm.utils import get_default_llm_version -from danswer.server.query_and_chat.models import LLMOverride -from danswer.server.query_and_chat.models import PromptOverride if TYPE_CHECKING: from danswer.db.models import ChatMessage diff --git a/backend/danswer/llm/override_models.py b/backend/danswer/llm/override_models.py new file mode 100644 index 0000000000..1ecb3192f0 --- /dev/null +++ b/backend/danswer/llm/override_models.py @@ -0,0 +1,17 @@ +"""Overrides sent over the wire / stored in the DB + +NOTE: these models are used in many places, so have to be +kepy in a separate file to avoid circular imports. +""" +from pydantic import BaseModel + + +class LLMOverride(BaseModel): + model_provider: str | None = None + model_version: str | None = None + temperature: float | None = None + + +class PromptOverride(BaseModel): + system_prompt: str | None = None + task_prompt: str | None = None diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 19b7db23b3..52d6414e22 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -8,12 +8,16 @@ from danswer.auth.users import current_user from danswer.chat.chat_utils import create_chat_chain from danswer.chat.process_message import stream_chat_message +from danswer.configs.app_configs import WEB_DOMAIN +from danswer.configs.constants import MessageType from danswer.db.chat import create_chat_session +from danswer.db.chat import create_new_chat_message from danswer.db.chat import delete_chat_session from danswer.db.chat import get_chat_message from danswer.db.chat import get_chat_messages_by_session from danswer.db.chat import get_chat_session_by_id from danswer.db.chat import get_chat_sessions_by_user +from danswer.db.chat import get_or_create_root_message from danswer.db.chat import get_persona_by_id from danswer.db.chat import set_as_latest_chat_message from danswer.db.chat import translate_db_message_to_chat_message_detail @@ -27,6 +31,7 @@ from danswer.llm.answering.prompts.citations_prompt import ( compute_max_document_tokens_for_persona, ) +from danswer.llm.utils import get_default_llm_tokenizer from danswer.secondary_llm_flows.chat_session_naming import ( get_renamed_conversation_name, ) @@ -40,6 +45,8 @@ from danswer.server.query_and_chat.models import ChatSessionUpdateRequest from danswer.server.query_and_chat.models import CreateChatMessageRequest from danswer.server.query_and_chat.models import CreateChatSessionID +from danswer.server.query_and_chat.models import LLMOverride +from danswer.server.query_and_chat.models import PromptOverride from danswer.server.query_and_chat.models import RenameChatSessionResponse from danswer.server.query_and_chat.models import SearchFeedbackRequest from danswer.utils.logger import setup_logger @@ -93,6 +100,13 @@ def get_chat_session( except ValueError: raise ValueError("Chat session does not exist or has been deleted") + # for chat-seeding: if the session is unassigned, assign it now. This is done here + # to avoid another back and forth between FE -> BE before starting the first + # message generation + if chat_session.user_id is None and user_id is not None: + chat_session.user_id = user_id + db_session.commit() + session_messages = get_chat_messages_by_session( chat_session_id=session_id, user_id=user_id, db_session=db_session ) @@ -209,15 +223,24 @@ def handle_new_chat_message( - Sending a new message in the session - Regenerating a message in the session (just send the same one again) - Editing a message (similar to regenerating but sending a different message) + - Kicking off a seeded chat session (set `use_existing_user_message`) To avoid extra overhead/latency, this assumes (and checks) that previous messages on the path have already been set as latest""" logger.info(f"Received new chat message: {chat_message_req.message}") - if not chat_message_req.message and chat_message_req.prompt_id is not None: + if ( + not chat_message_req.message + and chat_message_req.prompt_id is not None + and not chat_message_req.use_existing_user_message + ): raise HTTPException(status_code=400, detail="Empty chat message is invalid") - packets = stream_chat_message(new_msg_req=chat_message_req, user=user) + packets = stream_chat_message( + new_msg_req=chat_message_req, + user=user, + use_existing_user_message=chat_message_req.use_existing_user_message, + ) return StreamingResponse(packets, media_type="application/json") @@ -308,3 +331,71 @@ def get_max_document_tokens( return MaxSelectedDocumentTokens( max_tokens=compute_max_document_tokens_for_persona(persona), ) + + +"""Endpoints for chat seeding""" + + +class ChatSeedRequest(BaseModel): + # standard chat session stuff + persona_id: int + prompt_id: int | None = None + + # overrides / seeding + llm_override: LLMOverride | None = None + prompt_override: PromptOverride | None = None + description: str | None = None + message: str | None = None + + # TODO: support this + # initial_message_retrieval_options: RetrievalDetails | None = None + + +class ChatSeedResponse(BaseModel): + redirect_url: str + + +@router.post("/seed-chat-session") +def seed_chat( + chat_seed_request: ChatSeedRequest, + # NOTE: realistically, this will be an API key not an actual user + _: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> ChatSeedResponse: + try: + new_chat_session = create_chat_session( + db_session=db_session, + description=chat_seed_request.description or "", + user_id=None, # this chat session is "unassigned" until a user visits the web UI + persona_id=chat_seed_request.persona_id, + llm_override=chat_seed_request.llm_override, + prompt_override=chat_seed_request.prompt_override, + ) + except Exception as e: + logger.exception(e) + raise HTTPException(status_code=400, detail="Invalid Persona provided.") + + if chat_seed_request.message is not None: + root_message = get_or_create_root_message( + chat_session_id=new_chat_session.id, db_session=db_session + ) + create_new_chat_message( + chat_session_id=new_chat_session.id, + parent_message=root_message, + prompt_id=chat_seed_request.prompt_id + or ( + new_chat_session.persona.prompts[0].id + if new_chat_session.persona.prompts + else None + ), + message=chat_seed_request.message, + token_count=len( + get_default_llm_tokenizer().encode(chat_seed_request.message) + ), + message_type=MessageType.USER, + db_session=db_session, + ) + + return ChatSeedResponse( + redirect_url=f"{WEB_DOMAIN}/chat?chatId={new_chat_session.id}" + ) diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 8f86814772..90be759ad7 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -8,7 +8,9 @@ from danswer.configs.constants import DocumentSource from danswer.configs.constants import MessageType from danswer.configs.constants import SearchFeedbackType -from danswer.db.models import ChatSessionSharedStatus +from danswer.db.enums import ChatSessionSharedStatus +from danswer.llm.override_models import LLMOverride +from danswer.llm.override_models import PromptOverride from danswer.search.models import BaseFilters from danswer.search.models import RetrievalDetails from danswer.search.models import SearchDoc @@ -70,17 +72,6 @@ class DocumentSearchRequest(BaseModel): skip_llm_chunk_filter: bool | None = None -class LLMOverride(BaseModel): - model_provider: str | None = None - model_version: str | None = None - temperature: float | None = None - - -class PromptOverride(BaseModel): - system_prompt: str | None = None - task_prompt: str | None = None - - """ Currently the different branches are generated by changing the search query @@ -116,6 +107,9 @@ class CreateChatMessageRequest(BaseModel): llm_override: LLMOverride | None = None prompt_override: PromptOverride | None = None + # used for seeded chats to kick off the generation of an AI answer + use_existing_user_message: bool = False + @root_validator def check_search_doc_ids_or_retrieval_options(cls: BaseModel, values: dict) -> dict: search_doc_ids, retrieval_options = values.get("search_doc_ids"), values.get( diff --git a/web/src/app/chat/Chat.tsx b/web/src/app/chat/Chat.tsx index cd2f1c4dac..7c174f18d5 100644 --- a/web/src/app/chat/Chat.tsx +++ b/web/src/app/chat/Chat.tsx @@ -133,7 +133,7 @@ export const Chat = ({ !submitOnLoadPerformed.current ) { submitOnLoadPerformed.current = true; - onSubmit(); + await onSubmit(); } return; @@ -162,6 +162,21 @@ export const Chat = ({ setChatSessionSharedStatus(chatSession.shared_status); setIsFetchingChatMessages(false); + + // if this is a seeded chat, then kick off the AI message generation + if (newMessageHistory.length === 1 && !submitOnLoadPerformed.current) { + submitOnLoadPerformed.current = true; + const seededMessage = newMessageHistory[0].message; + await onSubmit({ + isSeededChat: true, + messageOverride: seededMessage, + }); + // force re-name if the chat session doesn't have one + if (!chatSession.description) { + await nameChatSession(existingChatSessionId, seededMessage); + router.refresh(); // need to refresh to update name on sidebar + } + } } initialSessionFetch(); @@ -326,11 +341,13 @@ export const Chat = ({ messageOverride, queryOverride, forceSearch, + isSeededChat, }: { messageIdToResend?: number; messageOverride?: string; queryOverride?: string; forceSearch?: boolean; + isSeededChat?: boolean; } = {}) => { let currChatSessionId: number; let isNewSession = chatSessionId === null; @@ -419,6 +436,7 @@ export const Chat = ({ undefined, systemPromptOverride: searchParams.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined, + useExistingUserMessage: isSeededChat, })) { for (const packet of packetBunch) { if (Object.hasOwn(packet, "answer_piece")) { diff --git a/web/src/app/chat/lib.tsx b/web/src/app/chat/lib.tsx index c5195abf83..29a90526cd 100644 --- a/web/src/app/chat/lib.tsx +++ b/web/src/app/chat/lib.tsx @@ -57,6 +57,7 @@ export async function* sendMessage({ modelVersion, temperature, systemPromptOverride, + useExistingUserMessage, }: { message: string; parentMessageId: number | null; @@ -71,6 +72,9 @@ export async function* sendMessage({ temperature?: number; // prompt overrides systemPromptOverride?: string; + // if specified, will use the existing latest user message + // and will ignore the specified `message` + useExistingUserMessage?: boolean; }) { const documentsAreSelected = selectedDocumentIds && selectedDocumentIds.length > 0; @@ -99,13 +103,19 @@ export async function* sendMessage({ } : null, query_override: queryOverride, - prompt_override: { - system_prompt: systemPromptOverride, - }, - llm_override: { - temperature, - model_version: modelVersion, - }, + prompt_override: systemPromptOverride + ? { + system_prompt: systemPromptOverride, + } + : null, + llm_override: + temperature || modelVersion + ? { + temperature, + model_version: modelVersion, + } + : null, + use_existing_user_message: useExistingUserMessage, }), }); if (!sendMessageResponse.ok) { From 447791b45500c615bd5862631907932b09309232 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Thu, 4 Apr 2024 20:43:24 -0700 Subject: [PATCH 49/58] Token budgets (#1302) --------- Co-authored-by: Nick Donohue --- backend/danswer/configs/constants.py | 4 + .../danswer/server/manage/administrative.py | 39 ++++ .../server/query_and_chat/query_backend.py | 2 + .../server/query_and_chat/token_budget.py | 69 +++++++ web/src/app/admin/keys/openai/page.tsx | 168 +++++++++++++++++- 5 files changed, 279 insertions(+), 3 deletions(-) create mode 100644 backend/danswer/server/query_and_chat/token_budget.py diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 65b9f7945b..b961cdfb39 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -40,6 +40,10 @@ SESSION_KEY = "session" QUERY_EVENT_ID = "query_event_id" LLM_CHUNKS = "llm_chunks" +TOKEN_BUDGET = "token_budget" +TOKEN_BUDGET_TIME_PERIOD = "token_budget_time_period" +ENABLE_TOKEN_BUDGET = "enable_token_budget" +TOKEN_BUDGET_SETTINGS = "token_budget_settings" # For chunking/processing chunks TITLE_SEPARATOR = "\n\r\n" diff --git a/backend/danswer/server/manage/administrative.py b/backend/danswer/server/manage/administrative.py index d3a9c4d3b7..fb3f306f80 100644 --- a/backend/danswer/server/manage/administrative.py +++ b/backend/danswer/server/manage/administrative.py @@ -1,3 +1,4 @@ +import json from collections.abc import Callable from datetime import datetime from datetime import timedelta @@ -5,6 +6,7 @@ from typing import cast from fastapi import APIRouter +from fastapi import Body from fastapi import Depends from fastapi import HTTPException from sqlalchemy.orm import Session @@ -12,8 +14,12 @@ from danswer.auth.users import current_admin_user from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ from danswer.configs.constants import DocumentSource +from danswer.configs.constants import ENABLE_TOKEN_BUDGET from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY from danswer.configs.constants import GEN_AI_DETECTED_MODEL +from danswer.configs.constants import TOKEN_BUDGET +from danswer.configs.constants import TOKEN_BUDGET_SETTINGS +from danswer.configs.constants import TOKEN_BUDGET_TIME_PERIOD from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.db.connector_credential_pair import get_connector_credential_pair @@ -262,3 +268,36 @@ def create_deletion_attempt_for_connector_id( file_store = get_default_file_store(db_session) for file_name in connector.connector_specific_config["file_locations"]: file_store.delete_file(file_name) + + +@router.get("/admin/token-budget-settings") +def get_token_budget_settings(_: User = Depends(current_admin_user)) -> dict: + try: + settings_json = cast( + str, get_dynamic_config_store().load(TOKEN_BUDGET_SETTINGS) + ) + settings = json.loads(settings_json) + return settings + except ConfigNotFoundError: + raise HTTPException(status_code=404, detail="Token budget settings not found.") + + +@router.put("/admin/token-budget-settings") +def update_token_budget_settings( + _: User = Depends(current_admin_user), + enable_token_budget: bool = Body(..., embed=True), + token_budget: int = Body(..., ge=0, embed=True), # Ensure non-negative + token_budget_time_period: int = Body(..., ge=1, embed=True), # Ensure positive +) -> dict[str, str]: + # Prepare the settings as a JSON string + settings_json = json.dumps( + { + ENABLE_TOKEN_BUDGET: enable_token_budget, + TOKEN_BUDGET: token_budget, + TOKEN_BUDGET_TIME_PERIOD: token_budget_time_period, + } + ) + + # Store the settings in the dynamic config store + get_dynamic_config_store().store(TOKEN_BUDGET_SETTINGS, settings_json) + return {"message": "Token budget settings updated successfully."} diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index 5150eb9ce1..79971f381e 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -29,6 +29,7 @@ from danswer.server.query_and_chat.models import SimpleQueryRequest from danswer.server.query_and_chat.models import SourceTag from danswer.server.query_and_chat.models import TagResponse +from danswer.server.query_and_chat.token_budget import check_token_budget from danswer.utils.logger import setup_logger logger = setup_logger() @@ -148,6 +149,7 @@ def stream_query_validation( def get_answer_with_quote( query_request: DirectQARequest, user: User = Depends(current_user), + _: bool = Depends(check_token_budget), ) -> StreamingResponse: query = query_request.messages[0].message logger.info(f"Received query for one shot answer with quotes: {query}") diff --git a/backend/danswer/server/query_and_chat/token_budget.py b/backend/danswer/server/query_and_chat/token_budget.py new file mode 100644 index 0000000000..b35c8fece8 --- /dev/null +++ b/backend/danswer/server/query_and_chat/token_budget.py @@ -0,0 +1,69 @@ +import json +from datetime import datetime +from datetime import timedelta +from typing import cast + +from fastapi import HTTPException +from sqlalchemy import func +from sqlalchemy.orm import Session + +from danswer.configs.constants import ENABLE_TOKEN_BUDGET +from danswer.configs.constants import TOKEN_BUDGET +from danswer.configs.constants import TOKEN_BUDGET_SETTINGS +from danswer.configs.constants import TOKEN_BUDGET_TIME_PERIOD +from danswer.db.engine import get_session_context_manager +from danswer.db.models import ChatMessage +from danswer.dynamic_configs.factory import get_dynamic_config_store + +BUDGET_LIMIT_DEFAULT = -1 # Default to no limit +TIME_PERIOD_HOURS_DEFAULT = 12 + + +def is_under_token_budget(db_session: Session) -> bool: + settings_json = cast(str, get_dynamic_config_store().load(TOKEN_BUDGET_SETTINGS)) + settings = json.loads(settings_json) + + is_enabled = settings.get(ENABLE_TOKEN_BUDGET, False) + + if not is_enabled: + return True + + budget_limit = settings.get(TOKEN_BUDGET, -1) + + if budget_limit < 0: + return True + + period_hours = settings.get(TOKEN_BUDGET_TIME_PERIOD, TIME_PERIOD_HOURS_DEFAULT) + period_start_time = datetime.now() - timedelta(hours=period_hours) + + # Fetch the sum of all tokens used within the period + token_sum = ( + db_session.query(func.sum(ChatMessage.token_count)) + .filter(ChatMessage.time_sent >= period_start_time) + .scalar() + or 0 + ) + + print( + "token_sum:", + token_sum, + "budget_limit:", + budget_limit, + "period_hours:", + period_hours, + "period_start_time:", + period_start_time, + ) + + return token_sum < ( + budget_limit * 1000 + ) # Budget limit is expressed in thousands of tokens + + +def check_token_budget() -> None: + with get_session_context_manager() as db_session: + # Perform the token budget check here, possibly using `user` and `db_session` for database access if needed + if not is_under_token_budget(db_session): + raise HTTPException( + status_code=429, detail="Sorry, token budget exceeded. Try again later." + ) diff --git a/web/src/app/admin/keys/openai/page.tsx b/web/src/app/admin/keys/openai/page.tsx index 70497f7199..31d3f3b0cd 100644 --- a/web/src/app/admin/keys/openai/page.tsx +++ b/web/src/app/admin/keys/openai/page.tsx @@ -1,12 +1,20 @@ "use client"; +import { Form, Formik } from "formik"; +import { useEffect, useState } from "react"; import { LoadingAnimation } from "@/components/Loading"; import { AdminPageTitle } from "@/components/admin/Title"; -import { KeyIcon, TrashIcon } from "@/components/icons/icons"; +import { + BooleanFormField, + SectionHeader, + TextFormField, +} from "@/components/admin/connectors/Field"; +import { Popup } from "@/components/admin/connectors/Popup"; +import { TrashIcon } from "@/components/icons/icons"; import { ApiKeyForm } from "@/components/openai/ApiKeyForm"; import { GEN_AI_API_KEY_URL } from "@/components/openai/constants"; import { fetcher } from "@/lib/fetcher"; -import { Text, Title } from "@tremor/react"; +import { Button, Divider, Text, Title } from "@tremor/react"; import { FiCpu } from "react-icons/fi"; import useSWR, { mutate } from "swr"; @@ -49,14 +57,167 @@ const ExistingKeys = () => { ); }; +const LLMOptions = () => { + const [popup, setPopup] = useState<{ + message: string; + type: "success" | "error"; + } | null>(null); + + const [initialValues, setInitialValues] = useState({ + enable_token_budget: false, + token_budget: "", + token_budget_time_period: "", + }); + + const fetchConfig = async () => { + const response = await fetch("/api/manage/admin/token-budget-settings"); + if (response.ok) { + const config = await response.json(); + // Assuming the config object directly matches the structure needed for initialValues + setInitialValues({ + enable_token_budget: config.enable_token_budget || false, + token_budget: config.token_budget || "", + token_budget_time_period: config.token_budget_time_period || "", + }); + } else { + // Handle error or provide fallback values + setPopup({ + message: "Failed to load current LLM options.", + type: "error", + }); + } + }; + + // Fetch current config when the component mounts + useEffect(() => { + fetchConfig(); + }, []); + + return ( + <> + {popup && } + { + const response = await fetch( + "/api/manage/admin/token-budget-settings", + { + method: "PUT", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(values), + } + ); + if (response.ok) { + setPopup({ + message: "Updated LLM Options", + type: "success", + }); + await fetchConfig(); + } else { + const body = await response.json(); + if (body.detail) { + setPopup({ message: body.detail, type: "error" }); + } else { + setPopup({ + message: "Unable to update LLM options.", + type: "error", + }); + } + setTimeout(() => { + setPopup(null); + }, 4000); + } + }} + > + {({ isSubmitting, values, setFieldValue, setValues }) => { + return ( +
+ + <> + Token Budget + + Set a maximum token use per time period. If the token budget + is exceeded, the persona will not be able to respond to + queries until the next time period. + +
+ { + setFieldValue("enable_token_budget", e.target.checked); + }} + /> + {values.enable_token_budget && ( + <> + + How many tokens (in thousands) can be used per time + period? If unspecified, no limit will be set. +
+ } + onChange={(e) => { + const value = e.target.value; + // Allow only integer values + if (value === "" || /^[0-9]+$/.test(value)) { + setFieldValue("token_budget", value); + } + }} + /> + + Specify the length of the time period, in hours, over + which the token budget will be applied. +
+ } + onChange={(e) => { + const value = e.target.value; + // Allow only integer values + if (value === "" || /^[0-9]+$/.test(value)) { + setFieldValue("token_budget_time_period", value); + } + }} + /> + + )} + +
+ +
+ + ); + }} + + + ); +}; + const Page = () => { return (
} /> + LLM Keys + Update Key @@ -72,6 +233,7 @@ const Page = () => { }} />
+
); }; From eb367de44d264a8c3a633919f504994a3a087e93 Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 4 Apr 2024 20:53:02 -0700 Subject: [PATCH 50/58] Small token budget tweaks --- backend/danswer/configs/app_configs.py | 3 +++ backend/danswer/server/manage/administrative.py | 6 ++++++ .../danswer/server/query_and_chat/token_budget.py | 4 ++++ deployment/docker_compose/docker-compose.dev.yml | 2 ++ web/src/app/admin/keys/openai/page.tsx | 15 +++++++++++---- 5 files changed, 26 insertions(+), 4 deletions(-) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 08ac2fc23d..7556f56460 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -247,3 +247,6 @@ DISABLE_TELEMETRY = os.environ.get("DISABLE_TELEMETRY", "").lower() == "true" # notset, debug, info, warning, error, or critical LOG_LEVEL = os.environ.get("LOG_LEVEL", "info") +TOKEN_BUDGET_GLOBALLY_ENABLED = ( + os.environ.get("TOKEN_BUDGET_GLOBALLY_ENABLED", "").lower() == "true" +) diff --git a/backend/danswer/server/manage/administrative.py b/backend/danswer/server/manage/administrative.py index fb3f306f80..02d980b04e 100644 --- a/backend/danswer/server/manage/administrative.py +++ b/backend/danswer/server/manage/administrative.py @@ -13,6 +13,7 @@ from danswer.auth.users import current_admin_user from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ +from danswer.configs.app_configs import TOKEN_BUDGET_GLOBALLY_ENABLED from danswer.configs.constants import DocumentSource from danswer.configs.constants import ENABLE_TOKEN_BUDGET from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY @@ -272,6 +273,11 @@ def create_deletion_attempt_for_connector_id( @router.get("/admin/token-budget-settings") def get_token_budget_settings(_: User = Depends(current_admin_user)) -> dict: + if not TOKEN_BUDGET_GLOBALLY_ENABLED: + raise HTTPException( + status_code=400, detail="Token budget is not enabled in the application." + ) + try: settings_json = cast( str, get_dynamic_config_store().load(TOKEN_BUDGET_SETTINGS) diff --git a/backend/danswer/server/query_and_chat/token_budget.py b/backend/danswer/server/query_and_chat/token_budget.py index b35c8fece8..1d1238c527 100644 --- a/backend/danswer/server/query_and_chat/token_budget.py +++ b/backend/danswer/server/query_and_chat/token_budget.py @@ -7,6 +7,7 @@ from sqlalchemy import func from sqlalchemy.orm import Session +from danswer.configs.app_configs import TOKEN_BUDGET_GLOBALLY_ENABLED from danswer.configs.constants import ENABLE_TOKEN_BUDGET from danswer.configs.constants import TOKEN_BUDGET from danswer.configs.constants import TOKEN_BUDGET_SETTINGS @@ -61,6 +62,9 @@ def is_under_token_budget(db_session: Session) -> bool: def check_token_budget() -> None: + if not TOKEN_BUDGET_GLOBALLY_ENABLED: + return None + with get_session_context_manager() as db_session: # Perform the token budget check here, possibly using `user` and `db_session` for database access if needed if not is_under_token_budget(db_session): diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index a6d7ef2946..d33981c7cd 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -44,6 +44,8 @@ services: - DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-} - DISABLE_LLM_QUERY_REPHRASE=${DISABLE_LLM_QUERY_REPHRASE:-} - DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-} + # if set, allows for the use of the token budget system + - TOKEN_BUDGET_GLOBALLY_ENABLED=${TOKEN_BUDGET_GLOBALLY_ENABLED:-} # Enables the use of bedrock models - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-} - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-} diff --git a/web/src/app/admin/keys/openai/page.tsx b/web/src/app/admin/keys/openai/page.tsx index 31d3f3b0cd..0d122e80ff 100644 --- a/web/src/app/admin/keys/openai/page.tsx +++ b/web/src/app/admin/keys/openai/page.tsx @@ -63,6 +63,8 @@ const LLMOptions = () => { type: "success" | "error"; } | null>(null); + const [tokenBudgetGloballyEnabled, setTokenBudgetGloballyEnabled] = + useState(false); const [initialValues, setInitialValues] = useState({ enable_token_budget: false, token_budget: "", @@ -79,6 +81,7 @@ const LLMOptions = () => { token_budget: config.token_budget || "", token_budget_time_period: config.token_budget_time_period || "", }); + setTokenBudgetGloballyEnabled(true); } else { // Handle error or provide fallback values setPopup({ @@ -93,6 +96,10 @@ const LLMOptions = () => { fetchConfig(); }, []); + if (!tokenBudgetGloballyEnabled) { + return null; + } + return ( <> {popup && } @@ -132,7 +139,7 @@ const LLMOptions = () => { } }} > - {({ isSubmitting, values, setFieldValue, setValues }) => { + {({ isSubmitting, values, setFieldValue }) => { return (
@@ -140,14 +147,14 @@ const LLMOptions = () => { Token Budget Set a maximum token use per time period. If the token budget - is exceeded, the persona will not be able to respond to - queries until the next time period. + is exceeded, Danswer will not be able to respond to queries + until the next time period.
{ setFieldValue("enable_token_budget", e.target.checked); }} From 795243283d1d5ae639abb703d7a984d9bd0202f7 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Sun, 7 Apr 2024 14:30:26 -0700 Subject: [PATCH 51/58] Update README.md Remove 'Danswer is the ChatGPT for teams' --- README.md | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 3e70e7259c..edd8328c31 100644 --- a/README.md +++ b/README.md @@ -22,11 +22,12 @@

-[Danswer](https://www.danswer.ai/) is the ChatGPT for teams. Danswer provides a Chat interface and plugs into any LLM of -your choice. Danswer can be deployed anywhere and for any scale - on a laptop, on-premise, or to cloud. Since you own -the deployment, your user data and chats are fully in your own control. Danswer is MIT licensed and designed to be -modular and easily extensible. The system also comes fully ready for production usage with user authentication, role -management (admin/basic users), chat persistence, and a UI for configuring Personas (AI Assistants) and their Prompts. +[Danswer](https://www.danswer.ai/) is the AI Assistant connected to your company's docs, apps, and people. +Danswer provides a Chat interface and plugs into any LLM of your choice. Danswer can be deployed anywhere and for any +scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your +own control. Danswer is MIT licensed and designed to be modular and easily extensible. The system also comes fully ready +for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for +configuring Personas (AI Assistants) and their Prompts. Danswer also serves as a Unified Search across all common workplace tools such as Slack, Google Drive, Confluence, etc. By combining LLMs and team specific knowledge, Danswer becomes a subject matter expert for the team. Imagine ChatGPT if From 2db906b7a28c3366852c3166f1e038adb43f1ff0 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sun, 7 Apr 2024 21:25:06 -0700 Subject: [PATCH 52/58] Always Use Model Server (#1306) --- .github/workflows/pr-python-checks.yml | 2 + CONTRIBUTING.md | 15 +- backend/Dockerfile | 2 +- backend/Dockerfile.model_server | 5 +- .../danswer/background/indexing/job_client.py | 7 +- .../background/indexing/run_indexing.py | 7 +- backend/danswer/background/update.py | 45 ++- backend/danswer/configs/app_configs.py | 10 +- backend/danswer/configs/model_configs.py | 22 +- .../slack/handlers/handle_message.py | 2 +- backend/danswer/danswerbot/slack/listener.py | 10 +- backend/danswer/indexing/embedder.py | 42 ++- backend/danswer/main.py | 53 ++- backend/danswer/search/enums.py | 5 + backend/danswer/search/models.py | 2 +- .../search/preprocessing/preprocessing.py | 2 +- .../danswer/search/retrieval/search_runner.py | 2 +- backend/danswer/search/search_nlp_models.py | 322 ++++-------------- backend/danswer/utils/batching.py | 7 + backend/model_server/constants.py | 1 + backend/model_server/custom_models.py | 68 +++- backend/model_server/encoders.py | 91 +++-- backend/model_server/main.py | 51 ++- backend/model_server/utils.py | 41 +++ backend/requirements/default.txt | 7 - backend/requirements/model_server.txt | 8 +- .../__init__.py | 0 .../model_server_models.py | 2 + backend/shared_configs/nlp_model_configs.py | 26 ++ .../docker_compose/docker-compose.dev.yml | 104 ++++-- .../docker-compose.prod-no-letsencrypt.yml | 95 ++++-- .../docker_compose/docker-compose.prod.yml | 95 ++++-- deployment/kubernetes/env-configmap.yaml | 4 +- ...exing_model_server-service-deployment.yaml | 59 ++++ ...rence_model_server-service-deployment.yaml | 56 +++ 35 files changed, 722 insertions(+), 548 deletions(-) create mode 100644 backend/model_server/constants.py create mode 100644 backend/model_server/utils.py rename backend/{shared_models => shared_configs}/__init__.py (100%) rename backend/{shared_models => shared_configs}/model_server_models.py (79%) create mode 100644 backend/shared_configs/nlp_model_configs.py create mode 100644 deployment/kubernetes/indexing_model_server-service-deployment.yaml create mode 100644 deployment/kubernetes/inference_model_server-service-deployment.yaml diff --git a/.github/workflows/pr-python-checks.yml b/.github/workflows/pr-python-checks.yml index 792fe4d46b..6c604e93d4 100644 --- a/.github/workflows/pr-python-checks.yml +++ b/.github/workflows/pr-python-checks.yml @@ -20,10 +20,12 @@ jobs: cache-dependency-path: | backend/requirements/default.txt backend/requirements/dev.txt + backend/requirements/model_server.txt - run: | python -m pip install --upgrade pip pip install -r backend/requirements/default.txt pip install -r backend/requirements/dev.txt + pip install -r backend/requirements/model_server.txt - name: Run MyPy run: | diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f32f4fff30..7e80baeb2d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -85,6 +85,7 @@ Install the required python dependencies: ```bash pip install -r danswer/backend/requirements/default.txt pip install -r danswer/backend/requirements/dev.txt +pip install -r danswer/backend/requirements/model_server.txt ``` Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend. @@ -117,7 +118,19 @@ To start the frontend, navigate to `danswer/web` and run: npm run dev ``` -The first time running Danswer, you will also need to run the DB migrations for Postgres. +Next, start the model server which runs the local NLP models. +Navigate to `danswer/backend` and run: +```bash +uvicorn model_server.main:app --reload --port 9000 +``` +_For Windows (for compatibility with both PowerShell and Command Prompt):_ +```bash +powershell -Command " + uvicorn model_server.main:app --reload --port 9000 +" +``` + +The first time running Danswer, you will need to run the DB migrations for Postgres. After the first time, this is no longer required unless the DB models change. Navigate to `danswer/backend` and with the venv active, run: diff --git a/backend/Dockerfile b/backend/Dockerfile index d18bd3ecdb..a0b50c53cb 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -40,7 +40,7 @@ RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cma # Set up application files WORKDIR /app COPY ./danswer /app/danswer -COPY ./shared_models /app/shared_models +COPY ./shared_configs /app/shared_configs COPY ./alembic /app/alembic COPY ./alembic.ini /app/alembic.ini COPY supervisord.conf /usr/etc/supervisord.conf diff --git a/backend/Dockerfile.model_server b/backend/Dockerfile.model_server index 624bdd37fc..0eb455c513 100644 --- a/backend/Dockerfile.model_server +++ b/backend/Dockerfile.model_server @@ -25,11 +25,8 @@ COPY ./danswer/utils/telemetry.py /app/danswer/utils/telemetry.py # Place to fetch version information COPY ./danswer/__init__.py /app/danswer/__init__.py -# Shared implementations for running NLP models locally -COPY ./danswer/search/search_nlp_models.py /app/danswer/search/search_nlp_models.py - # Request/Response models -COPY ./shared_models /app/shared_models +COPY ./shared_configs /app/shared_configs # Model Server main code COPY ./model_server /app/model_server diff --git a/backend/danswer/background/indexing/job_client.py b/backend/danswer/background/indexing/job_client.py index d37690627f..6b1344b59f 100644 --- a/backend/danswer/background/indexing/job_client.py +++ b/backend/danswer/background/indexing/job_client.py @@ -6,18 +6,15 @@ https://github.com/celery/celery/issues/7007#issuecomment-1740139367""" from collections.abc import Callable from dataclasses import dataclass +from multiprocessing import Process from typing import Any from typing import Literal from typing import Optional -from typing import TYPE_CHECKING from danswer.utils.logger import setup_logger logger = setup_logger() -if TYPE_CHECKING: - from torch.multiprocessing import Process - JobStatusType = ( Literal["error"] | Literal["finished"] @@ -89,8 +86,6 @@ def _cleanup_completed_jobs(self) -> None: def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None: """NOTE: `pure` arg is needed so this can be a drop in replacement for Dask""" - from torch.multiprocessing import Process - self._cleanup_completed_jobs() if len(self.jobs) >= self.n_workers: logger.debug("No available workers to run job") diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index 6241af6f56..9e8ee6b7fe 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -330,20 +330,15 @@ def _run_indexing( ) -def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None: +def run_indexing_entrypoint(index_attempt_id: int) -> None: """Entrypoint for indexing run when using dask distributed. Wraps the actual logic in a `try` block so that we can catch any exceptions and mark the attempt as failed.""" - import torch - try: # set the indexing attempt ID so that all log messages from this process # will have it added as a prefix IndexAttemptSingleton.set_index_attempt_id(index_attempt_id) - logger.info(f"Setting task to use {num_threads} threads") - torch.set_num_threads(num_threads) - with Session(get_sqlalchemy_engine()) as db_session: attempt = get_index_attempt( db_session=db_session, index_attempt_id=index_attempt_id diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index a7a20fca30..8d8de8da4c 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -15,9 +15,10 @@ from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP +from danswer.configs.app_configs import INDEXING_MODEL_SERVER_HOST from danswer.configs.app_configs import LOG_LEVEL +from danswer.configs.app_configs import MODEL_SERVER_PORT from danswer.configs.app_configs import NUM_INDEXING_WORKERS -from danswer.configs.model_configs import MIN_THREADS_ML_MODELS from danswer.db.connector import fetch_connectors from danswer.db.connector_credential_pair import get_connector_credential_pairs from danswer.db.connector_credential_pair import mark_all_in_progress_cc_pairs_failed @@ -43,6 +44,7 @@ from danswer.db.models import IndexAttempt from danswer.db.models import IndexingStatus from danswer.db.models import IndexModelStatus +from danswer.search.search_nlp_models import warm_up_encoders from danswer.utils.logger import setup_logger logger = setup_logger() @@ -56,18 +58,6 @@ ) -"""Util funcs""" - - -def _get_num_threads() -> int: - """Get # of "threads" to use for ML models in an indexing job. By default uses - the torch implementation, which returns the # of physical cores on the machine. - """ - import torch - - return max(MIN_THREADS_ML_MODELS, torch.get_num_threads()) - - def _should_create_new_indexing( connector: Connector, last_index: IndexAttempt | None, @@ -346,12 +336,10 @@ def kickoff_indexing_jobs( if use_secondary_index: run = secondary_client.submit( - run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False + run_indexing_entrypoint, attempt.id, pure=False ) else: - run = client.submit( - run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False - ) + run = client.submit(run_indexing_entrypoint, attempt.id, pure=False) if run: secondary_str = "(secondary index) " if use_secondary_index else "" @@ -409,6 +397,20 @@ def check_index_swap(db_session: Session) -> None: def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None: + engine = get_sqlalchemy_engine() + with Session(engine) as db_session: + db_embedding_model = get_current_db_embedding_model(db_session) + + # So that the first time users aren't surprised by really slow speed of first + # batch of documents indexed + logger.info("Running a first inference to warm up embedding model") + warm_up_encoders( + model_name=db_embedding_model.model_name, + normalize=db_embedding_model.normalize, + model_server_host=INDEXING_MODEL_SERVER_HOST, + model_server_port=MODEL_SERVER_PORT, + ) + client_primary: Client | SimpleJobClient client_secondary: Client | SimpleJobClient if DASK_JOB_CLIENT_ENABLED: @@ -435,7 +437,6 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non client_secondary = SimpleJobClient(n_workers=num_workers) existing_jobs: dict[int, Future | SimpleJob] = {} - engine = get_sqlalchemy_engine() with Session(engine) as db_session: # Previous version did not always clean up cc-pairs well leaving some connectors undeleteable @@ -472,14 +473,6 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non def update__main() -> None: - # needed for CUDA to work with multiprocessing - # NOTE: needs to be done on application startup - # before any other torch code has been run - import torch - - if not DASK_JOB_CLIENT_ENABLED: - torch.multiprocessing.set_start_method("spawn") - logger.info("Starting Indexing Loop") update_loop() diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 7556f56460..c6c697e89c 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -207,15 +207,11 @@ ##### # Model Server Configs ##### -# If MODEL_SERVER_HOST is set, the NLP models required for Danswer are offloaded to the server via -# requests. Be sure to include the scheme in the MODEL_SERVER_HOST value. -MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or None +MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or "localhost" MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0" MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000") - -# specify this env variable directly to have a different model server for the background -# indexing job vs the api server so that background indexing does not effect query-time -# performance +# Model server for indexing should use a separate one to not allow indexing to introduce delay +# for inference INDEXING_MODEL_SERVER_HOST = ( os.environ.get("INDEXING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST ) diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index ce79693731..e0d774c82b 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -37,33 +37,13 @@ ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ") # Purely an optimization, memory limitation consideration BATCH_SIZE_ENCODE_CHUNKS = 8 -# This controls the minimum number of pytorch "threads" to allocate to the embedding -# model. If torch finds more threads on its own, this value is not used. -MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1) - -# Cross Encoder Settings -ENABLE_RERANKING_ASYNC_FLOW = ( - os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true" -) -ENABLE_RERANKING_REAL_TIME_FLOW = ( - os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true" -) -# Only using one for now -CROSS_ENCODER_MODEL_ENSEMBLE = ["mixedbread-ai/mxbai-rerank-xsmall-v1"] -# For score normalizing purposes, only way is to know the expected ranges +# For score display purposes, only way is to know the expected ranges CROSS_ENCODER_RANGE_MAX = 12 CROSS_ENCODER_RANGE_MIN = -12 -CROSS_EMBED_CONTEXT_SIZE = 512 # Unused currently, can't be used with the current default encoder model due to its output range SEARCH_DISTANCE_CUTOFF = 0 -# Intent model max context size -QUERY_MAX_CONTEXT_SIZE = 256 - -# Danswer custom Deep Learning Models -INTENT_MODEL_VERSION = "danswer/intent-model" - ##### # Generative AI Model Configs diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index 33d64d9eff..0886c0c175 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -22,7 +22,6 @@ from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI from danswer.configs.danswerbot_configs import DISABLE_DANSWER_BOT_FILTER_DETECT from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION -from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW from danswer.danswerbot.slack.blocks import build_documents_blocks from danswer.danswerbot.slack.blocks import build_follow_up_block from danswer.danswerbot.slack.blocks import build_qa_response_blocks @@ -52,6 +51,7 @@ from danswer.search.models import OptionalSearchSetting from danswer.search.models import RetrievalDetails from danswer.utils.logger import setup_logger +from shared_configs.nlp_model_configs import ENABLE_RERANKING_ASYNC_FLOW logger_base = setup_logger() diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py index fc7055577c..08aa584111 100644 --- a/backend/danswer/danswerbot/slack/listener.py +++ b/backend/danswer/danswerbot/slack/listener.py @@ -10,10 +10,11 @@ from slack_sdk.socket_mode.response import SocketModeResponse from sqlalchemy.orm import Session +from danswer.configs.app_configs import MODEL_SERVER_HOST +from danswer.configs.app_configs import MODEL_SERVER_PORT from danswer.configs.constants import MessageType from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER -from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID @@ -43,7 +44,7 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.one_shot_answer.models import ThreadMessage -from danswer.search.search_nlp_models import warm_up_models +from danswer.search.search_nlp_models import warm_up_encoders from danswer.server.manage.models import SlackBotTokens from danswer.utils.logger import setup_logger @@ -390,10 +391,11 @@ def _initialize_socket_client(socket_client: SocketModeClient) -> None: with Session(get_sqlalchemy_engine()) as db_session: embedding_model = get_current_db_embedding_model(db_session) - warm_up_models( + warm_up_encoders( model_name=embedding_model.model_name, normalize=embedding_model.normalize, - skip_cross_encoders=not ENABLE_RERANKING_ASYNC_FLOW, + model_server_host=MODEL_SERVER_HOST, + model_server_port=MODEL_SERVER_PORT, ) slack_bot_tokens = latest_slack_bot_tokens diff --git a/backend/danswer/indexing/embedder.py b/backend/danswer/indexing/embedder.py index 3be10f5b41..2017265777 100644 --- a/backend/danswer/indexing/embedder.py +++ b/backend/danswer/indexing/embedder.py @@ -16,8 +16,9 @@ from danswer.indexing.models import ChunkEmbedding from danswer.indexing.models import DocAwareChunk from danswer.indexing.models import IndexChunk +from danswer.search.enums import EmbedTextType from danswer.search.search_nlp_models import EmbeddingModel -from danswer.search.search_nlp_models import EmbedTextType +from danswer.utils.batching import batch_list from danswer.utils.logger import setup_logger @@ -73,6 +74,8 @@ def embed_chunks( title_embed_dict: dict[str, list[float]] = {} embedded_chunks: list[IndexChunk] = [] + # Create Mini Chunks for more precise matching of details + # Off by default with unedited settings chunk_texts = [] chunk_mini_chunks_count = {} for chunk_ind, chunk in enumerate(chunks): @@ -85,23 +88,41 @@ def embed_chunks( chunk_texts.extend(mini_chunk_texts) chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts) - text_batches = [ - chunk_texts[i : i + batch_size] - for i in range(0, len(chunk_texts), batch_size) - ] + # Batching for embedding + text_batches = batch_list(chunk_texts, batch_size) embeddings: list[list[float]] = [] len_text_batches = len(text_batches) for idx, text_batch in enumerate(text_batches, start=1): - logger.debug(f"Embedding text batch {idx} of {len_text_batches}") - # Normalize embeddings is only configured via model_configs.py, be sure to use right value for the set loss + logger.debug(f"Embedding Content Texts batch {idx} of {len_text_batches}") + # Normalize embeddings is only configured via model_configs.py, be sure to use right + # value for the set loss embeddings.extend( self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE) ) - # Replace line above with the line below for easy debugging of indexing flow, skipping the actual model + # Replace line above with the line below for easy debugging of indexing flow + # skipping the actual model # embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))]) + chunk_titles = { + chunk.source_document.get_title_for_document_index() for chunk in chunks + } + chunk_titles.discard(None) + + # Embed Titles in batches + title_batches = batch_list(list(chunk_titles), batch_size) + len_title_batches = len(title_batches) + for ind_batch, title_batch in enumerate(title_batches, start=1): + logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}") + title_embeddings = self.embedding_model.encode( + title_batch, text_type=EmbedTextType.PASSAGE + ) + title_embed_dict.update( + {title: vector for title, vector in zip(title_batch, title_embeddings)} + ) + + # Mapping embeddings to chunks embedding_ind_start = 0 for chunk_ind, chunk in enumerate(chunks): num_embeddings = chunk_mini_chunks_count[chunk_ind] @@ -114,9 +135,12 @@ def embed_chunks( title_embedding = None if title: if title in title_embed_dict: - # Using cached value for speedup + # Using cached value to avoid recalculating for every chunk title_embedding = title_embed_dict[title] else: + logger.error( + "Title had to be embedded separately, this should not happen!" + ) title_embedding = self.embedding_model.encode( [title], text_type=EmbedTextType.PASSAGE )[0] diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 90abab7372..9ce32fe01b 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -1,10 +1,10 @@ +import time from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from typing import Any from typing import cast import nltk # type:ignore -import torch # Import here is fine, API server needs torch anyway and nothing imports main.py import uvicorn from fastapi import APIRouter from fastapi import FastAPI @@ -36,7 +36,6 @@ from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.configs.constants import AuthType -from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW from danswer.configs.model_configs import GEN_AI_API_ENDPOINT from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.db.chat import delete_old_default_personas @@ -54,7 +53,7 @@ from danswer.dynamic_configs.port_configs import port_filesystem_to_postgres from danswer.llm.factory import get_default_llm from danswer.llm.utils import get_default_llm_version -from danswer.search.search_nlp_models import warm_up_models +from danswer.search.search_nlp_models import warm_up_encoders from danswer.server.danswer_api.ingestion import get_danswer_api_key from danswer.server.danswer_api.ingestion import router as danswer_api_router from danswer.server.documents.cc_pair import router as cc_pair_router @@ -82,6 +81,7 @@ from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType from danswer.utils.variable_functionality import fetch_versioned_implementation +from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW logger = setup_logger() @@ -204,24 +204,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: if ENABLE_RERANKING_REAL_TIME_FLOW: logger.info("Reranking step of search flow is enabled.") - if MODEL_SERVER_HOST: - logger.info( - f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}" - ) - else: - logger.info("Warming up local NLP models.") - warm_up_models( - model_name=db_embedding_model.model_name, - normalize=db_embedding_model.normalize, - skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW, - ) - - if torch.cuda.is_available(): - logger.info("GPU is available") - else: - logger.info("GPU is not available") - logger.info(f"Torch Threads: {torch.get_num_threads()}") - logger.info("Verifying query preprocessing (NLTK) data is downloaded") nltk.download("stopwords", quiet=True) nltk.download("wordnet", quiet=True) @@ -237,19 +219,34 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: load_chat_yamls() logger.info("Verifying Document Index(s) is/are available.") - document_index = get_default_document_index( primary_index_name=db_embedding_model.index_name, secondary_index_name=secondary_db_embedding_model.index_name if secondary_db_embedding_model else None, ) - document_index.ensure_indices_exist( - index_embedding_dim=db_embedding_model.model_dim, - secondary_index_embedding_dim=secondary_db_embedding_model.model_dim - if secondary_db_embedding_model - else None, - ) + # Vespa startup is a bit slow, so give it a few seconds + wait_time = 5 + for attempt in range(5): + try: + document_index.ensure_indices_exist( + index_embedding_dim=db_embedding_model.model_dim, + secondary_index_embedding_dim=secondary_db_embedding_model.model_dim + if secondary_db_embedding_model + else None, + ) + break + except Exception: + logger.info(f"Waiting on Vespa, retrying in {wait_time} seconds...") + time.sleep(wait_time) + + logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}") + warm_up_encoders( + model_name=db_embedding_model.model_name, + normalize=db_embedding_model.normalize, + model_server_host=MODEL_SERVER_HOST, + model_server_port=MODEL_SERVER_PORT, + ) optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__}) diff --git a/backend/danswer/search/enums.py b/backend/danswer/search/enums.py index 9ba44ada2c..3990833552 100644 --- a/backend/danswer/search/enums.py +++ b/backend/danswer/search/enums.py @@ -28,3 +28,8 @@ class SearchType(str, Enum): class QueryFlow(str, Enum): SEARCH = "search" QUESTION_ANSWER = "question-answer" + + +class EmbedTextType(str, Enum): + QUERY = "query" + PASSAGE = "passage" diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index c59dbf4dab..9d3eb39b0c 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -8,10 +8,10 @@ from danswer.configs.chat_configs import NUM_RERANKED_RESULTS from danswer.configs.chat_configs import NUM_RETURNED_HITS from danswer.configs.constants import DocumentSource -from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW from danswer.db.models import Persona from danswer.search.enums import OptionalSearchSetting from danswer.search.enums import SearchType +from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW MAX_METRICS_CONTENT = ( diff --git a/backend/danswer/search/preprocessing/preprocessing.py b/backend/danswer/search/preprocessing/preprocessing.py index 4fb2665a83..7da6db4ceb 100644 --- a/backend/danswer/search/preprocessing/preprocessing.py +++ b/backend/danswer/search/preprocessing/preprocessing.py @@ -5,7 +5,6 @@ from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER from danswer.configs.chat_configs import NUM_RETURNED_HITS -from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW from danswer.db.models import User from danswer.search.enums import QueryFlow from danswer.search.enums import RecencyBiasSetting @@ -22,6 +21,7 @@ from danswer.utils.threadpool_concurrency import FunctionCall from danswer.utils.threadpool_concurrency import run_functions_in_parallel from danswer.utils.timing import log_function_time +from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW logger = setup_logger() diff --git a/backend/danswer/search/retrieval/search_runner.py b/backend/danswer/search/retrieval/search_runner.py index 41aa3a3c7e..bb17253921 100644 --- a/backend/danswer/search/retrieval/search_runner.py +++ b/backend/danswer/search/retrieval/search_runner.py @@ -14,6 +14,7 @@ from danswer.db.embedding_model import get_current_db_embedding_model from danswer.document_index.interfaces import DocumentIndex from danswer.indexing.models import InferenceChunk +from danswer.search.enums import EmbedTextType from danswer.search.models import ChunkMetric from danswer.search.models import IndexFilters from danswer.search.models import MAX_METRICS_CONTENT @@ -21,7 +22,6 @@ from danswer.search.models import SearchQuery from danswer.search.models import SearchType from danswer.search.search_nlp_models import EmbeddingModel -from danswer.search.search_nlp_models import EmbedTextType from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion from danswer.utils.logger import setup_logger from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel diff --git a/backend/danswer/search/search_nlp_models.py b/backend/danswer/search/search_nlp_models.py index 50decb92d2..95bd4d0f23 100644 --- a/backend/danswer/search/search_nlp_models.py +++ b/backend/danswer/search/search_nlp_models.py @@ -1,56 +1,38 @@ import gc import os -from enum import Enum +import time from typing import Optional from typing import TYPE_CHECKING -import numpy as np import requests from transformers import logging as transformer_logging # type:ignore from danswer.configs.app_configs import MODEL_SERVER_HOST from danswer.configs.app_configs import MODEL_SERVER_PORT -from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE -from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL -from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW -from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW -from danswer.configs.model_configs import INTENT_MODEL_VERSION -from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE +from danswer.search.enums import EmbedTextType from danswer.utils.logger import setup_logger -from shared_models.model_server_models import EmbedRequest -from shared_models.model_server_models import EmbedResponse -from shared_models.model_server_models import IntentRequest -from shared_models.model_server_models import IntentResponse -from shared_models.model_server_models import RerankRequest -from shared_models.model_server_models import RerankResponse +from shared_configs.model_server_models import EmbedRequest +from shared_configs.model_server_models import EmbedResponse +from shared_configs.model_server_models import IntentRequest +from shared_configs.model_server_models import IntentResponse +from shared_configs.model_server_models import RerankRequest +from shared_configs.model_server_models import RerankResponse +transformer_logging.set_verbosity_error() os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" logger = setup_logger() -transformer_logging.set_verbosity_error() if TYPE_CHECKING: - from sentence_transformers import CrossEncoder # type: ignore - from sentence_transformers import SentenceTransformer # type: ignore from transformers import AutoTokenizer # type: ignore - from transformers import TFDistilBertForSequenceClassification # type: ignore _TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None) -_EMBED_MODEL: tuple[Optional["SentenceTransformer"], str | None] = (None, None) -_RERANK_MODELS: Optional[list["CrossEncoder"]] = None -_INTENT_TOKENIZER: Optional["AutoTokenizer"] = None -_INTENT_MODEL: Optional["TFDistilBertForSequenceClassification"] = None - - -class EmbedTextType(str, Enum): - QUERY = "query" - PASSAGE = "passage" def clean_model_name(model_str: str) -> str: @@ -84,89 +66,10 @@ def get_default_tokenizer(model_name: str | None = None) -> "AutoTokenizer": return _TOKENIZER[0] -def get_local_embedding_model( - model_name: str, - max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE, -) -> "SentenceTransformer": - # NOTE: doing a local import here to avoid reduce memory usage caused by - # processes importing this file despite not using any of this - from sentence_transformers import SentenceTransformer # type: ignore - - global _EMBED_MODEL - if ( - _EMBED_MODEL[0] is None - or max_context_length != _EMBED_MODEL[0].max_seq_length - or model_name != _EMBED_MODEL[1] - ): - if _EMBED_MODEL[0] is not None: - del _EMBED_MODEL - gc.collect() - - logger.info(f"Loading {model_name}") - _EMBED_MODEL = (SentenceTransformer(model_name), model_name) - _EMBED_MODEL[0].max_seq_length = max_context_length - return _EMBED_MODEL[0] - - -def get_local_reranking_model_ensemble( - model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE, - max_context_length: int = CROSS_EMBED_CONTEXT_SIZE, -) -> list["CrossEncoder"]: - # NOTE: doing a local import here to avoid reduce memory usage caused by - # processes importing this file despite not using any of this - from sentence_transformers import CrossEncoder - - global _RERANK_MODELS - if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length: - del _RERANK_MODELS - gc.collect() - - _RERANK_MODELS = [] - for model_name in model_names: - logger.info(f"Loading {model_name}") - model = CrossEncoder(model_name) - model.max_length = max_context_length - _RERANK_MODELS.append(model) - return _RERANK_MODELS - - -def get_intent_model_tokenizer( - model_name: str = INTENT_MODEL_VERSION, -) -> "AutoTokenizer": - # NOTE: doing a local import here to avoid reduce memory usage caused by - # processes importing this file despite not using any of this - from transformers import AutoTokenizer # type: ignore - - global _INTENT_TOKENIZER - if _INTENT_TOKENIZER is None: - _INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name) - return _INTENT_TOKENIZER - - -def get_local_intent_model( - model_name: str = INTENT_MODEL_VERSION, - max_context_length: int = QUERY_MAX_CONTEXT_SIZE, -) -> "TFDistilBertForSequenceClassification": - # NOTE: doing a local import here to avoid reduce memory usage caused by - # processes importing this file despite not using any of this - from transformers import TFDistilBertForSequenceClassification # type: ignore - - global _INTENT_MODEL - if _INTENT_MODEL is None or max_context_length != _INTENT_MODEL.max_seq_length: - _INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained( - model_name - ) - _INTENT_MODEL.max_seq_length = max_context_length - return _INTENT_MODEL - - def build_model_server_url( - model_server_host: str | None, - model_server_port: int | None, -) -> str | None: - if not model_server_host or model_server_port is None: - return None - + model_server_host: str, + model_server_port: int, +) -> str: model_server_url = f"{model_server_host}:{model_server_port}" # use protocol if provided @@ -184,8 +87,8 @@ def __init__( query_prefix: str | None, passage_prefix: str | None, normalize: bool, - server_host: str | None, # Changes depending on indexing or inference - server_port: int | None, + server_host: str, # Changes depending on indexing or inference + server_port: int, # The following are globals are currently not configurable max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE, ) -> None: @@ -196,17 +99,7 @@ def __init__( self.normalize = normalize model_server_url = build_model_server_url(server_host, server_port) - self.embed_server_endpoint = ( - f"{model_server_url}/encoder/bi-encoder-embed" if model_server_url else None - ) - - def load_model(self) -> Optional["SentenceTransformer"]: - if self.embed_server_endpoint: - return None - - return get_local_embedding_model( - model_name=self.model_name, max_context_length=self.max_seq_length - ) + self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed" def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float]]: if text_type == EmbedTextType.QUERY and self.query_prefix: @@ -216,166 +109,67 @@ def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float] else: prefixed_texts = texts - if self.embed_server_endpoint: - embed_request = EmbedRequest( - texts=prefixed_texts, - model_name=self.model_name, - normalize_embeddings=self.normalize, - ) - - try: - response = requests.post( - self.embed_server_endpoint, json=embed_request.dict() - ) - response.raise_for_status() - - return EmbedResponse(**response.json()).embeddings - except requests.RequestException as e: - logger.exception(f"Failed to get Embedding: {e}") - raise - - local_model = self.load_model() + embed_request = EmbedRequest( + texts=prefixed_texts, + model_name=self.model_name, + max_context_length=self.max_seq_length, + normalize_embeddings=self.normalize, + ) - if local_model is None: - raise RuntimeError("Failed to load local Embedding Model") + response = requests.post(self.embed_server_endpoint, json=embed_request.dict()) + response.raise_for_status() - return local_model.encode( - prefixed_texts, normalize_embeddings=self.normalize - ).tolist() + return EmbedResponse(**response.json()).embeddings class CrossEncoderEnsembleModel: def __init__( self, - model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE, - max_seq_length: int = CROSS_EMBED_CONTEXT_SIZE, - model_server_host: str | None = MODEL_SERVER_HOST, + model_server_host: str = MODEL_SERVER_HOST, model_server_port: int = MODEL_SERVER_PORT, ) -> None: - self.model_names = model_names - self.max_seq_length = max_seq_length - model_server_url = build_model_server_url(model_server_host, model_server_port) - self.rerank_server_endpoint = ( - model_server_url + "/encoder/cross-encoder-scores" - if model_server_url - else None - ) - - def load_model(self) -> list["CrossEncoder"] | None: - if ( - ENABLE_RERANKING_REAL_TIME_FLOW is False - and ENABLE_RERANKING_ASYNC_FLOW is False - ): - logger.warning( - "Running rerankers but they are globally disabled." - "Was this specified explicitly via an API?" - ) - - if self.rerank_server_endpoint: - return None - - return get_local_reranking_model_ensemble( - model_names=self.model_names, max_context_length=self.max_seq_length - ) + self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores" def predict(self, query: str, passages: list[str]) -> list[list[float]]: - if self.rerank_server_endpoint: - rerank_request = RerankRequest(query=query, documents=passages) - - try: - response = requests.post( - self.rerank_server_endpoint, json=rerank_request.dict() - ) - response.raise_for_status() + rerank_request = RerankRequest(query=query, documents=passages) - return RerankResponse(**response.json()).scores - except requests.RequestException as e: - logger.exception(f"Failed to get Reranking Scores: {e}") - raise - - local_models = self.load_model() - - if local_models is None: - raise RuntimeError("Failed to load local Reranking Model Ensemble") - - scores = [ - cross_encoder.predict([(query, passage) for passage in passages]).tolist() # type: ignore - for cross_encoder in local_models - ] + response = requests.post( + self.rerank_server_endpoint, json=rerank_request.dict() + ) + response.raise_for_status() - return scores + return RerankResponse(**response.json()).scores class IntentModel: def __init__( self, - model_name: str = INTENT_MODEL_VERSION, - max_seq_length: int = QUERY_MAX_CONTEXT_SIZE, - model_server_host: str | None = MODEL_SERVER_HOST, + model_server_host: str = MODEL_SERVER_HOST, model_server_port: int = MODEL_SERVER_PORT, ) -> None: - self.model_name = model_name - self.max_seq_length = max_seq_length - model_server_url = build_model_server_url(model_server_host, model_server_port) - self.intent_server_endpoint = ( - model_server_url + "/custom/intent-model" if model_server_url else None - ) - - def load_model(self) -> Optional["SentenceTransformer"]: - if self.intent_server_endpoint: - return None - - return get_local_intent_model( - model_name=self.model_name, max_context_length=self.max_seq_length - ) + self.intent_server_endpoint = model_server_url + "/custom/intent-model" def predict( self, query: str, ) -> list[float]: - # NOTE: doing a local import here to avoid reduce memory usage caused by - # processes importing this file despite not using any of this - import tensorflow as tf # type: ignore - - if self.intent_server_endpoint: - intent_request = IntentRequest(query=query) - - try: - response = requests.post( - self.intent_server_endpoint, json=intent_request.dict() - ) - response.raise_for_status() + intent_request = IntentRequest(query=query) - return IntentResponse(**response.json()).class_probs - except requests.RequestException as e: - logger.exception(f"Failed to get Embedding: {e}") - raise - - tokenizer = get_intent_model_tokenizer() - local_model = self.load_model() - - if local_model is None: - raise RuntimeError("Failed to load local Intent Model") - - intent_model = get_local_intent_model() - model_input = tokenizer( - query, return_tensors="tf", truncation=True, padding=True + response = requests.post( + self.intent_server_endpoint, json=intent_request.dict() ) + response.raise_for_status() - predictions = intent_model(model_input)[0] - probabilities = tf.nn.softmax(predictions, axis=-1) - class_percentages = np.round(probabilities.numpy() * 100, 2) + return IntentResponse(**response.json()).class_probs - return list(class_percentages.tolist()[0]) - -def warm_up_models( +def warm_up_encoders( model_name: str, normalize: bool, - skip_cross_encoders: bool = True, - indexer_only: bool = False, + model_server_host: str = MODEL_SERVER_HOST, + model_server_port: int = MODEL_SERVER_PORT, ) -> None: warm_up_str = ( "Danswer is amazing! Check out our easy deployment guide at " @@ -387,23 +181,23 @@ def warm_up_models( embed_model = EmbeddingModel( model_name=model_name, normalize=normalize, - # These don't matter, if it's a remote model, this function shouldn't be called + # Not a big deal if prefix is incorrect query_prefix=None, passage_prefix=None, - server_host=None, - server_port=None, + server_host=model_server_host, + server_port=model_server_port, ) - embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY) - - if indexer_only: - return - - if not skip_cross_encoders: - CrossEncoderEnsembleModel().predict(query=warm_up_str, passages=[warm_up_str]) - - intent_tokenizer = get_intent_model_tokenizer() - inputs = intent_tokenizer( - warm_up_str, return_tensors="tf", truncation=True, padding=True - ) - get_local_intent_model()(inputs) + # First time downloading the models it may take even longer, but just in case, + # retry the whole server + wait_time = 5 + for attempt in range(20): + try: + embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY) + return + except Exception: + logger.info( + f"Failed to run test embedding, retrying in {wait_time} seconds..." + ) + time.sleep(wait_time) + raise Exception("Failed to run test embedding.") diff --git a/backend/danswer/utils/batching.py b/backend/danswer/utils/batching.py index 0200f72250..2ea436e117 100644 --- a/backend/danswer/utils/batching.py +++ b/backend/danswer/utils/batching.py @@ -21,3 +21,10 @@ def batch_generator( if pre_batch_yield: pre_batch_yield(batch) yield batch + + +def batch_list( + lst: list[T], + batch_size: int, +) -> list[list[T]]: + return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)] diff --git a/backend/model_server/constants.py b/backend/model_server/constants.py new file mode 100644 index 0000000000..bc842f5461 --- /dev/null +++ b/backend/model_server/constants.py @@ -0,0 +1 @@ +MODEL_WARM_UP_STRING = "hi " * 512 diff --git a/backend/model_server/custom_models.py b/backend/model_server/custom_models.py index 9faea17ba3..9b8066e96c 100644 --- a/backend/model_server/custom_models.py +++ b/backend/model_server/custom_models.py @@ -1,19 +1,58 @@ +from typing import Optional + import numpy as np +import tensorflow as tf # type: ignore from fastapi import APIRouter +from transformers import AutoTokenizer # type: ignore +from transformers import TFDistilBertForSequenceClassification + +from model_server.constants import MODEL_WARM_UP_STRING +from model_server.utils import simple_log_function_time +from shared_configs.model_server_models import IntentRequest +from shared_configs.model_server_models import IntentResponse +from shared_configs.nlp_model_configs import INDEXING_ONLY +from shared_configs.nlp_model_configs import INTENT_MODEL_CONTEXT_SIZE +from shared_configs.nlp_model_configs import INTENT_MODEL_VERSION -from danswer.search.search_nlp_models import get_intent_model_tokenizer -from danswer.search.search_nlp_models import get_local_intent_model -from danswer.utils.timing import log_function_time -from shared_models.model_server_models import IntentRequest -from shared_models.model_server_models import IntentResponse router = APIRouter(prefix="/custom") +_INTENT_TOKENIZER: Optional[AutoTokenizer] = None +_INTENT_MODEL: Optional[TFDistilBertForSequenceClassification] = None -@log_function_time(print_only=True) -def classify_intent(query: str) -> list[float]: - import tensorflow as tf # type:ignore +def get_intent_model_tokenizer( + model_name: str = INTENT_MODEL_VERSION, +) -> "AutoTokenizer": + global _INTENT_TOKENIZER + if _INTENT_TOKENIZER is None: + _INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name) + return _INTENT_TOKENIZER + + +def get_local_intent_model( + model_name: str = INTENT_MODEL_VERSION, + max_context_length: int = INTENT_MODEL_CONTEXT_SIZE, +) -> TFDistilBertForSequenceClassification: + global _INTENT_MODEL + if _INTENT_MODEL is None or max_context_length != _INTENT_MODEL.max_seq_length: + _INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained( + model_name + ) + _INTENT_MODEL.max_seq_length = max_context_length + return _INTENT_MODEL + + +def warm_up_intent_model() -> None: + intent_tokenizer = get_intent_model_tokenizer() + inputs = intent_tokenizer( + MODEL_WARM_UP_STRING, return_tensors="tf", truncation=True, padding=True + ) + get_local_intent_model()(inputs) + + +@simple_log_function_time() +def classify_intent(query: str) -> list[float]: tokenizer = get_intent_model_tokenizer() intent_model = get_local_intent_model() model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True) @@ -26,16 +65,11 @@ def classify_intent(query: str) -> list[float]: @router.post("/intent-model") -def process_intent_request( +async def process_intent_request( intent_request: IntentRequest, ) -> IntentResponse: + if INDEXING_ONLY: + raise RuntimeError("Indexing model server should not call intent endpoint") + class_percentages = classify_intent(intent_request.query) return IntentResponse(class_probs=class_percentages) - - -def warm_up_intent_model() -> None: - intent_tokenizer = get_intent_model_tokenizer() - inputs = intent_tokenizer( - "danswer", return_tensors="tf", truncation=True, padding=True - ) - get_local_intent_model()(inputs) diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index 1220736dea..f1f3fdf0cf 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -1,34 +1,33 @@ -from typing import TYPE_CHECKING +import gc +from typing import Optional from fastapi import APIRouter from fastapi import HTTPException +from sentence_transformers import CrossEncoder # type: ignore +from sentence_transformers import SentenceTransformer # type: ignore -from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE -from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE -from danswer.search.search_nlp_models import get_local_reranking_model_ensemble from danswer.utils.logger import setup_logger -from danswer.utils.timing import log_function_time -from shared_models.model_server_models import EmbedRequest -from shared_models.model_server_models import EmbedResponse -from shared_models.model_server_models import RerankRequest -from shared_models.model_server_models import RerankResponse - -if TYPE_CHECKING: - from sentence_transformers import SentenceTransformer # type: ignore - +from model_server.constants import MODEL_WARM_UP_STRING +from model_server.utils import simple_log_function_time +from shared_configs.model_server_models import EmbedRequest +from shared_configs.model_server_models import EmbedResponse +from shared_configs.model_server_models import RerankRequest +from shared_configs.model_server_models import RerankResponse +from shared_configs.nlp_model_configs import CROSS_EMBED_CONTEXT_SIZE +from shared_configs.nlp_model_configs import CROSS_ENCODER_MODEL_ENSEMBLE +from shared_configs.nlp_model_configs import INDEXING_ONLY logger = setup_logger() -WARM_UP_STRING = "Danswer is amazing" - router = APIRouter(prefix="/encoder") _GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {} +_RERANK_MODELS: Optional[list["CrossEncoder"]] = None def get_embedding_model( model_name: str, - max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE, + max_context_length: int, ) -> "SentenceTransformer": from sentence_transformers import SentenceTransformer # type: ignore @@ -48,11 +47,44 @@ def get_embedding_model( return _GLOBAL_MODELS_DICT[model_name] -@log_function_time(print_only=True) +def get_local_reranking_model_ensemble( + model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE, + max_context_length: int = CROSS_EMBED_CONTEXT_SIZE, +) -> list[CrossEncoder]: + global _RERANK_MODELS + if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length: + del _RERANK_MODELS + gc.collect() + + _RERANK_MODELS = [] + for model_name in model_names: + logger.info(f"Loading {model_name}") + model = CrossEncoder(model_name) + model.max_length = max_context_length + _RERANK_MODELS.append(model) + return _RERANK_MODELS + + +def warm_up_cross_encoders() -> None: + logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}") + + cross_encoders = get_local_reranking_model_ensemble() + [ + cross_encoder.predict((MODEL_WARM_UP_STRING, MODEL_WARM_UP_STRING)) + for cross_encoder in cross_encoders + ] + + +@simple_log_function_time() def embed_text( - texts: list[str], model_name: str, normalize_embeddings: bool + texts: list[str], + model_name: str, + max_context_length: int, + normalize_embeddings: bool, ) -> list[list[float]]: - model = get_embedding_model(model_name=model_name) + model = get_embedding_model( + model_name=model_name, max_context_length=max_context_length + ) embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings) if not isinstance(embeddings, list): @@ -61,7 +93,7 @@ def embed_text( return embeddings -@log_function_time(print_only=True) +@simple_log_function_time() def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]: cross_encoders = get_local_reranking_model_ensemble() sim_scores = [ @@ -72,13 +104,14 @@ def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]: @router.post("/bi-encoder-embed") -def process_embed_request( +async def process_embed_request( embed_request: EmbedRequest, ) -> EmbedResponse: try: embeddings = embed_text( texts=embed_request.texts, model_name=embed_request.model_name, + max_context_length=embed_request.max_context_length, normalize_embeddings=embed_request.normalize_embeddings, ) return EmbedResponse(embeddings=embeddings) @@ -87,7 +120,11 @@ def process_embed_request( @router.post("/cross-encoder-scores") -def process_rerank_request(embed_request: RerankRequest) -> RerankResponse: +async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse: + """Cross encoders can be purely black box from the app perspective""" + if INDEXING_ONLY: + raise RuntimeError("Indexing model server should not call intent endpoint") + try: sim_scores = calc_sim_scores( query=embed_request.query, docs=embed_request.documents @@ -95,13 +132,3 @@ def process_rerank_request(embed_request: RerankRequest) -> RerankResponse: return RerankResponse(scores=sim_scores) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - - -def warm_up_cross_encoders() -> None: - logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}") - - cross_encoders = get_local_reranking_model_ensemble() - [ - cross_encoder.predict((WARM_UP_STRING, WARM_UP_STRING)) - for cross_encoder in cross_encoders - ] diff --git a/backend/model_server/main.py b/backend/model_server/main.py index dead931dcd..aaac1d0d17 100644 --- a/backend/model_server/main.py +++ b/backend/model_server/main.py @@ -1,39 +1,60 @@ +import os +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager + import torch import uvicorn from fastapi import FastAPI +from transformers import logging as transformer_logging # type:ignore from danswer import __version__ from danswer.configs.app_configs import MODEL_SERVER_ALLOWED_HOST from danswer.configs.app_configs import MODEL_SERVER_PORT -from danswer.configs.model_configs import MIN_THREADS_ML_MODELS from danswer.utils.logger import setup_logger from model_server.custom_models import router as custom_models_router from model_server.custom_models import warm_up_intent_model from model_server.encoders import router as encoders_router from model_server.encoders import warm_up_cross_encoders +from shared_configs.nlp_model_configs import ENABLE_RERANKING_ASYNC_FLOW +from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW +from shared_configs.nlp_model_configs import INDEXING_ONLY +from shared_configs.nlp_model_configs import MIN_THREADS_ML_MODELS + +os.environ["TOKENIZERS_PARALLELISM"] = "false" +os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" +transformer_logging.set_verbosity_error() logger = setup_logger() -def get_model_app() -> FastAPI: - application = FastAPI(title="Danswer Model Server", version=__version__) +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator: + if torch.cuda.is_available(): + logger.info("GPU is available") + else: + logger.info("GPU is not available") - application.include_router(encoders_router) - application.include_router(custom_models_router) + torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads())) + logger.info(f"Torch Threads: {torch.get_num_threads()}") - @application.on_event("startup") - def startup_event() -> None: - if torch.cuda.is_available(): - logger.info("GPU is available") - else: - logger.info("GPU is not available") + if not INDEXING_ONLY: + warm_up_intent_model() + if ENABLE_RERANKING_REAL_TIME_FLOW or ENABLE_RERANKING_ASYNC_FLOW: + warm_up_cross_encoders() + else: + logger.info("This model server should only run document indexing.") - torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads())) - logger.info(f"Torch Threads: {torch.get_num_threads()}") + yield - warm_up_cross_encoders() - warm_up_intent_model() + +def get_model_app() -> FastAPI: + application = FastAPI( + title="Danswer Model Server", version=__version__, lifespan=lifespan + ) + + application.include_router(encoders_router) + application.include_router(custom_models_router) return application diff --git a/backend/model_server/utils.py b/backend/model_server/utils.py new file mode 100644 index 0000000000..3ebae26e5b --- /dev/null +++ b/backend/model_server/utils.py @@ -0,0 +1,41 @@ +import time +from collections.abc import Callable +from collections.abc import Generator +from collections.abc import Iterator +from functools import wraps +from typing import Any +from typing import cast +from typing import TypeVar + +from danswer.utils.logger import setup_logger + +logger = setup_logger() + +F = TypeVar("F", bound=Callable) +FG = TypeVar("FG", bound=Callable[..., Generator | Iterator]) + + +def simple_log_function_time( + func_name: str | None = None, + debug_only: bool = False, + include_args: bool = False, +) -> Callable[[F], F]: + def decorator(func: F) -> F: + @wraps(func) + def wrapped_func(*args: Any, **kwargs: Any) -> Any: + start_time = time.time() + result = func(*args, **kwargs) + elapsed_time_str = str(time.time() - start_time) + log_name = func_name or func.__name__ + args_str = f" args={args} kwargs={kwargs}" if include_args else "" + final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds" + if debug_only: + logger.debug(final_log) + else: + logger.info(final_log) + + return result + + return cast(F, wrapped_func) + + return decorator diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 1d92e8f25e..eab0f89357 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -54,19 +54,12 @@ requests-oauthlib==1.3.1 retry==0.9.2 # This pulls in py which is in CVE-2022-42969, must remove py from image rfc3986==1.5.0 rt==3.1.2 -# need to pin `safetensors` version, since the latest versions requires -# building from source using Rust -safetensors==0.4.2 -sentence-transformers==2.6.1 slack-sdk==3.20.2 SQLAlchemy[mypy]==2.0.15 starlette==0.36.3 supervisor==4.2.5 -tensorflow==2.15.0 tiktoken==0.4.0 timeago==1.0.16 -torch==2.0.1 -torchvision==0.15.2 transformers==4.39.2 uvicorn==0.21.1 zulip==0.8.2 diff --git a/backend/requirements/model_server.txt b/backend/requirements/model_server.txt index 666baabe4c..487e6338d4 100644 --- a/backend/requirements/model_server.txt +++ b/backend/requirements/model_server.txt @@ -1,8 +1,8 @@ -fastapi==0.109.1 +fastapi==0.109.2 pydantic==1.10.7 -safetensors==0.3.1 -sentence-transformers==2.2.2 +safetensors==0.4.2 +sentence-transformers==2.6.1 tensorflow==2.15.0 torch==2.0.1 -transformers==4.36.2 +transformers==4.39.2 uvicorn==0.21.1 diff --git a/backend/shared_models/__init__.py b/backend/shared_configs/__init__.py similarity index 100% rename from backend/shared_models/__init__.py rename to backend/shared_configs/__init__.py diff --git a/backend/shared_models/model_server_models.py b/backend/shared_configs/model_server_models.py similarity index 79% rename from backend/shared_models/model_server_models.py rename to backend/shared_configs/model_server_models.py index e3b04557d2..020a24a30b 100644 --- a/backend/shared_models/model_server_models.py +++ b/backend/shared_configs/model_server_models.py @@ -2,8 +2,10 @@ class EmbedRequest(BaseModel): + # This already includes any prefixes, the text is just passed directly to the model texts: list[str] model_name: str + max_context_length: int normalize_embeddings: bool diff --git a/backend/shared_configs/nlp_model_configs.py b/backend/shared_configs/nlp_model_configs.py new file mode 100644 index 0000000000..cc58a56b0d --- /dev/null +++ b/backend/shared_configs/nlp_model_configs.py @@ -0,0 +1,26 @@ +import os + + +# Danswer custom Deep Learning Models +INTENT_MODEL_VERSION = "danswer/intent-model" +INTENT_MODEL_CONTEXT_SIZE = 256 + +# Bi-Encoder, other details +DOC_EMBEDDING_CONTEXT_SIZE = 512 + +# Cross Encoder Settings +ENABLE_RERANKING_ASYNC_FLOW = ( + os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true" +) +ENABLE_RERANKING_REAL_TIME_FLOW = ( + os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true" +) +# Only using one cross-encoder for now +CROSS_ENCODER_MODEL_ENSEMBLE = ["mixedbread-ai/mxbai-rerank-xsmall-v1"] +CROSS_EMBED_CONTEXT_SIZE = 512 + +# This controls the minimum number of pytorch "threads" to allocate to the embedding +# model. If torch finds more threads on its own, this value is not used. +MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1) + +INDEXING_ONLY = os.environ.get("INDEXING_ONLY", "").lower() == "true" diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index d33981c7cd..9948b1a602 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -67,7 +67,7 @@ services: - ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-} - ENABLE_RERANKING_REAL_TIME_FLOW=${ENABLE_RERANKING_REAL_TIME_FLOW:-} - ENABLE_RERANKING_ASYNC_FLOW=${ENABLE_RERANKING_ASYNC_FLOW:-} - - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-} + - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} - MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-} # Leave this on pretty please? Nothing sensitive is collected! # https://docs.danswer.dev/more/telemetry @@ -80,9 +80,7 @@ services: volumes: - local_dynamic_storage:/home/storage - file_connector_tmp_storage:/home/file_connector_storage - - model_cache_torch:/root/.cache/torch/ - model_cache_nltk:/root/nltk_data/ - - model_cache_huggingface:/root/.cache/huggingface/ extra_hosts: - "host.docker.internal:host-gateway" logging: @@ -90,6 +88,8 @@ services: options: max-size: "50m" max-file: "6" + + background: image: danswer/danswer-backend:latest build: @@ -137,10 +137,9 @@ services: - NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-} - ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-} # Needed by DanswerBot - ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-} - - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-} + - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} - MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-} - - INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-} - - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} + - INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server} # Indexing Configs - NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-} - DISABLE_INDEX_UPDATE_ON_SWAP=${DISABLE_INDEX_UPDATE_ON_SWAP:-} @@ -174,9 +173,7 @@ services: volumes: - local_dynamic_storage:/home/storage - file_connector_tmp_storage:/home/file_connector_storage - - model_cache_torch:/root/.cache/torch/ - model_cache_nltk:/root/nltk_data/ - - model_cache_huggingface:/root/.cache/huggingface/ extra_hosts: - "host.docker.internal:host-gateway" logging: @@ -184,6 +181,8 @@ services: options: max-size: "50m" max-file: "6" + + web_server: image: danswer/danswer-web-server:latest build: @@ -198,6 +197,63 @@ services: environment: - INTERNAL_URL=http://api_server:8080 - WEB_DOMAIN=${WEB_DOMAIN:-} + + + inference_model_server: + image: danswer/danswer-model-server:latest + build: + context: ../../backend + dockerfile: Dockerfile.model_server + command: > + /bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then + echo 'Skipping service...'; + exit 0; + else + exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000; + fi" + restart: on-failure + environment: + - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} + # Set to debug to get more fine-grained logs + - LOG_LEVEL=${LOG_LEVEL:-info} + volumes: + - model_cache_torch:/root/.cache/torch/ + - model_cache_huggingface:/root/.cache/huggingface/ + logging: + driver: json-file + options: + max-size: "50m" + max-file: "6" + + + indexing_model_server: + image: danswer/danswer-model-server:latest + build: + context: ../../backend + dockerfile: Dockerfile.model_server + command: > + /bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then + echo 'Skipping service...'; + exit 0; + else + exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000; + fi" + restart: on-failure + environment: + - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} + - INDEXING_ONLY=True + # Set to debug to get more fine-grained logs + - LOG_LEVEL=${LOG_LEVEL:-info} + volumes: + - model_cache_torch:/root/.cache/torch/ + - model_cache_huggingface:/root/.cache/huggingface/ + logging: + driver: json-file + options: + max-size: "50m" + max-file: "6" + + relational_db: image: postgres:15.2-alpine restart: always @@ -208,6 +264,8 @@ services: - "5432:5432" volumes: - db_volume:/var/lib/postgresql/data + + # This container name cannot have an underscore in it due to Vespa expectations of the URL index: image: vespaengine/vespa:8.277.17 @@ -222,6 +280,8 @@ services: options: max-size: "50m" max-file: "6" + + nginx: image: nginx:1.23.4-alpine restart: always @@ -250,32 +310,8 @@ services: command: > /bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh && /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev" - # Run with --profile model-server to bring up the danswer-model-server container - # Be sure to change MODEL_SERVER_HOST (see above) as well - # ie. MODEL_SERVER_HOST="model_server" docker compose -f docker-compose.dev.yml -p danswer-stack --profile model-server up -d --build - model_server: - image: danswer/danswer-model-server:latest - build: - context: ../../backend - dockerfile: Dockerfile.model_server - profiles: - - "model-server" - command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000 - restart: always - environment: - - DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-} - - NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-} - - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} - # Set to debug to get more fine-grained logs - - LOG_LEVEL=${LOG_LEVEL:-info} - volumes: - - model_cache_torch:/root/.cache/torch/ - - model_cache_huggingface:/root/.cache/huggingface/ - logging: - driver: json-file - options: - max-size: "50m" - max-file: "6" + + volumes: local_dynamic_storage: file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them diff --git a/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml b/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml index 84a912988e..6f671adb40 100644 --- a/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml +++ b/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml @@ -22,9 +22,7 @@ services: volumes: - local_dynamic_storage:/home/storage - file_connector_tmp_storage:/home/file_connector_storage - - model_cache_torch:/root/.cache/torch/ - model_cache_nltk:/root/nltk_data/ - - model_cache_huggingface:/root/.cache/huggingface/ extra_hosts: - "host.docker.internal:host-gateway" logging: @@ -32,6 +30,8 @@ services: options: max-size: "50m" max-file: "6" + + background: image: danswer/danswer-backend:latest build: @@ -51,9 +51,7 @@ services: volumes: - local_dynamic_storage:/home/storage - file_connector_tmp_storage:/home/file_connector_storage - - model_cache_torch:/root/.cache/torch/ - model_cache_nltk:/root/nltk_data/ - - model_cache_huggingface:/root/.cache/huggingface/ extra_hosts: - "host.docker.internal:host-gateway" logging: @@ -61,6 +59,8 @@ services: options: max-size: "50m" max-file: "6" + + web_server: image: danswer/danswer-web-server:latest build: @@ -81,6 +81,63 @@ services: options: max-size: "50m" max-file: "6" + + + inference_model_server: + image: danswer/danswer-model-server:latest + build: + context: ../../backend + dockerfile: Dockerfile.model_server + command: > + /bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then + echo 'Skipping service...'; + exit 0; + else + exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000; + fi" + restart: on-failure + environment: + - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} + # Set to debug to get more fine-grained logs + - LOG_LEVEL=${LOG_LEVEL:-info} + volumes: + - model_cache_torch:/root/.cache/torch/ + - model_cache_huggingface:/root/.cache/huggingface/ + logging: + driver: json-file + options: + max-size: "50m" + max-file: "6" + + + indexing_model_server: + image: danswer/danswer-model-server:latest + build: + context: ../../backend + dockerfile: Dockerfile.model_server + command: > + /bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then + echo 'Skipping service...'; + exit 0; + else + exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000; + fi" + restart: on-failure + environment: + - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} + - INDEXING_ONLY=True + # Set to debug to get more fine-grained logs + - LOG_LEVEL=${LOG_LEVEL:-info} + volumes: + - model_cache_torch:/root/.cache/torch/ + - model_cache_huggingface:/root/.cache/huggingface/ + logging: + driver: json-file + options: + max-size: "50m" + max-file: "6" + + relational_db: image: postgres:15.2-alpine restart: always @@ -94,6 +151,8 @@ services: options: max-size: "50m" max-file: "6" + + # This container name cannot have an underscore in it due to Vespa expectations of the URL index: image: vespaengine/vespa:8.277.17 @@ -108,6 +167,8 @@ services: options: max-size: "50m" max-file: "6" + + nginx: image: nginx:1.23.4-alpine restart: always @@ -137,30 +198,8 @@ services: && /etc/nginx/conf.d/run-nginx.sh app.conf.template.no-letsencrypt" env_file: - .env.nginx - # Run with --profile model-server to bring up the danswer-model-server container - model_server: - image: danswer/danswer-model-server:latest - build: - context: ../../backend - dockerfile: Dockerfile.model_server - profiles: - - "model-server" - command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000 - restart: always - environment: - - DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-} - - NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-} - - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} - # Set to debug to get more fine-grained logs - - LOG_LEVEL=${LOG_LEVEL:-info} - volumes: - - model_cache_torch:/root/.cache/torch/ - - model_cache_huggingface:/root/.cache/huggingface/ - logging: - driver: json-file - options: - max-size: "50m" - max-file: "6" + + volumes: local_dynamic_storage: file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them diff --git a/deployment/docker_compose/docker-compose.prod.yml b/deployment/docker_compose/docker-compose.prod.yml index 5ce30f666a..310ac2ddc5 100644 --- a/deployment/docker_compose/docker-compose.prod.yml +++ b/deployment/docker_compose/docker-compose.prod.yml @@ -22,9 +22,7 @@ services: volumes: - local_dynamic_storage:/home/storage - file_connector_tmp_storage:/home/file_connector_storage - - model_cache_torch:/root/.cache/torch/ - model_cache_nltk:/root/nltk_data/ - - model_cache_huggingface:/root/.cache/huggingface/ extra_hosts: - "host.docker.internal:host-gateway" logging: @@ -32,6 +30,8 @@ services: options: max-size: "50m" max-file: "6" + + background: image: danswer/danswer-backend:latest build: @@ -51,9 +51,7 @@ services: volumes: - local_dynamic_storage:/home/storage - file_connector_tmp_storage:/home/file_connector_storage - - model_cache_torch:/root/.cache/torch/ - model_cache_nltk:/root/nltk_data/ - - model_cache_huggingface:/root/.cache/huggingface/ extra_hosts: - "host.docker.internal:host-gateway" logging: @@ -61,6 +59,8 @@ services: options: max-size: "50m" max-file: "6" + + web_server: image: danswer/danswer-web-server:latest build: @@ -94,6 +94,63 @@ services: options: max-size: "50m" max-file: "6" + + + inference_model_server: + image: danswer/danswer-model-server:latest + build: + context: ../../backend + dockerfile: Dockerfile.model_server + command: > + /bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then + echo 'Skipping service...'; + exit 0; + else + exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000; + fi" + restart: on-failure + environment: + - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} + # Set to debug to get more fine-grained logs + - LOG_LEVEL=${LOG_LEVEL:-info} + volumes: + - model_cache_torch:/root/.cache/torch/ + - model_cache_huggingface:/root/.cache/huggingface/ + logging: + driver: json-file + options: + max-size: "50m" + max-file: "6" + + + indexing_model_server: + image: danswer/danswer-model-server:latest + build: + context: ../../backend + dockerfile: Dockerfile.model_server + command: > + /bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then + echo 'Skipping service...'; + exit 0; + else + exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000; + fi" + restart: on-failure + environment: + - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} + - INDEXING_ONLY=True + # Set to debug to get more fine-grained logs + - LOG_LEVEL=${LOG_LEVEL:-info} + volumes: + - model_cache_torch:/root/.cache/torch/ + - model_cache_huggingface:/root/.cache/huggingface/ + logging: + driver: json-file + options: + max-size: "50m" + max-file: "6" + + # This container name cannot have an underscore in it due to Vespa expectations of the URL index: image: vespaengine/vespa:8.277.17 @@ -108,6 +165,8 @@ services: options: max-size: "50m" max-file: "6" + + nginx: image: nginx:1.23.4-alpine restart: always @@ -141,6 +200,8 @@ services: && /etc/nginx/conf.d/run-nginx.sh app.conf.template" env_file: - .env.nginx + + # follows https://pentacent.medium.com/nginx-and-lets-encrypt-with-docker-in-less-than-5-minutes-b4b8a60d3a71 certbot: image: certbot/certbot @@ -154,30 +215,8 @@ services: max-size: "50m" max-file: "6" entrypoint: "/bin/sh -c 'trap exit TERM; while :; do certbot renew; sleep 12h & wait $${!}; done;'" - # Run with --profile model-server to bring up the danswer-model-server container - model_server: - image: danswer/danswer-model-server:latest - build: - context: ../../backend - dockerfile: Dockerfile.model_server - profiles: - - "model-server" - command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000 - restart: always - environment: - - DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-} - - NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-} - - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} - # Set to debug to get more fine-grained logs - - LOG_LEVEL=${LOG_LEVEL:-info} - volumes: - - model_cache_torch:/root/.cache/torch/ - - model_cache_huggingface:/root/.cache/huggingface/ - logging: - driver: json-file - options: - max-size: "50m" - max-file: "6" + + volumes: local_dynamic_storage: file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them diff --git a/deployment/kubernetes/env-configmap.yaml b/deployment/kubernetes/env-configmap.yaml index a10aad91e1..88ed9e0962 100644 --- a/deployment/kubernetes/env-configmap.yaml +++ b/deployment/kubernetes/env-configmap.yaml @@ -43,9 +43,9 @@ data: ASYM_PASSAGE_PREFIX: "" ENABLE_RERANKING_REAL_TIME_FLOW: "" ENABLE_RERANKING_ASYNC_FLOW: "" - MODEL_SERVER_HOST: "" + MODEL_SERVER_HOST: "inference-model-server-service" MODEL_SERVER_PORT: "" - INDEXING_MODEL_SERVER_HOST: "" + INDEXING_MODEL_SERVER_HOST: "indexing-model-server-service" MIN_THREADS_ML_MODELS: "" # Indexing Configs NUM_INDEXING_WORKERS: "" diff --git a/deployment/kubernetes/indexing_model_server-service-deployment.yaml b/deployment/kubernetes/indexing_model_server-service-deployment.yaml new file mode 100644 index 0000000000..d44b52e928 --- /dev/null +++ b/deployment/kubernetes/indexing_model_server-service-deployment.yaml @@ -0,0 +1,59 @@ +apiVersion: v1 +kind: Service +metadata: + name: indexing-model-server-service +spec: + selector: + app: indexing-model-server + ports: + - name: indexing-model-server-port + protocol: TCP + port: 9000 + targetPort: 9000 + type: ClusterIP +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: indexing-model-server-deployment +spec: + replicas: 1 + selector: + matchLabels: + app: indexing-model-server + template: + metadata: + labels: + app: indexing-model-server + spec: + containers: + - name: indexing-model-server + image: danswer/danswer-model-server:latest + imagePullPolicy: IfNotPresent + command: [ "uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000" ] + ports: + - containerPort: 9000 + envFrom: + - configMapRef: + name: env-configmap + env: + - name: INDEXING_ONLY + value: "True" + volumeMounts: + - name: indexing-model-storage + mountPath: /root/.cache + volumes: + - name: indexing-model-storage + persistentVolumeClaim: + claimName: indexing-model-pvc +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: indexing-model-pvc +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 3Gi diff --git a/deployment/kubernetes/inference_model_server-service-deployment.yaml b/deployment/kubernetes/inference_model_server-service-deployment.yaml new file mode 100644 index 0000000000..790dc633db --- /dev/null +++ b/deployment/kubernetes/inference_model_server-service-deployment.yaml @@ -0,0 +1,56 @@ +apiVersion: v1 +kind: Service +metadata: + name: inference-model-server-service +spec: + selector: + app: inference-model-server + ports: + - name: inference-model-server-port + protocol: TCP + port: 9000 + targetPort: 9000 + type: ClusterIP +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: inference-model-server-deployment +spec: + replicas: 1 + selector: + matchLabels: + app: inference-model-server + template: + metadata: + labels: + app: inference-model-server + spec: + containers: + - name: inference-model-server + image: danswer/danswer-model-server:latest + imagePullPolicy: IfNotPresent + command: [ "uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000" ] + ports: + - containerPort: 9000 + envFrom: + - configMapRef: + name: env-configmap + volumeMounts: + - name: inference-model-storage + mountPath: /root/.cache + volumes: + - name: inference-model-storage + persistentVolumeClaim: + claimName: inference-model-pvc +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: inference-model-pvc +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 3Gi From b432d422058392f1514f5ddc809ab25f018497aa Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Mon, 8 Apr 2024 00:52:14 -0700 Subject: [PATCH 53/58] Mypy Fix (#1308) --- backend/danswer/indexing/embedder.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/backend/danswer/indexing/embedder.py b/backend/danswer/indexing/embedder.py index 2017265777..446122df1c 100644 --- a/backend/danswer/indexing/embedder.py +++ b/backend/danswer/indexing/embedder.py @@ -108,10 +108,12 @@ def embed_chunks( chunk_titles = { chunk.source_document.get_title_for_document_index() for chunk in chunks } - chunk_titles.discard(None) + + # Drop any None or empty strings + chunk_titles_list = [title for title in chunk_titles if title] # Embed Titles in batches - title_batches = batch_list(list(chunk_titles), batch_size) + title_batches = batch_list(chunk_titles_list, batch_size) len_title_batches = len(title_batches) for ind_batch, title_batch in enumerate(title_batches, start=1): logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}") From dac4be62e0fb223a4d13577cef7ff4ffc5277cd8 Mon Sep 17 00:00:00 2001 From: Weves Date: Mon, 8 Apr 2024 15:59:23 -0700 Subject: [PATCH 54/58] Fix prod compose files --- .../docker_compose/docker-compose.prod-no-letsencrypt.yml | 3 +++ deployment/docker_compose/docker-compose.prod.yml | 3 +++ 2 files changed, 6 insertions(+) diff --git a/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml b/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml index 6f671adb40..5c5cd5a466 100644 --- a/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml +++ b/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml @@ -19,6 +19,7 @@ services: - AUTH_TYPE=${AUTH_TYPE:-google_oauth} - POSTGRES_HOST=relational_db - VESPA_HOST=index + - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} volumes: - local_dynamic_storage:/home/storage - file_connector_tmp_storage:/home/file_connector_storage @@ -48,6 +49,8 @@ services: - AUTH_TYPE=${AUTH_TYPE:-google_oauth} - POSTGRES_HOST=relational_db - VESPA_HOST=index + - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} + - INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server} volumes: - local_dynamic_storage:/home/storage - file_connector_tmp_storage:/home/file_connector_storage diff --git a/deployment/docker_compose/docker-compose.prod.yml b/deployment/docker_compose/docker-compose.prod.yml index 310ac2ddc5..9c7202abd3 100644 --- a/deployment/docker_compose/docker-compose.prod.yml +++ b/deployment/docker_compose/docker-compose.prod.yml @@ -19,6 +19,7 @@ services: - AUTH_TYPE=${AUTH_TYPE:-google_oauth} - POSTGRES_HOST=relational_db - VESPA_HOST=index + - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} volumes: - local_dynamic_storage:/home/storage - file_connector_tmp_storage:/home/file_connector_storage @@ -48,6 +49,8 @@ services: - AUTH_TYPE=${AUTH_TYPE:-google_oauth} - POSTGRES_HOST=relational_db - VESPA_HOST=index + - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} + - INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server} volumes: - local_dynamic_storage:/home/storage - file_connector_tmp_storage:/home/file_connector_storage From 31bfbe5d1624074844e655ec435d9e095cda8f92 Mon Sep 17 00:00:00 2001 From: Weves Date: Tue, 9 Apr 2024 11:29:37 -0700 Subject: [PATCH 55/58] Fix chat sharing --- backend/danswer/server/query_and_chat/chat_backend.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 52d6414e22..52d879dfe6 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -108,7 +108,12 @@ def get_chat_session( db_session.commit() session_messages = get_chat_messages_by_session( - chat_session_id=session_id, user_id=user_id, db_session=db_session + chat_session_id=session_id, + user_id=user_id, + db_session=db_session, + # we already did a permission check above with the call to + # `get_chat_session_by_id`, so we can skip it here + skip_permission_check=True, ) return ChatSessionDetailResponse( From 714a3c867db278da20ba9413a5fcad0f6f520096 Mon Sep 17 00:00:00 2001 From: Weves Date: Tue, 9 Apr 2024 19:18:39 -0700 Subject: [PATCH 56/58] Add option to skip Jira tickets with a certain label --- backend/danswer/configs/app_configs.py | 5 +++++ .../connectors/danswer_jira/connector.py | 20 +++++++++++++++++++ .../docker_compose/docker-compose.dev.yml | 1 + 3 files changed, 26 insertions(+) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index c6c697e89c..b8bcc97f69 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -157,6 +157,11 @@ ) if ignored_tag ] +JIRA_CONNECTOR_LABELS_TO_SKIP = [ + ignored_tag + for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",") + if ignored_tag +] GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME") diff --git a/backend/danswer/connectors/danswer_jira/connector.py b/backend/danswer/connectors/danswer_jira/connector.py index 5ef833e581..dfed7ebd16 100644 --- a/backend/danswer/connectors/danswer_jira/connector.py +++ b/backend/danswer/connectors/danswer_jira/connector.py @@ -8,6 +8,7 @@ from jira.resources import Issue from danswer.configs.app_configs import INDEX_BATCH_SIZE +from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP from danswer.configs.constants import DocumentSource from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from danswer.connectors.interfaces import GenerateDocumentsOutput @@ -68,6 +69,7 @@ def fetch_jira_issues_batch( jira_client: JIRA, batch_size: int = INDEX_BATCH_SIZE, comment_email_blacklist: tuple[str, ...] = (), + labels_to_skip: set[str] | None = None, ) -> tuple[list[Document], int]: doc_batch = [] @@ -82,6 +84,15 @@ def fetch_jira_issues_batch( logger.warning(f"Found Jira object not of type Issue {jira}") continue + if labels_to_skip and any( + label in jira.fields.labels for label in labels_to_skip + ): + logger.info( + f"Skipping {jira.key} because it has a label to skip. Found " + f"labels: {jira.fields.labels}. Labels to skip: {labels_to_skip}." + ) + continue + comments = _get_comment_strs(jira, comment_email_blacklist) semantic_rep = f"{jira.fields.description}\n" + "\n".join( [f"Comment: {comment}" for comment in comments] @@ -143,12 +154,18 @@ def __init__( jira_project_url: str, comment_email_blacklist: list[str] | None = None, batch_size: int = INDEX_BATCH_SIZE, + # if a ticket has one of the labels specified in this list, we will just + # skip it. This is generally used to avoid indexing extra sensitive + # tickets. + labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP, ) -> None: self.batch_size = batch_size self.jira_base, self.jira_project = extract_jira_project(jira_project_url) self.jira_client: JIRA | None = None self._comment_email_blacklist = comment_email_blacklist or [] + self.labels_to_skip = set(labels_to_skip) + @property def comment_email_blacklist(self) -> tuple: return tuple(email.strip() for email in self._comment_email_blacklist) @@ -182,6 +199,8 @@ def load_from_state(self) -> GenerateDocumentsOutput: start_index=start_ind, jira_client=self.jira_client, batch_size=self.batch_size, + comment_email_blacklist=self.comment_email_blacklist, + labels_to_skip=self.labels_to_skip, ) if doc_batch: @@ -218,6 +237,7 @@ def poll_source( jira_client=self.jira_client, batch_size=self.batch_size, comment_email_blacklist=self.comment_email_blacklist, + labels_to_skip=self.labels_to_skip, ) if doc_batch: diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 9948b1a602..9b5115f801 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -147,6 +147,7 @@ services: - CONTINUE_ON_CONNECTOR_FAILURE=${CONTINUE_ON_CONNECTOR_FAILURE:-} - EXPERIMENTAL_CHECKPOINTING_ENABLED=${EXPERIMENTAL_CHECKPOINTING_ENABLED:-} - CONFLUENCE_CONNECTOR_LABELS_TO_SKIP=${CONFLUENCE_CONNECTOR_LABELS_TO_SKIP:-} + - JIRA_CONNECTOR_LABELS_TO_SKIP=${JIRA_CONNECTOR_LABELS_TO_SKIP:-} - JIRA_API_VERSION=${JIRA_API_VERSION:-} - GONG_CONNECTOR_START_TIME=${GONG_CONNECTOR_START_TIME:-} - NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-} From f346c2fc869be6c44b2af1d21ad4ee1b35abdeb1 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Wed, 10 Apr 2024 09:44:12 -0700 Subject: [PATCH 57/58] Axero Link Fix (#1317) --- backend/danswer/connectors/axero/connector.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/backend/danswer/connectors/axero/connector.py b/backend/danswer/connectors/axero/connector.py index 9ee7b2d6f0..f82c6b4494 100644 --- a/backend/danswer/connectors/axero/connector.py +++ b/backend/danswer/connectors/axero/connector.py @@ -244,7 +244,7 @@ def _translate_content_to_doc(content: dict) -> Document: doc = Document( id="AXERO_" + str(content["ContentID"]), - sections=[Section(link=content["ContentVersionURL"], text=page_text)], + sections=[Section(link=content["ContentURL"], text=page_text)], source=DocumentSource.AXERO, semantic_identifier=content["ContentTitle"], doc_updated_at=time_str_to_utc(content["DateUpdated"]), @@ -304,7 +304,6 @@ def poll_source( iterable_space_ids = self.space_ids if self.space_ids else [None] for space_id in iterable_space_ids: - entity_types = [] for entity in entity_types: axero_obj = _get_entities( entity_type=entity, From b59912884bed7a32a0858d2e634d5402647c5ea6 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Wed, 10 Apr 2024 23:13:22 -0700 Subject: [PATCH 58/58] Fix Model Server (#1320) --- backend/Dockerfile.model_server | 8 +------- backend/danswer/background/update.py | 6 +++--- backend/danswer/configs/app_configs.py | 16 +--------------- .../danswerbot/slack/handlers/handle_message.py | 2 +- backend/danswer/danswerbot/slack/listener.py | 4 ++-- backend/danswer/indexing/embedder.py | 4 ++-- backend/danswer/llm/utils.py | 2 +- backend/danswer/main.py | 6 +++--- backend/danswer/search/models.py | 2 +- .../search/preprocessing/preprocessing.py | 2 +- .../danswer/search/retrieval/search_runner.py | 4 ++-- backend/danswer/search/search_nlp_models.py | 4 ++-- backend/danswer/utils/logger.py | 2 +- backend/model_server/custom_models.py | 6 +++--- backend/model_server/encoders.py | 6 +++--- backend/model_server/main.py | 12 ++++++------ backend/requirements/model_server.txt | 1 + .../{nlp_model_configs.py => configs.py} | 14 ++++++++++++++ 18 files changed, 48 insertions(+), 53 deletions(-) rename backend/shared_configs/{nlp_model_configs.py => configs.py} (57%) diff --git a/backend/Dockerfile.model_server b/backend/Dockerfile.model_server index 0eb455c513..cb7115c0bc 100644 --- a/backend/Dockerfile.model_server +++ b/backend/Dockerfile.model_server @@ -13,19 +13,13 @@ RUN apt-get remove -y --allow-remove-essential perl-base && \ WORKDIR /app -# Needed for model configs and defaults -COPY ./danswer/configs /app/danswer/configs -COPY ./danswer/dynamic_configs /app/danswer/dynamic_configs - # Utils used by model server COPY ./danswer/utils/logger.py /app/danswer/utils/logger.py -COPY ./danswer/utils/timing.py /app/danswer/utils/timing.py -COPY ./danswer/utils/telemetry.py /app/danswer/utils/telemetry.py # Place to fetch version information COPY ./danswer/__init__.py /app/danswer/__init__.py -# Request/Response models +# Shared between Danswer Backend and Model Server COPY ./shared_configs /app/shared_configs # Model Server main code diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 8d8de8da4c..6042e02b1c 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -15,9 +15,6 @@ from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP -from danswer.configs.app_configs import INDEXING_MODEL_SERVER_HOST -from danswer.configs.app_configs import LOG_LEVEL -from danswer.configs.app_configs import MODEL_SERVER_PORT from danswer.configs.app_configs import NUM_INDEXING_WORKERS from danswer.db.connector import fetch_connectors from danswer.db.connector_credential_pair import get_connector_credential_pairs @@ -46,6 +43,9 @@ from danswer.db.models import IndexModelStatus from danswer.search.search_nlp_models import warm_up_encoders from danswer.utils.logger import setup_logger +from shared_configs.configs import INDEXING_MODEL_SERVER_HOST +from shared_configs.configs import LOG_LEVEL +from shared_configs.configs import MODEL_SERVER_PORT logger = setup_logger() diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index b8bcc97f69..1e4809d071 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -209,19 +209,6 @@ ) -##### -# Model Server Configs -##### -MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or "localhost" -MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0" -MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000") -# Model server for indexing should use a separate one to not allow indexing to introduce delay -# for inference -INDEXING_MODEL_SERVER_HOST = ( - os.environ.get("INDEXING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST -) - - ##### # Miscellaneous ##### @@ -246,8 +233,7 @@ ) # Anonymous usage telemetry DISABLE_TELEMETRY = os.environ.get("DISABLE_TELEMETRY", "").lower() == "true" -# notset, debug, info, warning, error, or critical -LOG_LEVEL = os.environ.get("LOG_LEVEL", "info") + TOKEN_BUDGET_GLOBALLY_ENABLED = ( os.environ.get("TOKEN_BUDGET_GLOBALLY_ENABLED", "").lower() == "true" ) diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index 0886c0c175..fc1c038aeb 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -51,7 +51,7 @@ from danswer.search.models import OptionalSearchSetting from danswer.search.models import RetrievalDetails from danswer.utils.logger import setup_logger -from shared_configs.nlp_model_configs import ENABLE_RERANKING_ASYNC_FLOW +from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW logger_base = setup_logger() diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py index 08aa584111..7f935c1f07 100644 --- a/backend/danswer/danswerbot/slack/listener.py +++ b/backend/danswer/danswerbot/slack/listener.py @@ -10,8 +10,6 @@ from slack_sdk.socket_mode.response import SocketModeResponse from sqlalchemy.orm import Session -from danswer.configs.app_configs import MODEL_SERVER_HOST -from danswer.configs.app_configs import MODEL_SERVER_PORT from danswer.configs.constants import MessageType from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER @@ -47,6 +45,8 @@ from danswer.search.search_nlp_models import warm_up_encoders from danswer.server.manage.models import SlackBotTokens from danswer.utils.logger import setup_logger +from shared_configs.configs import MODEL_SERVER_HOST +from shared_configs.configs import MODEL_SERVER_PORT logger = setup_logger() diff --git a/backend/danswer/indexing/embedder.py b/backend/danswer/indexing/embedder.py index 446122df1c..20a8690e36 100644 --- a/backend/danswer/indexing/embedder.py +++ b/backend/danswer/indexing/embedder.py @@ -4,8 +4,6 @@ from sqlalchemy.orm import Session from danswer.configs.app_configs import ENABLE_MINI_CHUNK -from danswer.configs.app_configs import INDEXING_MODEL_SERVER_HOST -from danswer.configs.app_configs import MODEL_SERVER_PORT from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.db.embedding_model import get_current_db_embedding_model @@ -20,6 +18,8 @@ from danswer.search.search_nlp_models import EmbeddingModel from danswer.utils.batching import batch_list from danswer.utils.logger import setup_logger +from shared_configs.configs import INDEXING_MODEL_SERVER_HOST +from shared_configs.configs import MODEL_SERVER_PORT logger = setup_logger() diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index b41c85b9ed..05b36f6ffc 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -20,7 +20,6 @@ from langchain.schema.messages import SystemMessage from tiktoken.core import Encoding -from danswer.configs.app_configs import LOG_LEVEL from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY from danswer.configs.constants import GEN_AI_DETECTED_MODEL from danswer.configs.constants import MessageType @@ -37,6 +36,7 @@ from danswer.indexing.models import InferenceChunk from danswer.llm.interfaces import LLM from danswer.utils.logger import setup_logger +from shared_configs.configs import LOG_LEVEL if TYPE_CHECKING: from danswer.llm.answering.models import PreviousMessage diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 9ce32fe01b..3fb9a11750 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -28,8 +28,6 @@ from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP -from danswer.configs.app_configs import MODEL_SERVER_HOST -from danswer.configs.app_configs import MODEL_SERVER_PORT from danswer.configs.app_configs import OAUTH_CLIENT_ID from danswer.configs.app_configs import OAUTH_CLIENT_SECRET from danswer.configs.app_configs import SECRET @@ -81,7 +79,9 @@ from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType from danswer.utils.variable_functionality import fetch_versioned_implementation -from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW +from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW +from shared_configs.configs import MODEL_SERVER_HOST +from shared_configs.configs import MODEL_SERVER_PORT logger = setup_logger() diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index 9d3eb39b0c..7fc247fa4e 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -11,7 +11,7 @@ from danswer.db.models import Persona from danswer.search.enums import OptionalSearchSetting from danswer.search.enums import SearchType -from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW +from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW MAX_METRICS_CONTENT = ( diff --git a/backend/danswer/search/preprocessing/preprocessing.py b/backend/danswer/search/preprocessing/preprocessing.py index 7da6db4ceb..ec9fc2dae0 100644 --- a/backend/danswer/search/preprocessing/preprocessing.py +++ b/backend/danswer/search/preprocessing/preprocessing.py @@ -21,7 +21,7 @@ from danswer.utils.threadpool_concurrency import FunctionCall from danswer.utils.threadpool_concurrency import run_functions_in_parallel from danswer.utils.timing import log_function_time -from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW +from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW logger = setup_logger() diff --git a/backend/danswer/search/retrieval/search_runner.py b/backend/danswer/search/retrieval/search_runner.py index bb17253921..1189053dbd 100644 --- a/backend/danswer/search/retrieval/search_runner.py +++ b/backend/danswer/search/retrieval/search_runner.py @@ -7,8 +7,6 @@ from sqlalchemy.orm import Session from danswer.chat.models import LlmDoc -from danswer.configs.app_configs import MODEL_SERVER_HOST -from danswer.configs.app_configs import MODEL_SERVER_PORT from danswer.configs.chat_configs import HYBRID_ALPHA from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.db.embedding_model import get_current_db_embedding_model @@ -26,6 +24,8 @@ from danswer.utils.logger import setup_logger from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel from danswer.utils.timing import log_function_time +from shared_configs.configs import MODEL_SERVER_HOST +from shared_configs.configs import MODEL_SERVER_PORT logger = setup_logger() diff --git a/backend/danswer/search/search_nlp_models.py b/backend/danswer/search/search_nlp_models.py index 95bd4d0f23..39d762238a 100644 --- a/backend/danswer/search/search_nlp_models.py +++ b/backend/danswer/search/search_nlp_models.py @@ -7,12 +7,12 @@ import requests from transformers import logging as transformer_logging # type:ignore -from danswer.configs.app_configs import MODEL_SERVER_HOST -from danswer.configs.app_configs import MODEL_SERVER_PORT from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL from danswer.search.enums import EmbedTextType from danswer.utils.logger import setup_logger +from shared_configs.configs import MODEL_SERVER_HOST +from shared_configs.configs import MODEL_SERVER_PORT from shared_configs.model_server_models import EmbedRequest from shared_configs.model_server_models import EmbedResponse from shared_configs.model_server_models import IntentRequest diff --git a/backend/danswer/utils/logger.py b/backend/danswer/utils/logger.py index c4dd59742b..38e24a3672 100644 --- a/backend/danswer/utils/logger.py +++ b/backend/danswer/utils/logger.py @@ -3,7 +3,7 @@ from collections.abc import MutableMapping from typing import Any -from danswer.configs.app_configs import LOG_LEVEL +from shared_configs.configs import LOG_LEVEL class IndexAttemptSingleton: diff --git a/backend/model_server/custom_models.py b/backend/model_server/custom_models.py index 9b8066e96c..ee97ded784 100644 --- a/backend/model_server/custom_models.py +++ b/backend/model_server/custom_models.py @@ -8,11 +8,11 @@ from model_server.constants import MODEL_WARM_UP_STRING from model_server.utils import simple_log_function_time +from shared_configs.configs import INDEXING_ONLY +from shared_configs.configs import INTENT_MODEL_CONTEXT_SIZE +from shared_configs.configs import INTENT_MODEL_VERSION from shared_configs.model_server_models import IntentRequest from shared_configs.model_server_models import IntentResponse -from shared_configs.nlp_model_configs import INDEXING_ONLY -from shared_configs.nlp_model_configs import INTENT_MODEL_CONTEXT_SIZE -from shared_configs.nlp_model_configs import INTENT_MODEL_VERSION router = APIRouter(prefix="/custom") diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index f1f3fdf0cf..705386a8c4 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -9,13 +9,13 @@ from danswer.utils.logger import setup_logger from model_server.constants import MODEL_WARM_UP_STRING from model_server.utils import simple_log_function_time +from shared_configs.configs import CROSS_EMBED_CONTEXT_SIZE +from shared_configs.configs import CROSS_ENCODER_MODEL_ENSEMBLE +from shared_configs.configs import INDEXING_ONLY from shared_configs.model_server_models import EmbedRequest from shared_configs.model_server_models import EmbedResponse from shared_configs.model_server_models import RerankRequest from shared_configs.model_server_models import RerankResponse -from shared_configs.nlp_model_configs import CROSS_EMBED_CONTEXT_SIZE -from shared_configs.nlp_model_configs import CROSS_ENCODER_MODEL_ENSEMBLE -from shared_configs.nlp_model_configs import INDEXING_ONLY logger = setup_logger() diff --git a/backend/model_server/main.py b/backend/model_server/main.py index aaac1d0d17..c7b2a2f931 100644 --- a/backend/model_server/main.py +++ b/backend/model_server/main.py @@ -8,17 +8,17 @@ from transformers import logging as transformer_logging # type:ignore from danswer import __version__ -from danswer.configs.app_configs import MODEL_SERVER_ALLOWED_HOST -from danswer.configs.app_configs import MODEL_SERVER_PORT from danswer.utils.logger import setup_logger from model_server.custom_models import router as custom_models_router from model_server.custom_models import warm_up_intent_model from model_server.encoders import router as encoders_router from model_server.encoders import warm_up_cross_encoders -from shared_configs.nlp_model_configs import ENABLE_RERANKING_ASYNC_FLOW -from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW -from shared_configs.nlp_model_configs import INDEXING_ONLY -from shared_configs.nlp_model_configs import MIN_THREADS_ML_MODELS +from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW +from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW +from shared_configs.configs import INDEXING_ONLY +from shared_configs.configs import MIN_THREADS_ML_MODELS +from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST +from shared_configs.configs import MODEL_SERVER_PORT os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" diff --git a/backend/requirements/model_server.txt b/backend/requirements/model_server.txt index 487e6338d4..8f133657b5 100644 --- a/backend/requirements/model_server.txt +++ b/backend/requirements/model_server.txt @@ -1,4 +1,5 @@ fastapi==0.109.2 +h5py==3.9.0 pydantic==1.10.7 safetensors==0.4.2 sentence-transformers==2.6.1 diff --git a/backend/shared_configs/nlp_model_configs.py b/backend/shared_configs/configs.py similarity index 57% rename from backend/shared_configs/nlp_model_configs.py rename to backend/shared_configs/configs.py index cc58a56b0d..41b46723e4 100644 --- a/backend/shared_configs/nlp_model_configs.py +++ b/backend/shared_configs/configs.py @@ -1,6 +1,15 @@ import os +MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or "localhost" +MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0" +MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000") +# Model server for indexing should use a separate one to not allow indexing to introduce delay +# for inference +INDEXING_MODEL_SERVER_HOST = ( + os.environ.get("INDEXING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST +) + # Danswer custom Deep Learning Models INTENT_MODEL_VERSION = "danswer/intent-model" INTENT_MODEL_CONTEXT_SIZE = 256 @@ -23,4 +32,9 @@ # model. If torch finds more threads on its own, this value is not used. MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1) +# Model server that has indexing only set will throw exception if used for reranking +# or intent classification INDEXING_ONLY = os.environ.get("INDEXING_ONLY", "").lower() == "true" + +# notset, debug, info, warning, error, or critical +LOG_LEVEL = os.environ.get("LOG_LEVEL", "info")