diff --git a/.github/workflows/docker-build-push-backend-container-on-tag.yml b/.github/workflows/docker-build-push-backend-container-on-tag.yml index e95c143fb49..a7d46a09736 100644 --- a/.github/workflows/docker-build-push-backend-container-on-tag.yml +++ b/.github/workflows/docker-build-push-backend-container-on-tag.yml @@ -38,5 +38,7 @@ jobs: - name: Run Trivy vulnerability scanner uses: aquasecurity/trivy-action@master with: + # To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend image-ref: docker.io/danswer/danswer-backend:${{ github.ref_name }} severity: 'CRITICAL,HIGH' + trivyignores: ./backend/.trivyignore diff --git a/.github/workflows/pr-python-checks.yml b/.github/workflows/pr-python-checks.yml index 792fe4d46b3..6c604e93d43 100644 --- a/.github/workflows/pr-python-checks.yml +++ b/.github/workflows/pr-python-checks.yml @@ -20,10 +20,12 @@ jobs: cache-dependency-path: | backend/requirements/default.txt backend/requirements/dev.txt + backend/requirements/model_server.txt - run: | python -m pip install --upgrade pip pip install -r backend/requirements/default.txt pip install -r backend/requirements/dev.txt + pip install -r backend/requirements/model_server.txt - name: Run MyPy run: | diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4d88752da2b..7e80baeb2d7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -85,6 +85,7 @@ Install the required python dependencies: ```bash pip install -r danswer/backend/requirements/default.txt pip install -r danswer/backend/requirements/dev.txt +pip install -r danswer/backend/requirements/model_server.txt ``` Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend. @@ -112,26 +113,24 @@ docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational (index refers to Vespa and relational_db refers to Postgres) #### Running Danswer - -Setup a folder to store config. Navigate to `danswer/backend` and run: -```bash -mkdir dynamic_config_storage -``` - To start the frontend, navigate to `danswer/web` and run: ```bash npm run dev ``` -Package the Vespa schema. This will only need to be done when the Vespa schema is updated locally. - -Navigate to `danswer/backend/danswer/document_index/vespa/app_config` and run: +Next, start the model server which runs the local NLP models. +Navigate to `danswer/backend` and run: ```bash -zip -r ../vespa-app.zip . +uvicorn model_server.main:app --reload --port 9000 +``` +_For Windows (for compatibility with both PowerShell and Command Prompt):_ +```bash +powershell -Command " + uvicorn model_server.main:app --reload --port 9000 +" ``` -- Note: If you don't have the `zip` utility, you will need to install it prior to running the above -The first time running Danswer, you will also need to run the DB migrations for Postgres. +The first time running Danswer, you will need to run the DB migrations for Postgres. After the first time, this is no longer required unless the DB models change. Navigate to `danswer/backend` and with the venv active, run: @@ -149,17 +148,12 @@ python ./scripts/dev_run_background_jobs.py To run the backend API server, navigate back to `danswer/backend` and run: ```bash -AUTH_TYPE=disabled \ -DYNAMIC_CONFIG_DIR_PATH=./dynamic_config_storage \ -VESPA_DEPLOYMENT_ZIP=./danswer/document_index/vespa/vespa-app.zip \ -uvicorn danswer.main:app --reload --port 8080 +AUTH_TYPE=disabled uvicorn danswer.main:app --reload --port 8080 ``` _For Windows (for compatibility with both PowerShell and Command Prompt):_ ```bash powershell -Command " $env:AUTH_TYPE='disabled' - $env:DYNAMIC_CONFIG_DIR_PATH='./dynamic_config_storage' - $env:VESPA_DEPLOYMENT_ZIP='./danswer/document_index/vespa/vespa-app.zip' uvicorn danswer.main:app --reload --port 8080 " ``` @@ -178,20 +172,16 @@ pre-commit install Additionally, we use `mypy` for static type checking. Danswer is fully type-annotated, and we would like to keep it that way! -Right now, there is no automated type checking at the moment (coming soon), but we ask you to manually run it before -creating a pull requests with `python -m mypy .` from the `danswer/backend` directory. +To run the mypy checks manually, run `python -m mypy .` from the `danswer/backend` directory. #### Web We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `danswer/web` directory. To run the formatter, use `npx prettier --write .` from the `danswer/web` directory. -Like `mypy`, we have no automated formatting yet (coming soon), but we request that, for now, -you run this manually before creating a pull request. +Please double check that prettier passes before creating a pull request. ### Release Process Danswer follows the semver versioning standard. A set of Docker containers will be pushed automatically to DockerHub with every tag. You can see the containers [here](https://hub.docker.com/search?q=danswer%2F). - -As pre-1.0 software, even patch releases may contain breaking or non-backwards-compatible changes. diff --git a/README.md b/README.md index 3e70e7259c7..edd8328c31e 100644 --- a/README.md +++ b/README.md @@ -22,11 +22,12 @@
-[Danswer](https://www.danswer.ai/) is the ChatGPT for teams. Danswer provides a Chat interface and plugs into any LLM of -your choice. Danswer can be deployed anywhere and for any scale - on a laptop, on-premise, or to cloud. Since you own -the deployment, your user data and chats are fully in your own control. Danswer is MIT licensed and designed to be -modular and easily extensible. The system also comes fully ready for production usage with user authentication, role -management (admin/basic users), chat persistence, and a UI for configuring Personas (AI Assistants) and their Prompts. +[Danswer](https://www.danswer.ai/) is the AI Assistant connected to your company's docs, apps, and people. +Danswer provides a Chat interface and plugs into any LLM of your choice. Danswer can be deployed anywhere and for any +scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your +own control. Danswer is MIT licensed and designed to be modular and easily extensible. The system also comes fully ready +for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for +configuring Personas (AI Assistants) and their Prompts. Danswer also serves as a Unified Search across all common workplace tools such as Slack, Google Drive, Confluence, etc. By combining LLMs and team specific knowledge, Danswer becomes a subject matter expert for the team. Imagine ChatGPT if diff --git a/backend/.trivyignore b/backend/.trivyignore new file mode 100644 index 00000000000..e8351b40741 --- /dev/null +++ b/backend/.trivyignore @@ -0,0 +1,46 @@ +# https://github.com/madler/zlib/issues/868 +# Pulled in with base Debian image, it's part of the contrib folder but unused +# zlib1g is fine +# Will be gone with Debian image upgrade +# No impact in our settings +CVE-2023-45853 + +# krb5 related, worst case is denial of service by resource exhaustion +# Accept the risk +CVE-2024-26458 +CVE-2024-26461 +CVE-2024-26462 +CVE-2024-26458 +CVE-2024-26461 +CVE-2024-26462 +CVE-2024-26458 +CVE-2024-26461 +CVE-2024-26462 +CVE-2024-26458 +CVE-2024-26461 +CVE-2024-26462 + +# Specific to Firefox which we do not use +# No impact in our settings +CVE-2024-0743 + +# bind9 related, worst case is denial of service by CPU resource exhaustion +# Accept the risk +CVE-2023-50387 +CVE-2023-50868 +CVE-2023-50387 +CVE-2023-50868 + +# libexpat1, XML parsing resource exhaustion +# We don't parse any user provided XMLs +# No impact in our settings +CVE-2023-52425 +CVE-2024-28757 + +# sqlite, only used by NLTK library to grab word lemmatizer and stopwords +# No impact in our settings +CVE-2023-7104 + +# libharfbuzz0b, O(n^2) growth, worst case is denial of service +# Accept the risk +CVE-2023-25193 diff --git a/backend/Dockerfile b/backend/Dockerfile index a9bc852a5a2..a0b50c53cbe 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -12,7 +12,9 @@ RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}" # zip for Vespa step futher down # ca-certificates for HTTPS RUN apt-get update && \ - apt-get install -y cmake curl zip ca-certificates libgnutls30=3.7.9-2+deb12u2 && \ + apt-get install -y cmake curl zip ca-certificates libgnutls30=3.7.9-2+deb12u2 \ + libblkid1=2.38.1-5+deb12u1 libmount1=2.38.1-5+deb12u1 libsmartcols1=2.38.1-5+deb12u1 \ + libuuid1=2.38.1-5+deb12u1 && \ rm -rf /var/lib/apt/lists/* && \ apt-get clean @@ -29,7 +31,8 @@ RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \ # xserver-common and xvfb included by playwright installation but not needed after # perl-base is part of the base Python Debian image but not needed for Danswer functionality # perl-base could only be removed with --allow-remove-essential -RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cmake libldap-2.5-0 libldap-2.5-0 && \ +RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cmake \ + libldap-2.5-0 libldap-2.5-0 && \ apt-get autoremove -y && \ rm -rf /var/lib/apt/lists/* && \ rm /usr/local/lib/python3.11/site-packages/tornado/test/test.key @@ -37,7 +40,7 @@ RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cma # Set up application files WORKDIR /app COPY ./danswer /app/danswer -COPY ./shared_models /app/shared_models +COPY ./shared_configs /app/shared_configs COPY ./alembic /app/alembic COPY ./alembic.ini /app/alembic.ini COPY supervisord.conf /usr/etc/supervisord.conf diff --git a/backend/Dockerfile.model_server b/backend/Dockerfile.model_server index 624bdd37fcd..cb7115c0bc7 100644 --- a/backend/Dockerfile.model_server +++ b/backend/Dockerfile.model_server @@ -13,23 +13,14 @@ RUN apt-get remove -y --allow-remove-essential perl-base && \ WORKDIR /app -# Needed for model configs and defaults -COPY ./danswer/configs /app/danswer/configs -COPY ./danswer/dynamic_configs /app/danswer/dynamic_configs - # Utils used by model server COPY ./danswer/utils/logger.py /app/danswer/utils/logger.py -COPY ./danswer/utils/timing.py /app/danswer/utils/timing.py -COPY ./danswer/utils/telemetry.py /app/danswer/utils/telemetry.py # Place to fetch version information COPY ./danswer/__init__.py /app/danswer/__init__.py -# Shared implementations for running NLP models locally -COPY ./danswer/search/search_nlp_models.py /app/danswer/search/search_nlp_models.py - -# Request/Response models -COPY ./shared_models /app/shared_models +# Shared between Danswer Backend and Model Server +COPY ./shared_configs /app/shared_configs # Model Server main code COPY ./model_server /app/model_server diff --git a/backend/alembic/versions/173cae5bba26_port_config_store.py b/backend/alembic/versions/173cae5bba26_port_config_store.py new file mode 100644 index 00000000000..4087086bf13 --- /dev/null +++ b/backend/alembic/versions/173cae5bba26_port_config_store.py @@ -0,0 +1,29 @@ +"""Port Config Store + +Revision ID: 173cae5bba26 +Revises: e50154680a5c +Create Date: 2024-03-19 15:30:44.425436 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "173cae5bba26" +down_revision = "e50154680a5c" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "key_value_store", + sa.Column("key", sa.String(), nullable=False), + sa.Column("value", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.PrimaryKeyConstraint("key"), + ) + + +def downgrade() -> None: + op.drop_table("key_value_store") diff --git a/backend/alembic/versions/38eda64af7fe_add_chat_session_sharing.py b/backend/alembic/versions/38eda64af7fe_add_chat_session_sharing.py new file mode 100644 index 00000000000..e77ee186f42 --- /dev/null +++ b/backend/alembic/versions/38eda64af7fe_add_chat_session_sharing.py @@ -0,0 +1,41 @@ +"""Add chat session sharing + +Revision ID: 38eda64af7fe +Revises: 776b3bbe9092 +Create Date: 2024-03-27 19:41:29.073594 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "38eda64af7fe" +down_revision = "776b3bbe9092" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "chat_session", + sa.Column( + "shared_status", + sa.Enum( + "PUBLIC", + "PRIVATE", + name="chatsessionsharedstatus", + native_enum=False, + ), + nullable=True, + ), + ) + op.execute("UPDATE chat_session SET shared_status='PRIVATE'") + op.alter_column( + "chat_session", + "shared_status", + nullable=False, + ) + + +def downgrade() -> None: + op.drop_column("chat_session", "shared_status") diff --git a/backend/alembic/versions/4738e4b3bae1_pg_file_store.py b/backend/alembic/versions/4738e4b3bae1_pg_file_store.py new file mode 100644 index 00000000000..a57102dbe93 --- /dev/null +++ b/backend/alembic/versions/4738e4b3bae1_pg_file_store.py @@ -0,0 +1,28 @@ +"""PG File Store + +Revision ID: 4738e4b3bae1 +Revises: e91df4e935ef +Create Date: 2024-03-20 18:53:32.461518 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "4738e4b3bae1" +down_revision = "e91df4e935ef" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "file_store", + sa.Column("file_name", sa.String(), nullable=False), + sa.Column("lobj_oid", sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint("file_name"), + ) + + +def downgrade() -> None: + op.drop_table("file_store") diff --git a/backend/alembic/versions/776b3bbe9092_remove_remaining_enums.py b/backend/alembic/versions/776b3bbe9092_remove_remaining_enums.py new file mode 100644 index 00000000000..1e2e7cd3c1b --- /dev/null +++ b/backend/alembic/versions/776b3bbe9092_remove_remaining_enums.py @@ -0,0 +1,71 @@ +"""Remove Remaining Enums + +Revision ID: 776b3bbe9092 +Revises: 4738e4b3bae1 +Create Date: 2024-03-22 21:34:27.629444 + +""" +from alembic import op +import sqlalchemy as sa + +from danswer.db.models import IndexModelStatus +from danswer.search.enums import RecencyBiasSetting +from danswer.search.models import SearchType + +# revision identifiers, used by Alembic. +revision = "776b3bbe9092" +down_revision = "4738e4b3bae1" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.alter_column( + "persona", + "search_type", + type_=sa.String, + existing_type=sa.Enum(SearchType, native_enum=False), + existing_nullable=False, + ) + op.alter_column( + "persona", + "recency_bias", + type_=sa.String, + existing_type=sa.Enum(RecencyBiasSetting, native_enum=False), + existing_nullable=False, + ) + + # Because the indexmodelstatus enum does not have a mapping to a string type + # we need this workaround instead of directly changing the type + op.add_column("embedding_model", sa.Column("temp_status", sa.String)) + op.execute("UPDATE embedding_model SET temp_status = status::text") + op.drop_column("embedding_model", "status") + op.alter_column("embedding_model", "temp_status", new_column_name="status") + + op.execute("DROP TYPE IF EXISTS searchtype") + op.execute("DROP TYPE IF EXISTS recencybiassetting") + op.execute("DROP TYPE IF EXISTS indexmodelstatus") + + +def downgrade() -> None: + op.alter_column( + "persona", + "search_type", + type_=sa.Enum(SearchType, native_enum=False), + existing_type=sa.String(length=50), + existing_nullable=False, + ) + op.alter_column( + "persona", + "recency_bias", + type_=sa.Enum(RecencyBiasSetting, native_enum=False), + existing_type=sa.String(length=50), + existing_nullable=False, + ) + op.alter_column( + "embedding_model", + "status", + type_=sa.Enum(IndexModelStatus, native_enum=False), + existing_type=sa.String(length=50), + existing_nullable=False, + ) diff --git a/backend/alembic/versions/91fd3b470d1a_remove_documentsource_from_tag.py b/backend/alembic/versions/91fd3b470d1a_remove_documentsource_from_tag.py new file mode 100644 index 00000000000..b8f1a729222 --- /dev/null +++ b/backend/alembic/versions/91fd3b470d1a_remove_documentsource_from_tag.py @@ -0,0 +1,36 @@ +"""Remove DocumentSource from Tag + +Revision ID: 91fd3b470d1a +Revises: 173cae5bba26 +Create Date: 2024-03-21 12:05:23.956734 + +""" +from alembic import op +import sqlalchemy as sa +from danswer.configs.constants import DocumentSource + +# revision identifiers, used by Alembic. +revision = "91fd3b470d1a" +down_revision = "173cae5bba26" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.alter_column( + "tag", + "source", + type_=sa.String(length=50), + existing_type=sa.Enum(DocumentSource, native_enum=False), + existing_nullable=False, + ) + + +def downgrade() -> None: + op.alter_column( + "tag", + "source", + type_=sa.Enum(DocumentSource, native_enum=False), + existing_type=sa.String(length=50), + existing_nullable=False, + ) diff --git a/backend/alembic/versions/e91df4e935ef_private_personas_documentsets.py b/backend/alembic/versions/e91df4e935ef_private_personas_documentsets.py new file mode 100644 index 00000000000..c18084563da --- /dev/null +++ b/backend/alembic/versions/e91df4e935ef_private_personas_documentsets.py @@ -0,0 +1,118 @@ +"""Private Personas DocumentSets + +Revision ID: e91df4e935ef +Revises: 91fd3b470d1a +Create Date: 2024-03-17 11:47:24.675881 + +""" +import fastapi_users_db_sqlalchemy +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "e91df4e935ef" +down_revision = "91fd3b470d1a" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "document_set__user", + sa.Column("document_set_id", sa.Integer(), nullable=False), + sa.Column( + "user_id", + fastapi_users_db_sqlalchemy.generics.GUID(), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["document_set_id"], + ["document_set.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("document_set_id", "user_id"), + ) + op.create_table( + "persona__user", + sa.Column("persona_id", sa.Integer(), nullable=False), + sa.Column( + "user_id", + fastapi_users_db_sqlalchemy.generics.GUID(), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["persona_id"], + ["persona.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("persona_id", "user_id"), + ) + op.create_table( + "document_set__user_group", + sa.Column("document_set_id", sa.Integer(), nullable=False), + sa.Column( + "user_group_id", + sa.Integer(), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["document_set_id"], + ["document_set.id"], + ), + sa.ForeignKeyConstraint( + ["user_group_id"], + ["user_group.id"], + ), + sa.PrimaryKeyConstraint("document_set_id", "user_group_id"), + ) + op.create_table( + "persona__user_group", + sa.Column("persona_id", sa.Integer(), nullable=False), + sa.Column( + "user_group_id", + sa.Integer(), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["persona_id"], + ["persona.id"], + ), + sa.ForeignKeyConstraint( + ["user_group_id"], + ["user_group.id"], + ), + sa.PrimaryKeyConstraint("persona_id", "user_group_id"), + ) + + op.add_column( + "document_set", + sa.Column("is_public", sa.Boolean(), nullable=True), + ) + # fill in is_public for existing rows + op.execute("UPDATE document_set SET is_public = true WHERE is_public IS NULL") + op.alter_column("document_set", "is_public", nullable=False) + + op.add_column( + "persona", + sa.Column("is_public", sa.Boolean(), nullable=True), + ) + # fill in is_public for existing rows + op.execute("UPDATE persona SET is_public = true WHERE is_public IS NULL") + op.alter_column("persona", "is_public", nullable=False) + + +def downgrade() -> None: + op.drop_column("persona", "is_public") + + op.drop_column("document_set", "is_public") + + op.drop_table("persona__user") + op.drop_table("document_set__user") + op.drop_table("persona__user_group") + op.drop_table("document_set__user_group") diff --git a/backend/alembic/versions/ecab2b3f1a3b_add_overrides_to_the_chat_session.py b/backend/alembic/versions/ecab2b3f1a3b_add_overrides_to_the_chat_session.py new file mode 100644 index 00000000000..791d7e42e07 --- /dev/null +++ b/backend/alembic/versions/ecab2b3f1a3b_add_overrides_to_the_chat_session.py @@ -0,0 +1,40 @@ +"""Add overrides to the chat session + +Revision ID: ecab2b3f1a3b +Revises: 38eda64af7fe +Create Date: 2024-04-01 19:08:21.359102 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "ecab2b3f1a3b" +down_revision = "38eda64af7fe" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "chat_session", + sa.Column( + "llm_override", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + ) + op.add_column( + "chat_session", + sa.Column( + "prompt_override", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + ) + + +def downgrade() -> None: + op.drop_column("chat_session", "prompt_override") + op.drop_column("chat_session", "llm_override") diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 31bdc41a208..975358b6cd0 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -279,13 +279,32 @@ async def logout( # take care of that in `double_check_user` ourself. This is needed, since # we want the /me endpoint to still return a user even if they are not # yet verified, so that the frontend knows they exist -optional_valid_user = fastapi_users.current_user(active=True, optional=True) +optional_fastapi_current_user = fastapi_users.current_user(active=True, optional=True) -async def double_check_user( +async def optional_user_( request: Request, user: User | None, db_session: Session, +) -> User | None: + """NOTE: `request` and `db_session` are not used here, but are included + for the EE version of this function.""" + return user + + +async def optional_user( + request: Request, + user: User | None = Depends(optional_fastapi_current_user), + db_session: Session = Depends(get_session), +) -> User | None: + versioned_fetch_user = fetch_versioned_implementation( + "danswer.auth.users", "optional_user_" + ) + return await versioned_fetch_user(request, user, db_session) + + +async def double_check_user( + user: User | None, optional: bool = DISABLE_AUTH, ) -> User | None: if optional: @@ -307,15 +326,9 @@ async def double_check_user( async def current_user( - request: Request, - user: User | None = Depends(optional_valid_user), - db_session: Session = Depends(get_session), + user: User | None = Depends(optional_user), ) -> User | None: - double_check_user = fetch_versioned_implementation( - "danswer.auth.users", "double_check_user" - ) - user = await double_check_user(request, user, db_session) - return user + return await double_check_user(user) async def current_admin_user(user: User | None = Depends(current_user)) -> User | None: diff --git a/backend/danswer/background/celery/celery.py b/backend/danswer/background/celery/celery.py index 80a8a2a1356..408f12f3a0f 100644 --- a/backend/danswer/background/celery/celery.py +++ b/backend/danswer/background/celery/celery.py @@ -182,7 +182,7 @@ def check_for_document_sets_sync_task() -> None: with Session(get_sqlalchemy_engine()) as db_session: # check if any document sets are not synced document_set_info = fetch_document_sets( - db_session=db_session, include_outdated=True + user_id=None, db_session=db_session, include_outdated=True ) for document_set, _ in document_set_info: if not document_set.is_up_to_date: @@ -226,8 +226,4 @@ def clean_old_temp_files_task( "task": "check_for_document_sets_sync_task", "schedule": timedelta(seconds=5), }, - "clean-old-temp-files": { - "task": "clean_old_temp_files_task", - "schedule": timedelta(minutes=30), - }, } diff --git a/backend/danswer/background/indexing/job_client.py b/backend/danswer/background/indexing/job_client.py index d37690627f5..6b1344b59f8 100644 --- a/backend/danswer/background/indexing/job_client.py +++ b/backend/danswer/background/indexing/job_client.py @@ -6,18 +6,15 @@ https://github.com/celery/celery/issues/7007#issuecomment-1740139367""" from collections.abc import Callable from dataclasses import dataclass +from multiprocessing import Process from typing import Any from typing import Literal from typing import Optional -from typing import TYPE_CHECKING from danswer.utils.logger import setup_logger logger = setup_logger() -if TYPE_CHECKING: - from torch.multiprocessing import Process - JobStatusType = ( Literal["error"] | Literal["finished"] @@ -89,8 +86,6 @@ def _cleanup_completed_jobs(self) -> None: def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None: """NOTE: `pure` arg is needed so this can be a drop in replacement for Dask""" - from torch.multiprocessing import Process - self._cleanup_completed_jobs() if len(self.jobs) >= self.n_workers: logger.debug("No available workers to run job") diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index 6241af6f56b..9e8ee6b7fe5 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -330,20 +330,15 @@ def _run_indexing( ) -def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None: +def run_indexing_entrypoint(index_attempt_id: int) -> None: """Entrypoint for indexing run when using dask distributed. Wraps the actual logic in a `try` block so that we can catch any exceptions and mark the attempt as failed.""" - import torch - try: # set the indexing attempt ID so that all log messages from this process # will have it added as a prefix IndexAttemptSingleton.set_index_attempt_id(index_attempt_id) - logger.info(f"Setting task to use {num_threads} threads") - torch.set_num_threads(num_threads) - with Session(get_sqlalchemy_engine()) as db_session: attempt = get_index_attempt( db_session=db_session, index_attempt_id=index_attempt_id diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index b77ddee859a..6042e02b1cd 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -15,9 +15,7 @@ from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP -from danswer.configs.app_configs import LOG_LEVEL from danswer.configs.app_configs import NUM_INDEXING_WORKERS -from danswer.configs.model_configs import MIN_THREADS_ML_MODELS from danswer.db.connector import fetch_connectors from danswer.db.connector_credential_pair import get_connector_credential_pairs from danswer.db.connector_credential_pair import mark_all_in_progress_cc_pairs_failed @@ -29,7 +27,9 @@ from danswer.db.engine import get_db_current_time from danswer.db.engine import get_sqlalchemy_engine from danswer.db.index_attempt import cancel_indexing_attempts_past_model -from danswer.db.index_attempt import count_unique_cc_pairs_with_index_attempts +from danswer.db.index_attempt import ( + count_unique_cc_pairs_with_successful_index_attempts, +) from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import get_inprogress_index_attempts @@ -41,7 +41,11 @@ from danswer.db.models import IndexAttempt from danswer.db.models import IndexingStatus from danswer.db.models import IndexModelStatus +from danswer.search.search_nlp_models import warm_up_encoders from danswer.utils.logger import setup_logger +from shared_configs.configs import INDEXING_MODEL_SERVER_HOST +from shared_configs.configs import LOG_LEVEL +from shared_configs.configs import MODEL_SERVER_PORT logger = setup_logger() @@ -54,18 +58,6 @@ ) -"""Util funcs""" - - -def _get_num_threads() -> int: - """Get # of "threads" to use for ML models in an indexing job. By default uses - the torch implementation, which returns the # of physical cores on the machine. - """ - import torch - - return max(MIN_THREADS_ML_MODELS, torch.get_num_threads()) - - def _should_create_new_indexing( connector: Connector, last_index: IndexAttempt | None, @@ -344,12 +336,10 @@ def kickoff_indexing_jobs( if use_secondary_index: run = secondary_client.submit( - run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False + run_indexing_entrypoint, attempt.id, pure=False ) else: - run = client.submit( - run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False - ) + run = client.submit(run_indexing_entrypoint, attempt.id, pure=False) if run: secondary_str = "(secondary index) " if use_secondary_index else "" @@ -365,9 +355,9 @@ def kickoff_indexing_jobs( def check_index_swap(db_session: Session) -> None: - """Get count of cc-pairs and count of index_attempts for the new model grouped by - connector + credential, if it's the same, then assume new index is done building. - This does not take into consideration if the attempt failed or not""" + """Get count of cc-pairs and count of successful index_attempts for the + new model grouped by connector + credential, if it's the same, then assume + new index is done building. If so, swap the indices and expire the old one.""" # Default CC-pair created for Ingestion API unused here all_cc_pairs = get_connector_credential_pairs(db_session) cc_pair_count = len(all_cc_pairs) - 1 @@ -376,7 +366,7 @@ def check_index_swap(db_session: Session) -> None: if not embedding_model: return - unique_cc_indexings = count_unique_cc_pairs_with_index_attempts( + unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts( embedding_model_id=embedding_model.id, db_session=db_session ) @@ -407,6 +397,20 @@ def check_index_swap(db_session: Session) -> None: def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None: + engine = get_sqlalchemy_engine() + with Session(engine) as db_session: + db_embedding_model = get_current_db_embedding_model(db_session) + + # So that the first time users aren't surprised by really slow speed of first + # batch of documents indexed + logger.info("Running a first inference to warm up embedding model") + warm_up_encoders( + model_name=db_embedding_model.model_name, + normalize=db_embedding_model.normalize, + model_server_host=INDEXING_MODEL_SERVER_HOST, + model_server_port=MODEL_SERVER_PORT, + ) + client_primary: Client | SimpleJobClient client_secondary: Client | SimpleJobClient if DASK_JOB_CLIENT_ENABLED: @@ -433,7 +437,6 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non client_secondary = SimpleJobClient(n_workers=num_workers) existing_jobs: dict[int, Future | SimpleJob] = {} - engine = get_sqlalchemy_engine() with Session(engine) as db_session: # Previous version did not always clean up cc-pairs well leaving some connectors undeleteable @@ -470,14 +473,6 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non def update__main() -> None: - # needed for CUDA to work with multiprocessing - # NOTE: needs to be done on application startup - # before any other torch code has been run - import torch - - if not DASK_JOB_CLIENT_ENABLED: - torch.multiprocessing.set_start_method("spawn") - logger.info("Starting Indexing Loop") update_loop() diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index fe97b0b3923..d7520955750 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -1,97 +1,29 @@ import re -from collections.abc import Callable -from collections.abc import Iterator from collections.abc import Sequence -from functools import lru_cache -from typing import cast -from langchain.schema.messages import BaseMessage -from langchain.schema.messages import HumanMessage -from langchain.schema.messages import SystemMessage from sqlalchemy.orm import Session -from tiktoken.core import Encoding from danswer.chat.models import CitationInfo -from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc -from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION -from danswer.configs.chat_configs import STOP_STREAM_PAT -from danswer.configs.constants import IGNORE_FOR_QA -from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE -from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS from danswer.db.chat import get_chat_messages_by_session -from danswer.db.chat import get_default_prompt from danswer.db.models import ChatMessage -from danswer.db.models import Persona -from danswer.db.models import Prompt from danswer.indexing.models import InferenceChunk -from danswer.llm.utils import check_number_of_tokens -from danswer.llm.utils import get_default_llm_tokenizer -from danswer.llm.utils import get_default_llm_version -from danswer.llm.utils import get_max_input_tokens -from danswer.llm.utils import tokenizer_trim_content -from danswer.prompts.chat_prompts import ADDITIONAL_INFO -from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT -from danswer.prompts.chat_prompts import CHAT_USER_PROMPT -from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT -from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT -from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT -from danswer.prompts.constants import TRIPLE_BACKTICK -from danswer.prompts.prompt_utils import build_complete_context_str -from danswer.prompts.prompt_utils import build_task_prompt_reminders -from danswer.prompts.prompt_utils import get_current_llm_day_time -from danswer.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT -from danswer.prompts.token_counts import ( - CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT, -) -from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT -from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT -from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT from danswer.utils.logger import setup_logger logger = setup_logger() -@lru_cache() -def build_chat_system_message( - prompt: Prompt, - context_exists: bool, - llm_tokenizer_encode_func: Callable, - citation_line: str = REQUIRE_CITATION_STATEMENT, - no_citation_line: str = NO_CITATION_STATEMENT, -) -> tuple[SystemMessage | None, int]: - system_prompt = prompt.system_prompt.strip() - if prompt.include_citations: - if context_exists: - system_prompt += citation_line - else: - system_prompt += no_citation_line - if prompt.datetime_aware: - if system_prompt: - system_prompt += ADDITIONAL_INFO.format( - datetime_info=get_current_llm_day_time() - ) - else: - system_prompt = get_current_llm_day_time() - - if not system_prompt: - return None, 0 - - token_count = len(llm_tokenizer_encode_func(system_prompt)) - system_msg = SystemMessage(content=system_prompt) - - return system_msg, token_count - - def llm_doc_from_inference_chunk(inf_chunk: InferenceChunk) -> LlmDoc: return LlmDoc( document_id=inf_chunk.document_id, content=inf_chunk.content, + blurb=inf_chunk.blurb, semantic_identifier=inf_chunk.semantic_identifier, source_type=inf_chunk.source_type, metadata=inf_chunk.metadata, updated_at=inf_chunk.updated_at, link=inf_chunk.source_links[0] if inf_chunk.source_links else None, + source_links=inf_chunk.source_links, ) @@ -108,170 +40,6 @@ def map_document_id_order( return order_mapping -def build_chat_user_message( - chat_message: ChatMessage, - prompt: Prompt, - context_docs: list[LlmDoc], - llm_tokenizer_encode_func: Callable, - all_doc_useful: bool, - user_prompt_template: str = CHAT_USER_PROMPT, - context_free_template: str = CHAT_USER_CONTEXT_FREE_PROMPT, - ignore_str: str = DEFAULT_IGNORE_STATEMENT, -) -> tuple[HumanMessage, int]: - user_query = chat_message.message - - if not context_docs: - # Simpler prompt for cases where there is no context - user_prompt = ( - context_free_template.format( - task_prompt=prompt.task_prompt, user_query=user_query - ) - if prompt.task_prompt - else user_query - ) - user_prompt = user_prompt.strip() - token_count = len(llm_tokenizer_encode_func(user_prompt)) - user_msg = HumanMessage(content=user_prompt) - return user_msg, token_count - - context_docs_str = build_complete_context_str( - cast(list[LlmDoc | InferenceChunk], context_docs) - ) - optional_ignore = "" if all_doc_useful else ignore_str - - task_prompt_with_reminder = build_task_prompt_reminders(prompt) - - user_prompt = user_prompt_template.format( - optional_ignore_statement=optional_ignore, - context_docs_str=context_docs_str, - task_prompt=task_prompt_with_reminder, - user_query=user_query, - ) - - user_prompt = user_prompt.strip() - token_count = len(llm_tokenizer_encode_func(user_prompt)) - user_msg = HumanMessage(content=user_prompt) - - return user_msg, token_count - - -def _get_usable_chunks( - chunks: list[InferenceChunk], token_limit: int -) -> list[InferenceChunk]: - total_token_count = 0 - usable_chunks = [] - for chunk in chunks: - chunk_token_count = check_number_of_tokens(chunk.content) - if total_token_count + chunk_token_count > token_limit: - break - - total_token_count += chunk_token_count - usable_chunks.append(chunk) - - # try and return at least one chunk if possible. This chunk will - # get truncated later on in the pipeline. This would only occur if - # the first chunk is larger than the token limit (usually due to character - # count -> token count mismatches caused by special characters / non-ascii - # languages) - if not usable_chunks and chunks: - usable_chunks = [chunks[0]] - - return usable_chunks - - -def get_usable_chunks( - chunks: list[InferenceChunk], - token_limit: int, - offset: int = 0, -) -> list[InferenceChunk]: - offset_into_chunks = 0 - usable_chunks: list[InferenceChunk] = [] - for _ in range(min(offset + 1, 1)): # go through this process at least once - if offset_into_chunks >= len(chunks) and offset_into_chunks > 0: - raise ValueError( - "Chunks offset too large, should not retry this many times" - ) - - usable_chunks = _get_usable_chunks( - chunks=chunks[offset_into_chunks:], token_limit=token_limit - ) - offset_into_chunks += len(usable_chunks) - - return usable_chunks - - -def get_chunks_for_qa( - chunks: list[InferenceChunk], - llm_chunk_selection: list[bool], - token_limit: int | None, - llm_tokenizer: Encoding | None = None, - batch_offset: int = 0, -) -> list[int]: - """ - Gives back indices of chunks to pass into the LLM for Q&A. - - Only selects chunks viable for Q&A, within the token limit, and prioritize those selected - by the LLM in a separate flow (this can be turned off) - - Note, the batch_offset calculation has to count the batches from the beginning each time as - there's no way to know which chunks were included in the prior batches without recounting atm, - this is somewhat slow as it requires tokenizing all the chunks again - """ - token_leeway = 50 - batch_index = 0 - latest_batch_indices: list[int] = [] - token_count = 0 - - # First iterate the LLM selected chunks, then iterate the rest if tokens remaining - for selection_target in [True, False]: - for ind, chunk in enumerate(chunks): - if llm_chunk_selection[ind] is not selection_target or chunk.metadata.get( - IGNORE_FOR_QA - ): - continue - - # We calculate it live in case the user uses a different LLM + tokenizer - chunk_token = check_number_of_tokens(chunk.content) - if chunk_token > DOC_EMBEDDING_CONTEXT_SIZE + token_leeway: - logger.warning( - "Found more tokens in chunk than expected, " - "likely mismatch between embedding and LLM tokenizers. Trimming content..." - ) - chunk.content = tokenizer_trim_content( - content=chunk.content, - desired_length=DOC_EMBEDDING_CONTEXT_SIZE, - tokenizer=llm_tokenizer or get_default_llm_tokenizer(), - ) - - # 50 for an approximate/slight overestimate for # tokens for metadata for the chunk - token_count += chunk_token + token_leeway - - # Always use at least 1 chunk - if ( - token_limit is None - or token_count <= token_limit - or not latest_batch_indices - ): - latest_batch_indices.append(ind) - current_chunk_unused = False - else: - current_chunk_unused = True - - if token_limit is not None and token_count >= token_limit: - if batch_index < batch_offset: - batch_index += 1 - if current_chunk_unused: - latest_batch_indices = [ind] - token_count = chunk_token - else: - latest_batch_indices = [] - token_count = 0 - else: - return latest_batch_indices - - return latest_batch_indices - - def create_chat_chain( chat_session_id: int, db_session: Session, @@ -287,7 +55,7 @@ def create_chat_chain( id_to_msg = {msg.id: msg for msg in all_chat_messages} if not all_chat_messages: - raise ValueError("No messages in Chat Session") + raise RuntimeError("No messages in Chat Session") root_message = all_chat_messages[0] if root_message.parent_message is not None: @@ -341,157 +109,6 @@ def combine_message_chain( return "\n\n".join(message_strs) -_PER_MESSAGE_TOKEN_BUFFER = 7 - - -def find_last_index(lst: list[int], max_prompt_tokens: int) -> int: - """From the back, find the index of the last element to include - before the list exceeds the maximum""" - running_sum = 0 - - last_ind = 0 - for i in range(len(lst) - 1, -1, -1): - running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER - if running_sum > max_prompt_tokens: - last_ind = i + 1 - break - if last_ind >= len(lst): - raise ValueError("Last message alone is too large!") - return last_ind - - -def drop_messages_history_overflow( - system_msg: BaseMessage | None, - system_token_count: int, - history_msgs: list[BaseMessage], - history_token_counts: list[int], - final_msg: BaseMessage, - final_msg_token_count: int, - max_allowed_tokens: int, -) -> list[BaseMessage]: - """As message history grows, messages need to be dropped starting from the furthest in the past. - The System message should be kept if at all possible and the latest user input which is inserted in the - prompt template must be included""" - if len(history_msgs) != len(history_token_counts): - # This should never happen - raise ValueError("Need exactly 1 token count per message for tracking overflow") - - prompt: list[BaseMessage] = [] - - # Start dropping from the history if necessary - all_tokens = history_token_counts + [system_token_count, final_msg_token_count] - ind_prev_msg_start = find_last_index( - all_tokens, max_prompt_tokens=max_allowed_tokens - ) - - if system_msg and ind_prev_msg_start <= len(history_msgs): - prompt.append(system_msg) - - prompt.extend(history_msgs[ind_prev_msg_start:]) - - prompt.append(final_msg) - - return prompt - - -def in_code_block(llm_text: str) -> bool: - count = llm_text.count(TRIPLE_BACKTICK) - return count % 2 != 0 - - -def extract_citations_from_stream( - tokens: Iterator[str], - context_docs: list[LlmDoc], - doc_id_to_rank_map: dict[str, int], - stop_stream: str | None = STOP_STREAM_PAT, -) -> Iterator[DanswerAnswerPiece | CitationInfo]: - llm_out = "" - max_citation_num = len(context_docs) - curr_segment = "" - prepend_bracket = False - cited_inds = set() - hold = "" - for raw_token in tokens: - if stop_stream: - next_hold = hold + raw_token - - if stop_stream in next_hold: - break - - if next_hold == stop_stream[: len(next_hold)]: - hold = next_hold - continue - - token = next_hold - hold = "" - else: - token = raw_token - - # Special case of [1][ where ][ is a single token - # This is where the model attempts to do consecutive citations like [1][2] - if prepend_bracket: - curr_segment += "[" + curr_segment - prepend_bracket = False - - curr_segment += token - llm_out += token - - possible_citation_pattern = r"(\[\d*$)" # [1, [, etc - possible_citation_found = re.search(possible_citation_pattern, curr_segment) - - citation_pattern = r"\[(\d+)\]" # [1], [2] etc - citation_found = re.search(citation_pattern, curr_segment) - - if citation_found and not in_code_block(llm_out): - numerical_value = int(citation_found.group(1)) - if 1 <= numerical_value <= max_citation_num: - context_llm_doc = context_docs[ - numerical_value - 1 - ] # remove 1 index offset - - link = context_llm_doc.link - target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id] - - # Use the citation number for the document's rank in - # the search (or selected docs) results - curr_segment = re.sub( - rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment - ) - - if target_citation_num not in cited_inds: - cited_inds.add(target_citation_num) - yield CitationInfo( - citation_num=target_citation_num, - document_id=context_llm_doc.document_id, - ) - - if link: - curr_segment = re.sub(r"\[", "[[", curr_segment, count=1) - curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1) - - # In case there's another open bracket like [1][, don't want to match this - possible_citation_found = None - - # if we see "[", but haven't seen the right side, hold back - this may be a - # citation that needs to be replaced with a link - if possible_citation_found: - continue - - # Special case with back to back citations [1][2] - if curr_segment and curr_segment[-1] == "[": - curr_segment = curr_segment[:-1] - prepend_bracket = True - - yield DanswerAnswerPiece(answer_piece=curr_segment) - curr_segment = "" - - if curr_segment: - if prepend_bracket: - yield DanswerAnswerPiece(answer_piece="[" + curr_segment) - else: - yield DanswerAnswerPiece(answer_piece=curr_segment) - - def reorganize_citations( answer: str, citations: list[CitationInfo] ) -> tuple[str, list[CitationInfo]]: @@ -547,72 +164,3 @@ def slack_link_format(match: re.Match) -> str: new_citation_info[citation.citation_num] = citation return new_answer, list(new_citation_info.values()) - - -def get_prompt_tokens(prompt: Prompt) -> int: - # Note: currently custom prompts do not allow datetime aware, only default prompts - return ( - check_number_of_tokens(prompt.system_prompt) - + check_number_of_tokens(prompt.task_prompt) - + CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT - + CITATION_STATEMENT_TOKEN_CNT - + CITATION_REMINDER_TOKEN_CNT - + (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0) - + (ADDITIONAL_INFO_TOKEN_CNT if prompt.datetime_aware else 0) - ) - - -# buffer just to be safe so that we don't overflow the token limit due to -# a small miscalculation -_MISC_BUFFER = 40 - - -def compute_max_document_tokens( - persona: Persona, - actual_user_input: str | None = None, - max_llm_token_override: int | None = None, -) -> int: - """Estimates the number of tokens available for context documents. Formula is roughly: - - ( - model_context_window - reserved_output_tokens - prompt_tokens - - (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe) - ) - - The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g. - if we're trying to determine if the user should be able to select another document) then we just set an - arbitrary "upper bound". - """ - llm_name = get_default_llm_version()[0] - if persona.llm_model_version_override: - llm_name = persona.llm_model_version_override - - # if we can't find a number of tokens, just assume some common default - max_input_tokens = ( - max_llm_token_override - if max_llm_token_override - else get_max_input_tokens(model_name=llm_name) - ) - if persona.prompts: - # TODO this may not always be the first prompt - prompt_tokens = get_prompt_tokens(persona.prompts[0]) - else: - prompt_tokens = get_prompt_tokens(get_default_prompt()) - - user_input_tokens = ( - check_number_of_tokens(actual_user_input) - if actual_user_input is not None - else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS - ) - - return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER - - -def compute_max_llm_input_tokens(persona: Persona) -> int: - """Maximum tokens allows in the input to the LLM (of any type).""" - llm_name = get_default_llm_version()[0] - if persona.llm_model_version_override: - llm_name = persona.llm_model_version_override - - input_tokens = get_max_input_tokens(model_name=llm_name) - return input_tokens - _MISC_BUFFER diff --git a/backend/danswer/chat/load_yamls.py b/backend/danswer/chat/load_yamls.py index d85def58d0c..ccc75443749 100644 --- a/backend/danswer/chat/load_yamls.py +++ b/backend/danswer/chat/load_yamls.py @@ -13,7 +13,7 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import DocumentSet as DocumentSetDBModel from danswer.db.models import Prompt as PromptDBModel -from danswer.search.models import RecencyBiasSetting +from danswer.search.enums import RecencyBiasSetting def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None: @@ -97,6 +97,7 @@ def load_personas_from_yaml( document_sets=doc_sets, default_persona=True, shared=True, + is_public=True, db_session=db_session, ) diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index de3f7e4f017..d2dd9f31faf 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -5,10 +5,10 @@ from pydantic import BaseModel from danswer.configs.constants import DocumentSource -from danswer.search.models import QueryFlow +from danswer.search.enums import QueryFlow +from danswer.search.enums import SearchType from danswer.search.models import RetrievalDocs from danswer.search.models import SearchResponse -from danswer.search.models import SearchType class LlmDoc(BaseModel): @@ -16,11 +16,13 @@ class LlmDoc(BaseModel): document_id: str content: str + blurb: str semantic_identifier: str source_type: DocumentSource metadata: dict[str, str | list[str]] updated_at: datetime | None link: str | None + source_links: dict[int, str] | None # First chunk of info for streaming QA @@ -100,9 +102,12 @@ class QAResponse(SearchResponse, DanswerAnswer): error_msg: str | None = None -AnswerQuestionStreamReturn = Iterator[ - DanswerAnswerPiece | DanswerQuotes | DanswerContexts | StreamingError -] +AnswerQuestionPossibleReturn = ( + DanswerAnswerPiece | DanswerQuotes | CitationInfo | DanswerContexts | StreamingError +) + + +AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn] class LLMMetricsContainer(BaseModel): diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 5ebf8ab1584..f904f496382 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -5,16 +5,8 @@ from sqlalchemy.orm import Session -from danswer.chat.chat_utils import build_chat_system_message -from danswer.chat.chat_utils import build_chat_user_message -from danswer.chat.chat_utils import compute_max_document_tokens -from danswer.chat.chat_utils import compute_max_llm_input_tokens from danswer.chat.chat_utils import create_chat_chain -from danswer.chat.chat_utils import drop_messages_history_overflow -from danswer.chat.chat_utils import extract_citations_from_stream -from danswer.chat.chat_utils import get_chunks_for_qa from danswer.chat.chat_utils import llm_doc_from_inference_chunk -from danswer.chat.chat_utils import map_document_id_order from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc @@ -23,9 +15,7 @@ from danswer.chat.models import StreamingError from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT -from danswer.configs.constants import DISABLED_GEN_AI_MSG from danswer.configs.constants import MessageType -from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.db.chat import create_db_search_doc from danswer.db.chat import create_new_chat_message from danswer.db.chat import get_chat_message @@ -36,27 +26,25 @@ from danswer.db.chat import translate_db_message_to_chat_message_detail from danswer.db.chat import translate_db_search_doc_to_server_search_doc from danswer.db.embedding_model import get_current_db_embedding_model -from danswer.db.models import ChatMessage -from danswer.db.models import Persona +from danswer.db.engine import get_session_context_manager from danswer.db.models import SearchDoc as DbSearchDoc from danswer.db.models import User from danswer.document_index.factory import get_default_document_index -from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.answer import Answer +from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import CitationConfig +from danswer.llm.answering.models import DocumentPruningConfig +from danswer.llm.answering.models import LLMConfig +from danswer.llm.answering.models import PreviousMessage +from danswer.llm.answering.models import PromptConfig from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm -from danswer.llm.interfaces import LLM from danswer.llm.utils import get_default_llm_tokenizer -from danswer.llm.utils import get_default_llm_version -from danswer.llm.utils import get_max_input_tokens -from danswer.llm.utils import tokenizer_trim_content -from danswer.llm.utils import translate_history_to_basemessages -from danswer.prompts.prompt_utils import build_doc_context_str from danswer.search.models import OptionalSearchSetting -from danswer.search.models import RetrievalDetails -from danswer.search.request_preprocessing import retrieval_preprocessing -from danswer.search.search_runner import chunks_to_search_docs -from danswer.search.search_runner import full_chunk_search_generator -from danswer.search.search_runner import inference_documents_from_ids +from danswer.search.models import SearchRequest +from danswer.search.pipeline import SearchPipeline +from danswer.search.retrieval.search_runner import inference_documents_from_ids +from danswer.search.utils import chunks_to_search_docs from danswer.secondary_llm_flows.choose_search import check_if_need_search from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase from danswer.server.query_and_chat.models import ChatMessageDetail @@ -68,72 +56,6 @@ logger = setup_logger() -def generate_ai_chat_response( - query_message: ChatMessage, - history: list[ChatMessage], - persona: Persona, - context_docs: list[LlmDoc], - doc_id_to_rank_map: dict[str, int], - llm: LLM | None, - llm_tokenizer_encode_func: Callable, - all_doc_useful: bool, -) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]: - if llm is None: - try: - llm = get_default_llm() - except GenAIDisabledException: - # Not an error if it's a user configuration - yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG) - return - - if query_message.prompt is None: - raise RuntimeError("No prompt received for generating Gen AI answer.") - - try: - context_exists = len(context_docs) > 0 - - system_message_or_none, system_tokens = build_chat_system_message( - prompt=query_message.prompt, - context_exists=context_exists, - llm_tokenizer_encode_func=llm_tokenizer_encode_func, - ) - - history_basemessages, history_token_counts = translate_history_to_basemessages( - history - ) - - # Be sure the context_docs passed to build_chat_user_message - # Is the same as passed in later for extracting citations - user_message, user_tokens = build_chat_user_message( - chat_message=query_message, - prompt=query_message.prompt, - context_docs=context_docs, - llm_tokenizer_encode_func=llm_tokenizer_encode_func, - all_doc_useful=all_doc_useful, - ) - - prompt = drop_messages_history_overflow( - system_msg=system_message_or_none, - system_token_count=system_tokens, - history_msgs=history_basemessages, - history_token_counts=history_token_counts, - final_msg=user_message, - final_msg_token_count=user_tokens, - max_allowed_tokens=compute_max_llm_input_tokens(persona), - ) - - # Good Debug/Breakpoint - tokens = llm.stream(prompt) - - yield from extract_citations_from_stream( - tokens, context_docs, doc_id_to_rank_map - ) - - except Exception as e: - logger.exception(f"LLM failed to produce valid chat message, error: {e}") - yield StreamingError(error=str(e)) - - def translate_citations( citations_list: list[CitationInfo], db_docs: list[DbSearchDoc] ) -> dict[int, int]: @@ -154,24 +76,30 @@ def translate_citations( return citation_to_saved_doc_id_map +ChatPacketStream = Iterator[ + StreamingError + | QADocsResponse + | LLMRelevanceFilterResponse + | ChatMessageDetail + | DanswerAnswerPiece + | CitationInfo +] + + def stream_chat_message_objects( new_msg_req: CreateChatMessageRequest, user: User | None, db_session: Session, # Needed to translate persona num_chunks to tokens to the LLM default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT, - default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE, # For flow with search, don't include as many chunks as possible since we need to leave space # for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE, -) -> Iterator[ - StreamingError - | QADocsResponse - | LLMRelevanceFilterResponse - | ChatMessageDetail - | DanswerAnswerPiece - | CitationInfo -]: + # if specified, uses the last user message and does not create a new user message based + # on the `new_msg_req.message`. Currently, requires a state where the last message is a + # user message (e.g. this can only be used for the chat-seeding flow). + use_existing_user_message: bool = False, +) -> ChatPacketStream: """Streams in order: 1. [conditional] Retrieved documents if a search needs to be run 2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on @@ -237,33 +165,43 @@ def stream_chat_message_objects( else: parent_message = root_message - # Create new message at the right place in the tree and update the parent's child pointer - # Don't commit yet until we verify the chat message chain - new_user_message = create_new_chat_message( - chat_session_id=chat_session_id, - parent_message=parent_message, - prompt_id=prompt_id, - message=message_text, - token_count=len(llm_tokenizer_encode_func(message_text)), - message_type=MessageType.USER, - db_session=db_session, - commit=False, - ) - - # Create linear history of messages - final_msg, history_msgs = create_chat_chain( - chat_session_id=chat_session_id, db_session=db_session - ) - - if final_msg.id != new_user_message.id: - db_session.rollback() - raise RuntimeError( - "The new message was not on the mainline. " - "Be sure to update the chat pointers before calling this." + if not use_existing_user_message: + # Create new message at the right place in the tree and update the parent's child pointer + # Don't commit yet until we verify the chat message chain + user_message = create_new_chat_message( + chat_session_id=chat_session_id, + parent_message=parent_message, + prompt_id=prompt_id, + message=message_text, + token_count=len(llm_tokenizer_encode_func(message_text)), + message_type=MessageType.USER, + db_session=db_session, + commit=False, + ) + # re-create linear history of messages + final_msg, history_msgs = create_chat_chain( + chat_session_id=chat_session_id, db_session=db_session ) + if final_msg.id != user_message.id: + db_session.rollback() + raise RuntimeError( + "The new message was not on the mainline. " + "Be sure to update the chat pointers before calling this." + ) - # Save now to save the latest chat message - db_session.commit() + # Save now to save the latest chat message + db_session.commit() + else: + # re-create linear history of messages + final_msg, history_msgs = create_chat_chain( + chat_session_id=chat_session_id, db_session=db_session + ) + if final_msg.message_type != MessageType.USER: + raise RuntimeError( + "The last message was not a user message. Cannot call " + "`stream_chat_message_objects` with `is_regenerate=True` " + "when the last message is not a user message." + ) run_search = False # Retrieval options are only None if reference_doc_ids are provided @@ -277,10 +215,6 @@ def stream_chat_message_objects( query_message=final_msg, history=history_msgs, llm=llm ) - max_document_tokens = compute_max_document_tokens( - persona=persona, actual_user_input=message_text - ) - rephrased_query = None if reference_doc_ids: identifier_tuples = get_doc_query_identifiers_from_model( @@ -296,64 +230,8 @@ def stream_chat_message_objects( doc_identifiers=identifier_tuples, document_index=document_index, ) - - # truncate the last document if it exceeds the token limit - tokens_per_doc = [ - len( - llm_tokenizer_encode_func( - build_doc_context_str( - semantic_identifier=llm_doc.semantic_identifier, - source_type=llm_doc.source_type, - content=llm_doc.content, - metadata_dict=llm_doc.metadata, - updated_at=llm_doc.updated_at, - ind=ind, - ) - ) - ) - for ind, llm_doc in enumerate(llm_docs) - ] - final_doc_ind = None - total_tokens = 0 - for ind, tokens in enumerate(tokens_per_doc): - total_tokens += tokens - if total_tokens > max_document_tokens: - final_doc_ind = ind - break - if final_doc_ind is not None: - # only allow the final document to get truncated - # if more than that, then the user message is too long - if final_doc_ind != len(tokens_per_doc) - 1: - yield StreamingError( - error="LLM context window exceeded. Please de-select some documents or shorten your query." - ) - return - - final_doc_desired_length = tokens_per_doc[final_doc_ind] - ( - total_tokens - max_document_tokens - ) - # 75 tokens is a reasonable over-estimate of the metadata and title - final_doc_content_length = final_doc_desired_length - 75 - # this could occur if we only have space for the title / metadata - # not ideal, but it's the most reasonable thing to do - # NOTE: the frontend prevents documents from being selected if - # less than 75 tokens are available to try and avoid this situation - # from occuring in the first place - if final_doc_content_length <= 0: - logger.error( - f"Final doc ({llm_docs[final_doc_ind].semantic_identifier}) content " - "length is less than 0. Removing this doc from the final prompt." - ) - llm_docs.pop() - else: - llm_docs[final_doc_ind].content = tokenizer_trim_content( - content=llm_docs[final_doc_ind].content, - desired_length=final_doc_content_length, - tokenizer=llm_tokenizer, - ) - - doc_id_to_rank_map = map_document_id_order( - cast(list[InferenceChunk | LlmDoc], llm_docs) + document_pruning_config = DocumentPruningConfig( + is_manually_selected_docs=True ) # In case the search doc is deleted, just don't include it @@ -376,36 +254,21 @@ def stream_chat_message_objects( else query_override ) - ( - retrieval_request, - predicted_search_type, - predicted_flow, - ) = retrieval_preprocessing( - query=rephrased_query, - retrieval_details=cast(RetrievalDetails, retrieval_options), - persona=persona, + search_pipeline = SearchPipeline( + search_request=SearchRequest( + query=rephrased_query, + human_selected_filters=retrieval_options.filters + if retrieval_options + else None, + persona=persona, + offset=retrieval_options.offset if retrieval_options else None, + limit=retrieval_options.limit if retrieval_options else None, + ), user=user, db_session=db_session, ) - documents_generator = full_chunk_search_generator( - search_query=retrieval_request, - document_index=document_index, - db_session=db_session, - ) - time_cutoff = retrieval_request.filters.time_cutoff - recency_bias_multiplier = retrieval_request.recency_bias_multiplier - run_llm_chunk_filter = not retrieval_request.skip_llm_chunk_filter - - # First fetch and return the top chunks to the UI so the user can - # immediately see some results - top_chunks = cast(list[InferenceChunk], next(documents_generator)) - - # Get ranking of the documents for citation purposes later - doc_id_to_rank_map = map_document_id_order( - cast(list[InferenceChunk | LlmDoc], top_chunks) - ) - + top_chunks = search_pipeline.reranked_docs top_docs = chunks_to_search_docs(top_chunks) reference_db_search_docs = [ @@ -421,68 +284,41 @@ def stream_chat_message_objects( initial_response = QADocsResponse( rephrased_query=rephrased_query, top_documents=response_docs, - predicted_flow=predicted_flow, - predicted_search=predicted_search_type, - applied_source_filters=retrieval_request.filters.source_type, - applied_time_cutoff=time_cutoff, - recency_bias_multiplier=recency_bias_multiplier, + predicted_flow=search_pipeline.predicted_flow, + predicted_search=search_pipeline.predicted_search_type, + applied_source_filters=search_pipeline.search_query.filters.source_type, + applied_time_cutoff=search_pipeline.search_query.filters.time_cutoff, + recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier, ) yield initial_response - # Get the final ordering of chunks for the LLM call - llm_chunk_selection = cast(list[bool], next(documents_generator)) - # Yield the list of LLM selected chunks for showing the LLM selected icons in the UI llm_relevance_filtering_response = LLMRelevanceFilterResponse( - relevant_chunk_indices=[ - index for index, value in enumerate(llm_chunk_selection) if value - ] - if run_llm_chunk_filter - else [] + relevant_chunk_indices=search_pipeline.relevant_chunk_indicies ) yield llm_relevance_filtering_response - # Prep chunks to pass to LLM - num_llm_chunks = ( - persona.num_chunks - if persona.num_chunks is not None - else default_num_chunks + document_pruning_config = DocumentPruningConfig( + max_chunks=int( + persona.num_chunks + if persona.num_chunks is not None + else default_num_chunks + ), + max_window_percentage=max_document_percentage, ) - llm_name = get_default_llm_version()[0] - if persona.llm_model_version_override: - llm_name = persona.llm_model_version_override - - llm_max_input_tokens = get_max_input_tokens(model_name=llm_name) - - llm_token_based_chunk_lim = max_document_percentage * llm_max_input_tokens - - chunk_token_limit = int( - min( - num_llm_chunks * default_chunk_size, - max_document_tokens, - llm_token_based_chunk_lim, - ) - ) - llm_chunks_indices = get_chunks_for_qa( - chunks=top_chunks, - llm_chunk_selection=llm_chunk_selection, - token_limit=chunk_token_limit, - llm_tokenizer=llm_tokenizer, - ) - llm_chunks = [top_chunks[i] for i in llm_chunks_indices] - llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in llm_chunks] + llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in top_chunks] else: llm_docs = [] - doc_id_to_rank_map = {} reference_db_search_docs = None + document_pruning_config = DocumentPruningConfig() # Cannot determine these without the LLM step or breaking out early partial_response = partial( create_new_chat_message, chat_session_id=chat_session_id, - parent_message=new_user_message, + parent_message=final_msg, prompt_id=prompt_id, # message=, rephrased_query=rephrased_query, @@ -514,33 +350,32 @@ def stream_chat_message_objects( return # LLM prompt building, response capturing, etc. - response_packets = generate_ai_chat_response( - query_message=final_msg, - history=history_msgs, - persona=persona, - context_docs=llm_docs, - doc_id_to_rank_map=doc_id_to_rank_map, - llm=llm, - llm_tokenizer_encode_func=llm_tokenizer_encode_func, - all_doc_useful=reference_doc_ids is not None, + answer = Answer( + question=final_msg.message, + docs=llm_docs, + answer_style_config=AnswerStyleConfig( + citation_config=CitationConfig( + all_docs_useful=reference_db_search_docs is not None + ), + document_pruning_config=document_pruning_config, + ), + prompt_config=PromptConfig.from_model( + final_msg.prompt, + prompt_override=( + new_msg_req.prompt_override or chat_session.prompt_override + ), + ), + llm_config=LLMConfig.from_persona( + persona, + llm_override=(new_msg_req.llm_override or chat_session.llm_override), + ), + message_history=[ + PreviousMessage.from_chat_message(msg) for msg in history_msgs + ], ) + # generator will not include quotes, so we can cast + yield from cast(ChatPacketStream, answer.processed_streamed_output) - # Capture outputs and errors - llm_output = "" - error: str | None = None - citations: list[CitationInfo] = [] - for packet in response_packets: - if isinstance(packet, DanswerAnswerPiece): - token = packet.answer_piece - if token: - llm_output += token - elif isinstance(packet, StreamingError): - error = packet.error - elif isinstance(packet, CitationInfo): - citations.append(packet) - continue - - yield packet except Exception as e: logger.exception(e) @@ -554,16 +389,16 @@ def stream_chat_message_objects( db_citations = None if reference_db_search_docs: db_citations = translate_citations( - citations_list=citations, + citations_list=answer.citations, db_docs=reference_db_search_docs, ) # Saving Gen AI answer and responding with message info gen_ai_response_message = partial_response( - message=llm_output, - token_count=len(llm_tokenizer_encode_func(llm_output)), + message=answer.llm_answer, + token_count=len(llm_tokenizer_encode_func(answer.llm_answer)), citations=db_citations, - error=error, + error=None, ) msg_detail_response = translate_db_message_to_chat_message_detail( @@ -582,12 +417,14 @@ def stream_chat_message_objects( def stream_chat_message( new_msg_req: CreateChatMessageRequest, user: User | None, - db_session: Session, + use_existing_user_message: bool = False, ) -> Iterator[str]: - objects = stream_chat_message_objects( - new_msg_req=new_msg_req, - user=user, - db_session=db_session, - ) - for obj in objects: - yield get_json_line(obj.dict()) + with get_session_context_manager() as db_session: + objects = stream_chat_message_objects( + new_msg_req=new_msg_req, + user=user, + db_session=db_session, + use_existing_user_message=use_existing_user_message, + ) + for obj in objects: + yield get_json_line(obj.dict()) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index cff1b8e5c75..1e4809d0716 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -157,6 +157,11 @@ ) if ignored_tag ] +JIRA_CONNECTOR_LABELS_TO_SKIP = [ + ignored_tag + for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",") + if ignored_tag +] GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME") @@ -204,28 +209,11 @@ ) -##### -# Model Server Configs -##### -# If MODEL_SERVER_HOST is set, the NLP models required for Danswer are offloaded to the server via -# requests. Be sure to include the scheme in the MODEL_SERVER_HOST value. -MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or None -MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0" -MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000") - -# specify this env variable directly to have a different model server for the background -# indexing job vs the api server so that background indexing does not effect query-time -# performance -INDEXING_MODEL_SERVER_HOST = ( - os.environ.get("INDEXING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST -) - - ##### # Miscellaneous ##### -DYNAMIC_CONFIG_STORE = os.environ.get( - "DYNAMIC_CONFIG_STORE", "FileSystemBackedDynamicConfigStore" +DYNAMIC_CONFIG_STORE = ( + os.environ.get("DYNAMIC_CONFIG_STORE") or "PostgresBackedDynamicConfigStore" ) DYNAMIC_CONFIG_DIR_PATH = os.environ.get("DYNAMIC_CONFIG_DIR_PATH", "/home/storage") JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default @@ -245,5 +233,7 @@ ) # Anonymous usage telemetry DISABLE_TELEMETRY = os.environ.get("DISABLE_TELEMETRY", "").lower() == "true" -# notset, debug, info, warning, error, or critical -LOG_LEVEL = os.environ.get("LOG_LEVEL", "info") + +TOKEN_BUDGET_GLOBALLY_ENABLED = ( + os.environ.get("TOKEN_BUDGET_GLOBALLY_ENABLED", "").lower() == "true" +) diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 356fc2831f6..b961cdfb39e 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -40,6 +40,10 @@ SESSION_KEY = "session" QUERY_EVENT_ID = "query_event_id" LLM_CHUNKS = "llm_chunks" +TOKEN_BUDGET = "token_budget" +TOKEN_BUDGET_TIME_PERIOD = "token_budget_time_period" +ENABLE_TOKEN_BUDGET = "enable_token_budget" +TOKEN_BUDGET_SETTINGS = "token_budget_settings" # For chunking/processing chunks TITLE_SEPARATOR = "\n\r\n" @@ -87,6 +91,7 @@ class DocumentSource(str, Enum): ZENDESK = "zendesk" LOOPIO = "loopio" SHAREPOINT = "sharepoint" + AXERO = "axero" class DocumentIndexType(str, Enum): diff --git a/backend/danswer/configs/danswerbot_configs.py b/backend/danswer/configs/danswerbot_configs.py index 5935c9b999e..192a0594d13 100644 --- a/backend/danswer/configs/danswerbot_configs.py +++ b/backend/danswer/configs/danswerbot_configs.py @@ -21,6 +21,14 @@ DANSWER_REACT_EMOJI = os.environ.get("DANSWER_REACT_EMOJI") or "eyes" # When User needs more help, what should the emoji be DANSWER_FOLLOWUP_EMOJI = os.environ.get("DANSWER_FOLLOWUP_EMOJI") or "sos" +# What kind of message should be shown when someone gives an AI answer feedback to DanswerBot +# Defaults to Private if not provided or invalid +# Private: Only visible to user clicking the feedback +# Anonymous: Public but anonymous +# Public: Visible with the user name who submitted the feedback +DANSWER_BOT_FEEDBACK_VISIBILITY = ( + os.environ.get("DANSWER_BOT_FEEDBACK_VISIBILITY") or "private" +) # Should DanswerBot send an apology message if it's not able to find an answer # That way the user isn't confused as to why DanswerBot reacted but then said nothing # Off by default to be less intrusive (don't want to give a notif that just says we couldnt help) diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index f6cd71f31db..e0d774c82b3 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -37,36 +37,13 @@ ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ") # Purely an optimization, memory limitation consideration BATCH_SIZE_ENCODE_CHUNKS = 8 -# This controls the minimum number of pytorch "threads" to allocate to the embedding -# model. If torch finds more threads on its own, this value is not used. -MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1) - -# Cross Encoder Settings -ENABLE_RERANKING_ASYNC_FLOW = ( - os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true" -) -ENABLE_RERANKING_REAL_TIME_FLOW = ( - os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true" -) -# https://www.sbert.net/docs/pretrained-models/ce-msmarco.html -CROSS_ENCODER_MODEL_ENSEMBLE = [ - "cross-encoder/ms-marco-MiniLM-L-4-v2", - "cross-encoder/ms-marco-TinyBERT-L-2-v2", -] -# For score normalizing purposes, only way is to know the expected ranges +# For score display purposes, only way is to know the expected ranges CROSS_ENCODER_RANGE_MAX = 12 CROSS_ENCODER_RANGE_MIN = -12 -CROSS_EMBED_CONTEXT_SIZE = 512 # Unused currently, can't be used with the current default encoder model due to its output range SEARCH_DISTANCE_CUTOFF = 0 -# Intent model max context size -QUERY_MAX_CONTEXT_SIZE = 256 - -# Danswer custom Deep Learning Models -INTENT_MODEL_VERSION = "danswer/intent-model" - ##### # Generative AI Model Configs diff --git a/backend/shared_models/__init__.py b/backend/danswer/connectors/axero/__init__.py similarity index 100% rename from backend/shared_models/__init__.py rename to backend/danswer/connectors/axero/__init__.py diff --git a/backend/danswer/connectors/axero/connector.py b/backend/danswer/connectors/axero/connector.py new file mode 100644 index 00000000000..f82c6b4494a --- /dev/null +++ b/backend/danswer/connectors/axero/connector.py @@ -0,0 +1,363 @@ +import time +from datetime import datetime +from datetime import timezone +from typing import Any + +import requests +from pydantic import BaseModel + +from danswer.configs.app_configs import INDEX_BATCH_SIZE +from danswer.configs.constants import DocumentSource +from danswer.connectors.cross_connector_utils.html_utils import parse_html_page_basic +from danswer.connectors.cross_connector_utils.miscellaneous_utils import ( + process_in_batches, +) +from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc +from danswer.connectors.cross_connector_utils.rate_limit_wrapper import ( + rate_limit_builder, +) +from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder +from danswer.connectors.interfaces import GenerateDocumentsOutput +from danswer.connectors.interfaces import PollConnector +from danswer.connectors.interfaces import SecondsSinceUnixEpoch +from danswer.connectors.models import ConnectorMissingCredentialError +from danswer.connectors.models import Document +from danswer.connectors.models import Section +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + + +ENTITY_NAME_MAP = {1: "Forum", 3: "Article", 4: "Blog", 9: "Wiki"} + + +def _get_auth_header(api_key: str) -> dict[str, str]: + return {"Rest-Api-Key": api_key} + + +@retry_builder() +@rate_limit_builder(max_calls=5, period=1) +def _rate_limited_request( + endpoint: str, headers: dict, params: dict | None = None +) -> Any: + # https://my.axerosolutions.com/spaces/5/communifire-documentation/wiki/view/370/rest-api + return requests.get(endpoint, headers=headers, params=params) + + +# https://my.axerosolutions.com/spaces/5/communifire-documentation/wiki/view/595/rest-api-get-content-list +def _get_entities( + entity_type: int, + api_key: str, + axero_base_url: str, + start: datetime, + end: datetime, + space_id: str | None = None, +) -> list[dict]: + endpoint = axero_base_url + "api/content/list" + page_num = 1 + pages_fetched = 0 + pages_to_return = [] + break_out = False + while True: + params = { + "EntityType": str(entity_type), + "SortColumn": "DateUpdated", + "SortOrder": "1", # descending + "StartPage": str(page_num), + } + + if space_id is not None: + params["SpaceID"] = space_id + + res = _rate_limited_request( + endpoint, headers=_get_auth_header(api_key), params=params + ) + res.raise_for_status() + + # Axero limitations: + # No next page token, can paginate but things may have changed + # for example, a doc that hasn't been read in by Danswer is updated and is now front of the list + # due to this limitation and the fact that Axero has no rate limiting but API calls can cause + # increased latency for the team, we have to just fetch all the pages quickly to reduce the + # chance of missing a document due to an update (it will still get updated next pass) + # Assumes the volume of data isn't too big to store in memory (probably fine) + data = res.json() + total_records = data["TotalRecords"] + contents = data["ResponseData"] + pages_fetched += len(contents) + logger.debug(f"Fetched {pages_fetched} {ENTITY_NAME_MAP[entity_type]}") + + for page in contents: + update_time = time_str_to_utc(page["DateUpdated"]) + + if update_time > end: + continue + + if update_time < start: + break_out = True + break + + pages_to_return.append(page) + + if pages_fetched >= total_records: + break + + page_num += 1 + + if break_out: + break + + return pages_to_return + + +def _get_obj_by_id(obj_id: int, api_key: str, axero_base_url: str) -> dict: + endpoint = axero_base_url + f"api/content/{obj_id}" + res = _rate_limited_request(endpoint, headers=_get_auth_header(api_key)) + res.raise_for_status() + + return res.json() + + +class AxeroForum(BaseModel): + doc_id: str + title: str + link: str + initial_content: str + responses: list[str] + last_update: datetime + + +def _map_post_to_parent( + posts: dict, + api_key: str, + axero_base_url: str, +) -> list[AxeroForum]: + """Cannot handle in batches since the posts aren't ordered or structured in any way + may need to map any number of them to the initial post""" + epoch_str = "1970-01-01T00:00:00.000" + post_map: dict[int, AxeroForum] = {} + + for ind, post in enumerate(posts): + if (ind + 1) % 25 == 0: + logger.debug(f"Processed {ind + 1} posts or responses") + + post_time = time_str_to_utc( + post.get("DateUpdated") or post.get("DateCreated") or epoch_str + ) + p_id = post.get("ParentContentID") + if p_id in post_map: + axero_forum = post_map[p_id] + axero_forum.responses.insert(0, post.get("ContentSummary")) + axero_forum.last_update = max(axero_forum.last_update, post_time) + else: + initial_post_d = _get_obj_by_id(p_id, api_key, axero_base_url)[ + "ResponseData" + ] + initial_post_time = time_str_to_utc( + initial_post_d.get("DateUpdated") + or initial_post_d.get("DateCreated") + or epoch_str + ) + post_map[p_id] = AxeroForum( + doc_id="AXERO_" + str(initial_post_d.get("ContentID")), + title=initial_post_d.get("ContentTitle"), + link=initial_post_d.get("ContentURL"), + initial_content=initial_post_d.get("ContentSummary"), + responses=[post.get("ContentSummary")], + last_update=max(post_time, initial_post_time), + ) + + return list(post_map.values()) + + +def _get_forums( + api_key: str, + axero_base_url: str, + space_id: str | None = None, +) -> list[dict]: + endpoint = axero_base_url + "api/content/list" + page_num = 1 + pages_fetched = 0 + pages_to_return = [] + break_out = False + + while True: + params = { + "EntityType": "54", + "SortColumn": "DateUpdated", + "SortOrder": "1", # descending + "StartPage": str(page_num), + } + + if space_id is not None: + params["SpaceID"] = space_id + + res = _rate_limited_request( + endpoint, headers=_get_auth_header(api_key), params=params + ) + res.raise_for_status() + + data = res.json() + total_records = data["TotalRecords"] + contents = data["ResponseData"] + pages_fetched += len(contents) + logger.debug(f"Fetched {pages_fetched} forums") + + for page in contents: + pages_to_return.append(page) + + if pages_fetched >= total_records: + break + + page_num += 1 + + if break_out: + break + + return pages_to_return + + +def _translate_forum_to_doc(af: AxeroForum) -> Document: + doc = Document( + id=af.doc_id, + sections=[Section(link=af.link, text=reply) for reply in af.responses], + source=DocumentSource.AXERO, + semantic_identifier=af.title, + doc_updated_at=af.last_update, + metadata={}, + ) + + return doc + + +def _translate_content_to_doc(content: dict) -> Document: + page_text = "" + summary = content.get("ContentSummary") + body = content.get("ContentBody") + if summary: + page_text += f"{summary}\n" + + if body: + content_parsed = parse_html_page_basic(body) + page_text += content_parsed + + doc = Document( + id="AXERO_" + str(content["ContentID"]), + sections=[Section(link=content["ContentURL"], text=page_text)], + source=DocumentSource.AXERO, + semantic_identifier=content["ContentTitle"], + doc_updated_at=time_str_to_utc(content["DateUpdated"]), + metadata={"space": content["SpaceName"]}, + ) + + return doc + + +class AxeroConnector(PollConnector): + def __init__( + self, + # Strings of the integer ids of the spaces + spaces: list[str] | None = None, + include_article: bool = True, + include_blog: bool = True, + include_wiki: bool = True, + include_forum: bool = True, + batch_size: int = INDEX_BATCH_SIZE, + ) -> None: + self.include_article = include_article + self.include_blog = include_blog + self.include_wiki = include_wiki + self.include_forum = include_forum + self.batch_size = batch_size + self.space_ids = spaces + self.axero_key = None + self.base_url = None + + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + self.axero_key = credentials["axero_api_token"] + # As the API key specifically applies to a particular deployment, this is + # included as part of the credential + base_url = credentials["base_url"] + if not base_url.endswith("/"): + base_url += "/" + self.base_url = base_url + return None + + def poll_source( + self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch + ) -> GenerateDocumentsOutput: + if not self.axero_key or not self.base_url: + raise ConnectorMissingCredentialError("Axero") + + start_datetime = datetime.utcfromtimestamp(start).replace(tzinfo=timezone.utc) + end_datetime = datetime.utcfromtimestamp(end).replace(tzinfo=timezone.utc) + + entity_types = [] + if self.include_article: + entity_types.append(3) + if self.include_blog: + entity_types.append(4) + if self.include_wiki: + entity_types.append(9) + + iterable_space_ids = self.space_ids if self.space_ids else [None] + + for space_id in iterable_space_ids: + for entity in entity_types: + axero_obj = _get_entities( + entity_type=entity, + api_key=self.axero_key, + axero_base_url=self.base_url, + start=start_datetime, + end=end_datetime, + space_id=space_id, + ) + yield from process_in_batches( + objects=axero_obj, + process_function=_translate_content_to_doc, + batch_size=self.batch_size, + ) + + if self.include_forum: + forums_posts = _get_forums( + api_key=self.axero_key, + axero_base_url=self.base_url, + space_id=space_id, + ) + + all_axero_forums = _map_post_to_parent( + posts=forums_posts, + api_key=self.axero_key, + axero_base_url=self.base_url, + ) + + filtered_forums = [ + f + for f in all_axero_forums + if f.last_update >= start_datetime and f.last_update <= end_datetime + ] + + yield from process_in_batches( + objects=filtered_forums, + process_function=_translate_forum_to_doc, + batch_size=self.batch_size, + ) + + +if __name__ == "__main__": + import os + + connector = AxeroConnector() + connector.load_credentials( + { + "axero_api_token": os.environ["AXERO_API_TOKEN"], + "base_url": os.environ["AXERO_BASE_URL"], + } + ) + current = time.time() + + one_year_ago = current - 24 * 60 * 60 * 360 + latest_docs = connector.poll_source(one_year_ago, current) + + print(next(latest_docs)) diff --git a/backend/danswer/connectors/confluence/connector.py b/backend/danswer/connectors/confluence/connector.py index 9b25524d6ad..f9f5e7c3bb2 100644 --- a/backend/danswer/connectors/confluence/connector.py +++ b/backend/danswer/connectors/confluence/connector.py @@ -75,7 +75,10 @@ def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, st def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, bool]: - is_confluence_cloud = ".atlassian.net/wiki/spaces/" in wiki_url + is_confluence_cloud = ( + ".atlassian.net/wiki/spaces/" in wiki_url + or ".jira.com/wiki/spaces/" in wiki_url + ) try: if is_confluence_cloud: diff --git a/backend/danswer/connectors/cross_connector_utils/file_utils.py b/backend/danswer/connectors/cross_connector_utils/file_utils.py index b0a9c723fea..c7f662d9af4 100644 --- a/backend/danswer/connectors/cross_connector_utils/file_utils.py +++ b/backend/danswer/connectors/cross_connector_utils/file_utils.py @@ -2,8 +2,7 @@ import os import re import zipfile -from collections.abc import Generator -from pathlib import Path +from collections.abc import Iterator from typing import Any from typing import IO @@ -78,11 +77,11 @@ def is_macos_resource_fork_file(file_name: str) -> bool: # to the zip file. This file should contain a list of objects with the following format: # [{ "filename": "file1.txt", "link": "https://example.com/file1.txt" }] def load_files_from_zip( - zip_location: str | Path, + zip_file_io: IO, ignore_macos_resource_fork_files: bool = True, ignore_dirs: bool = True, -) -> Generator[tuple[zipfile.ZipInfo, IO[Any], dict[str, Any]], None, None]: - with zipfile.ZipFile(zip_location, "r") as zip_file: +) -> Iterator[tuple[zipfile.ZipInfo, IO[Any], dict[str, Any]]]: + with zipfile.ZipFile(zip_file_io, "r") as zip_file: zip_metadata = {} try: metadata_file_info = zip_file.getinfo(".danswer_metadata.json") @@ -109,18 +108,19 @@ def load_files_from_zip( yield file_info, file, zip_metadata.get(file_info.filename, {}) -def detect_encoding(file_path: str | Path) -> str: - with open(file_path, "rb") as file: - raw_data = file.read(50000) # Read a portion of the file to guess encoding - return chardet.detect(raw_data)["encoding"] or "utf-8" +def detect_encoding(file: IO[bytes]) -> str: + raw_data = file.read(50000) + encoding = chardet.detect(raw_data)["encoding"] or "utf-8" + file.seek(0) + return encoding def read_file( - file_reader: IO[Any], encoding: str = "utf-8", errors: str = "replace" + file: IO, encoding: str = "utf-8", errors: str = "replace" ) -> tuple[str, dict]: metadata = {} file_content_raw = "" - for ind, line in enumerate(file_reader): + for ind, line in enumerate(file): try: line = line.decode(encoding) if isinstance(line, bytes) else line except UnicodeDecodeError: diff --git a/backend/danswer/connectors/cross_connector_utils/miscellaneous_utils.py b/backend/danswer/connectors/cross_connector_utils/miscellaneous_utils.py index 10c8315601b..8faf6bfadaf 100644 --- a/backend/danswer/connectors/cross_connector_utils/miscellaneous_utils.py +++ b/backend/danswer/connectors/cross_connector_utils/miscellaneous_utils.py @@ -1,5 +1,8 @@ +from collections.abc import Callable +from collections.abc import Iterator from datetime import datetime from datetime import timezone +from typing import TypeVar from dateutil.parser import parse @@ -43,3 +46,14 @@ def get_experts_stores_representations( reps = [basic_expert_info_representation(owner) for owner in experts] return [owner for owner in reps if owner is not None] + + +T = TypeVar("T") +U = TypeVar("U") + + +def process_in_batches( + objects: list[T], process_function: Callable[[T], U], batch_size: int +) -> Iterator[list[U]]: + for i in range(0, len(objects), batch_size): + yield [process_function(obj) for obj in objects[i : i + batch_size]] diff --git a/backend/danswer/connectors/danswer_jira/connector.py b/backend/danswer/connectors/danswer_jira/connector.py index 5ef833e581d..dfed7ebd16c 100644 --- a/backend/danswer/connectors/danswer_jira/connector.py +++ b/backend/danswer/connectors/danswer_jira/connector.py @@ -8,6 +8,7 @@ from jira.resources import Issue from danswer.configs.app_configs import INDEX_BATCH_SIZE +from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP from danswer.configs.constants import DocumentSource from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from danswer.connectors.interfaces import GenerateDocumentsOutput @@ -68,6 +69,7 @@ def fetch_jira_issues_batch( jira_client: JIRA, batch_size: int = INDEX_BATCH_SIZE, comment_email_blacklist: tuple[str, ...] = (), + labels_to_skip: set[str] | None = None, ) -> tuple[list[Document], int]: doc_batch = [] @@ -82,6 +84,15 @@ def fetch_jira_issues_batch( logger.warning(f"Found Jira object not of type Issue {jira}") continue + if labels_to_skip and any( + label in jira.fields.labels for label in labels_to_skip + ): + logger.info( + f"Skipping {jira.key} because it has a label to skip. Found " + f"labels: {jira.fields.labels}. Labels to skip: {labels_to_skip}." + ) + continue + comments = _get_comment_strs(jira, comment_email_blacklist) semantic_rep = f"{jira.fields.description}\n" + "\n".join( [f"Comment: {comment}" for comment in comments] @@ -143,12 +154,18 @@ def __init__( jira_project_url: str, comment_email_blacklist: list[str] | None = None, batch_size: int = INDEX_BATCH_SIZE, + # if a ticket has one of the labels specified in this list, we will just + # skip it. This is generally used to avoid indexing extra sensitive + # tickets. + labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP, ) -> None: self.batch_size = batch_size self.jira_base, self.jira_project = extract_jira_project(jira_project_url) self.jira_client: JIRA | None = None self._comment_email_blacklist = comment_email_blacklist or [] + self.labels_to_skip = set(labels_to_skip) + @property def comment_email_blacklist(self) -> tuple: return tuple(email.strip() for email in self._comment_email_blacklist) @@ -182,6 +199,8 @@ def load_from_state(self) -> GenerateDocumentsOutput: start_index=start_ind, jira_client=self.jira_client, batch_size=self.batch_size, + comment_email_blacklist=self.comment_email_blacklist, + labels_to_skip=self.labels_to_skip, ) if doc_batch: @@ -218,6 +237,7 @@ def poll_source( jira_client=self.jira_client, batch_size=self.batch_size, comment_email_blacklist=self.comment_email_blacklist, + labels_to_skip=self.labels_to_skip, ) if doc_batch: diff --git a/backend/danswer/connectors/factory.py b/backend/danswer/connectors/factory.py index f4a9ee29083..5e6438088b3 100644 --- a/backend/danswer/connectors/factory.py +++ b/backend/danswer/connectors/factory.py @@ -2,6 +2,7 @@ from typing import Type from danswer.configs.constants import DocumentSource +from danswer.connectors.axero.connector import AxeroConnector from danswer.connectors.bookstack.connector import BookstackConnector from danswer.connectors.confluence.connector import ConfluenceConnector from danswer.connectors.danswer_jira.connector import JiraConnector @@ -70,6 +71,7 @@ def identify_connector_class( DocumentSource.ZENDESK: ZendeskConnector, DocumentSource.LOOPIO: LoopioConnector, DocumentSource.SHAREPOINT: SharepointConnector, + DocumentSource.AXERO: AxeroConnector, } connector_by_source = connector_map.get(source, {}) diff --git a/backend/danswer/connectors/file/connector.py b/backend/danswer/connectors/file/connector.py index f6aeef649e5..fa290a49693 100644 --- a/backend/danswer/connectors/file/connector.py +++ b/backend/danswer/connectors/file/connector.py @@ -1,11 +1,13 @@ import os -from collections.abc import Generator +from collections.abc import Iterator from datetime import datetime from datetime import timezone from pathlib import Path from typing import Any from typing import IO +from sqlalchemy.orm import Session + from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource from danswer.connectors.cross_connector_utils.file_utils import detect_encoding @@ -20,37 +22,40 @@ from danswer.connectors.models import BasicExpertInfo from danswer.connectors.models import Document from danswer.connectors.models import Section +from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.file_store import get_default_file_store from danswer.utils.logger import setup_logger logger = setup_logger() -def _open_files_at_location( - file_path: str | Path, -) -> Generator[tuple[str, IO[Any], dict[str, Any]], Any, None]: - extension = get_file_ext(file_path) +def _read_files_and_metadata( + file_name: str, + db_session: Session, +) -> Iterator[tuple[str, IO, dict[str, Any]]]: + """Reads the file into IO, in the case of a zip file, yields each individual + file contained within, also includes the metadata dict if packaged in the zip""" + extension = get_file_ext(file_name) metadata: dict[str, Any] = {} + directory_path = os.path.dirname(file_name) + + file_content = get_default_file_store(db_session).read_file(file_name, mode="b") if extension == ".zip": for file_info, file, metadata in load_files_from_zip( - file_path, ignore_dirs=True + file_content, ignore_dirs=True ): - yield file_info.filename, file, metadata - elif extension in [".txt", ".md", ".mdx"]: - encoding = detect_encoding(file_path) - with open(file_path, "r", encoding=encoding, errors="replace") as file: - yield os.path.basename(file_path), file, metadata - elif extension == ".pdf": - with open(file_path, "rb") as file: - yield os.path.basename(file_path), file, metadata + yield os.path.join(directory_path, file_info.filename), file, metadata + elif extension in [".txt", ".md", ".mdx", ".pdf"]: + yield file_name, file_content, metadata else: - logger.warning(f"Skipping file '{file_path}' with extension '{extension}'") + logger.warning(f"Skipping file '{file_name}' with extension '{extension}'") def _process_file( file_name: str, file: IO[Any], - metadata: dict[str, Any] = {}, + metadata: dict[str, Any] | None = None, pdf_pass: str | None = None, ) -> list[Document]: extension = get_file_ext(file_name) @@ -65,8 +70,9 @@ def _process_file( file=file, file_name=file_name, pdf_pass=pdf_pass ) else: - file_content_raw, file_metadata = read_file(file) - all_metadata = {**metadata, **file_metadata} + encoding = detect_encoding(file) + file_content_raw, file_metadata = read_file(file, encoding=encoding) + all_metadata = {**metadata, **file_metadata} if metadata else file_metadata # If this is set, we will show this in the UI as the "name" of the file file_display_name_override = all_metadata.get("file_display_name") @@ -114,7 +120,8 @@ def _process_file( Section(link=all_metadata.get("link"), text=file_content_raw.strip()) ], source=DocumentSource.FILE, - semantic_identifier=file_display_name_override or file_name, + semantic_identifier=file_display_name_override + or os.path.basename(file_name), doc_updated_at=final_time_updated, primary_owners=p_owners, secondary_owners=s_owners, @@ -140,24 +147,27 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None def load_from_state(self) -> GenerateDocumentsOutput: documents: list[Document] = [] - for file_location in self.file_locations: - current_datetime = datetime.now(timezone.utc) - files = _open_files_at_location(file_location) - - for file_name, file, metadata in files: - metadata["time_updated"] = metadata.get( - "time_updated", current_datetime - ) - documents.extend( - _process_file(file_name, file, metadata, self.pdf_pass) + with Session(get_sqlalchemy_engine()) as db_session: + for file_path in self.file_locations: + current_datetime = datetime.now(timezone.utc) + files = _read_files_and_metadata( + file_name=str(file_path), db_session=db_session ) - if len(documents) >= self.batch_size: - yield documents - documents = [] + for file_name, file, metadata in files: + metadata["time_updated"] = metadata.get( + "time_updated", current_datetime + ) + documents.extend( + _process_file(file_name, file, metadata, self.pdf_pass) + ) + + if len(documents) >= self.batch_size: + yield documents + documents = [] - if documents: - yield documents + if documents: + yield documents if __name__ == "__main__": diff --git a/backend/danswer/connectors/gmail/connector_auth.py b/backend/danswer/connectors/gmail/connector_auth.py index f6cfa5a7489..39dd9aacf80 100644 --- a/backend/danswer/connectors/gmail/connector_auth.py +++ b/backend/danswer/connectors/gmail/connector_auth.py @@ -24,7 +24,7 @@ from danswer.connectors.gmail.constants import SCOPES from danswer.db.credentials import update_credential_json from danswer.db.models import User -from danswer.dynamic_configs import get_dynamic_config_store +from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.server.documents.models import CredentialBase from danswer.server.documents.models import GoogleAppCredentials from danswer.server.documents.models import GoogleServiceAccountKey diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index 15c9894a653..ea7ef60db70 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -388,7 +388,7 @@ def _process_folder_paths( def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None: """Checks for two different types of credentials. - (1) A credential which holds a token acquired via a user going thorugh + (1) A credential which holds a token acquired via a user going thorough the Google OAuth flow. (2) A credential which holds a service account key JSON file, which can then be used to impersonate any user in the workspace. diff --git a/backend/danswer/connectors/google_drive/connector_auth.py b/backend/danswer/connectors/google_drive/connector_auth.py index f65e177724b..65c34393c72 100644 --- a/backend/danswer/connectors/google_drive/connector_auth.py +++ b/backend/danswer/connectors/google_drive/connector_auth.py @@ -24,7 +24,7 @@ from danswer.connectors.google_drive.constants import SCOPES from danswer.db.credentials import update_credential_json from danswer.db.models import User -from danswer.dynamic_configs import get_dynamic_config_store +from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.server.documents.models import CredentialBase from danswer.server.documents.models import GoogleAppCredentials from danswer.server.documents.models import GoogleServiceAccountKey diff --git a/backend/danswer/connectors/google_site/connector.py b/backend/danswer/connectors/google_site/connector.py index 2a2be5ebe34..38d6e0b143a 100644 --- a/backend/danswer/connectors/google_site/connector.py +++ b/backend/danswer/connectors/google_site/connector.py @@ -5,6 +5,7 @@ from bs4 import BeautifulSoup from bs4 import Tag +from sqlalchemy.orm import Session from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource @@ -15,6 +16,8 @@ from danswer.connectors.interfaces import LoadConnector from danswer.connectors.models import Document from danswer.connectors.models import Section +from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.file_store import get_default_file_store from danswer.utils.logger import setup_logger logger = setup_logger() @@ -66,8 +69,13 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None def load_from_state(self) -> GenerateDocumentsOutput: documents: list[Document] = [] + with Session(get_sqlalchemy_engine()) as db_session: + file_content_io = get_default_file_store(db_session).read_file( + self.zip_path, mode="b" + ) + # load the HTML files - files = load_files_from_zip(self.zip_path) + files = load_files_from_zip(file_content_io) count = 0 for file_info, file_io, _metadata in files: # skip non-published files diff --git a/backend/danswer/connectors/notion/connector.py b/backend/danswer/connectors/notion/connector.py index 28fb47a44d5..e0e307fc56a 100644 --- a/backend/danswer/connectors/notion/connector.py +++ b/backend/danswer/connectors/notion/connector.py @@ -93,7 +93,9 @@ def __init__( self.recursive_index_enabled = recursive_index_enabled or self.root_page_id @retry(tries=3, delay=1, backoff=2) - def _fetch_blocks(self, block_id: str, cursor: str | None = None) -> dict[str, Any]: + def _fetch_child_blocks( + self, block_id: str, cursor: str | None = None + ) -> dict[str, Any] | None: """Fetch all child blocks via the Notion API.""" logger.debug(f"Fetching children of block with ID '{block_id}'") block_url = f"https://api.notion.com/v1/blocks/{block_id}/children" @@ -107,6 +109,15 @@ def _fetch_blocks(self, block_id: str, cursor: str | None = None) -> dict[str, A try: res.raise_for_status() except Exception as e: + if res.status_code == 404: + # this happens when a page is not shared with the integration + # in this case, we should just ignore the page + logger.error( + f"Unable to access block with ID '{block_id}'. " + f"This is likely due to the block not being shared " + f"with the Danswer integration. Exact exception:\n\n{e}" + ) + return None logger.exception(f"Error fetching blocks - {res.json()}") raise e return res.json() @@ -187,24 +198,30 @@ def _read_pages_from_database(self, database_id: str) -> list[str]: return result_pages def _read_blocks( - self, page_block_id: str + self, base_block_id: str ) -> tuple[list[tuple[str, str]], list[str]]: - """Reads blocks for a page""" + """Reads all child blocks for the specified block""" result_lines: list[tuple[str, str]] = [] child_pages: list[str] = [] cursor = None while True: - data = self._fetch_blocks(page_block_id, cursor) + data = self._fetch_child_blocks(base_block_id, cursor) + + # this happens when a block is not shared with the integration + if data is None: + return result_lines, child_pages for result in data["results"]: - logger.debug(f"Found block for page '{page_block_id}': {result}") + logger.debug( + f"Found child block for block with ID '{base_block_id}': {result}" + ) result_block_id = result["id"] result_type = result["type"] result_obj = result[result_type] if result_type == "ai_block": logger.warning( - f"Skipping 'ai_block' ('{result_block_id}') for page '{page_block_id}': " + f"Skipping 'ai_block' ('{result_block_id}') for base block '{base_block_id}': " f"Notion API does not currently support reading AI blocks (as of 24/02/09) " f"(discussion: https://github.com/danswer-ai/danswer/issues/1053)" ) diff --git a/backend/danswer/connectors/web/connector.py b/backend/danswer/connectors/web/connector.py index 8acfaca4259..37b65f8da77 100644 --- a/backend/danswer/connectors/web/connector.py +++ b/backend/danswer/connectors/web/connector.py @@ -42,6 +42,14 @@ class WEB_CONNECTOR_VALID_SETTINGS(str, Enum): UPLOAD = "upload" +def check_internet_connection(url: str) -> None: + try: + response = requests.get(url, timeout=3) + response.raise_for_status() + except (requests.RequestException, ValueError): + raise Exception(f"Unable to reach {url} - check your internet connection") + + def is_valid_url(url: str) -> bool: try: result = urlparse(url) @@ -149,6 +157,10 @@ def __init__( self.to_visit_list = extract_urls_from_sitemap(_ensure_valid_url(base_url)) elif web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.UPLOAD: + logger.warning( + "This is not a UI supported Web Connector flow, " + "are you sure you want to do this?" + ) self.to_visit_list = _read_urls_file(base_url) else: @@ -180,6 +192,7 @@ def load_from_state(self) -> GenerateDocumentsOutput: logger.info(f"Visiting {current_url}") try: + check_internet_connection(current_url) if restart_playwright: playwright, context = start_playwright() restart_playwright = False diff --git a/backend/danswer/danswerbot/slack/constants.py b/backend/danswer/danswerbot/slack/constants.py index a4930b593c3..1e524025fc7 100644 --- a/backend/danswer/danswerbot/slack/constants.py +++ b/backend/danswer/danswerbot/slack/constants.py @@ -1,3 +1,5 @@ +from enum import Enum + LIKE_BLOCK_ACTION_ID = "feedback-like" DISLIKE_BLOCK_ACTION_ID = "feedback-dislike" FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button" @@ -6,3 +8,9 @@ FOLLOWUP_BUTTON_RESOLVED_ACTION_ID = "followup-resolved-button" SLACK_CHANNEL_ID = "channel_id" VIEW_DOC_FEEDBACK_ID = "view-doc-feedback" + + +class FeedbackVisibility(str, Enum): + PRIVATE = "private" + ANONYMOUS = "anonymous" + PUBLIC = "public" diff --git a/backend/danswer/danswerbot/slack/handlers/handle_buttons.py b/backend/danswer/danswerbot/slack/handlers/handle_buttons.py index 0ca030612f3..bec1959e3cc 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_buttons.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_buttons.py @@ -15,6 +15,7 @@ from danswer.danswerbot.slack.blocks import get_document_feedback_blocks from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID +from danswer.danswerbot.slack.constants import FeedbackVisibility from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID from danswer.danswerbot.slack.constants import VIEW_DOC_FEEDBACK_ID from danswer.danswerbot.slack.utils import build_feedback_id @@ -22,6 +23,7 @@ from danswer.danswerbot.slack.utils import fetch_groupids_from_names from danswer.danswerbot.slack.utils import fetch_userids_from_emails from danswer.danswerbot.slack.utils import get_channel_name_from_id +from danswer.danswerbot.slack.utils import get_feedback_visibility from danswer.danswerbot.slack.utils import respond_in_thread from danswer.danswerbot.slack.utils import update_emote_react from danswer.db.engine import get_sqlalchemy_engine @@ -120,13 +122,33 @@ def handle_slack_feedback( else: logger_base.error(f"Feedback type '{feedback_type}' not supported") - # post message to slack confirming that feedback was received - client.chat_postEphemeral( - channel=channel_id_to_post_confirmation, - user=user_id_to_post_confirmation, - thread_ts=thread_ts_to_post_confirmation, - text="Thanks for your feedback!", - ) + if get_feedback_visibility() == FeedbackVisibility.PRIVATE or feedback_type not in [ + LIKE_BLOCK_ACTION_ID, + DISLIKE_BLOCK_ACTION_ID, + ]: + client.chat_postEphemeral( + channel=channel_id_to_post_confirmation, + user=user_id_to_post_confirmation, + thread_ts=thread_ts_to_post_confirmation, + text="Thanks for your feedback!", + ) + else: + feedback_response_txt = ( + "liked" if feedback_type == LIKE_BLOCK_ACTION_ID else "disliked" + ) + + if get_feedback_visibility() == FeedbackVisibility.ANONYMOUS: + msg = f"A user has {feedback_response_txt} the AI Answer" + else: + msg = f"<@{user_id_to_post_confirmation}> has {feedback_response_txt} the AI Answer" + + respond_in_thread( + client=client, + channel=channel_id_to_post_confirmation, + text=msg, + thread_ts=thread_ts_to_post_confirmation, + unfurl=False, + ) def handle_followup_button( diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index 6720ea86316..22ad323e54c 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -12,7 +12,6 @@ from slack_sdk.models.blocks import DividerBlock from sqlalchemy.orm import Session -from danswer.chat.chat_utils import compute_max_document_tokens from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER @@ -39,6 +38,9 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import SlackBotConfig from danswer.db.models import SlackBotResponseType +from danswer.llm.answering.prompts.citations_prompt import ( + compute_max_document_tokens_for_persona, +) from danswer.llm.utils import check_number_of_tokens from danswer.llm.utils import get_default_llm_version from danswer.llm.utils import get_max_input_tokens @@ -49,6 +51,7 @@ from danswer.search.models import OptionalSearchSetting from danswer.search.models import RetrievalDetails from danswer.utils.logger import setup_logger +from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW logger_base = setup_logger() @@ -247,7 +250,7 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse: query_text = new_message_request.messages[0].message if persona: - max_document_tokens = compute_max_document_tokens( + max_document_tokens = compute_max_document_tokens_for_persona( persona=persona, actual_user_input=query_text, max_llm_token_override=remaining_tokens, @@ -308,6 +311,7 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse: persona_id=persona.id if persona is not None else 0, retrieval_options=retrieval_details, chain_of_thought=not disable_cot, + skip_rerank=not ENABLE_RERANKING_ASYNC_FLOW, ) ) except Exception as e: diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py index 12db7dff957..460ecd32d1f 100644 --- a/backend/danswer/danswerbot/slack/listener.py +++ b/backend/danswer/danswerbot/slack/listener.py @@ -13,7 +13,6 @@ from danswer.configs.constants import MessageType from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER -from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID @@ -43,9 +42,11 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.one_shot_answer.models import ThreadMessage -from danswer.search.search_nlp_models import warm_up_models +from danswer.search.search_nlp_models import warm_up_encoders from danswer.server.manage.models import SlackBotTokens from danswer.utils.logger import setup_logger +from shared_configs.configs import MODEL_SERVER_HOST +from shared_configs.configs import MODEL_SERVER_PORT logger = setup_logger() @@ -390,10 +391,11 @@ def _initialize_socket_client(socket_client: SocketModeClient) -> None: with Session(get_sqlalchemy_engine()) as db_session: embedding_model = get_current_db_embedding_model(db_session) - warm_up_models( + warm_up_encoders( model_name=embedding_model.model_name, normalize=embedding_model.normalize, - skip_cross_encoders=not ENABLE_RERANKING_ASYNC_FLOW, + model_server_host=MODEL_SERVER_HOST, + model_server_port=MODEL_SERVER_PORT, ) slack_bot_tokens = latest_slack_bot_tokens diff --git a/backend/danswer/danswerbot/slack/tokens.py b/backend/danswer/danswerbot/slack/tokens.py index c9c12862820..34d2b79a303 100644 --- a/backend/danswer/danswerbot/slack/tokens.py +++ b/backend/danswer/danswerbot/slack/tokens.py @@ -1,7 +1,7 @@ import os from typing import cast -from danswer.dynamic_configs import get_dynamic_config_store +from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.server.manage.models import SlackBotTokens diff --git a/backend/danswer/danswerbot/slack/utils.py b/backend/danswer/danswerbot/slack/utils.py index 5d761dec0ee..5895dc52f91 100644 --- a/backend/danswer/danswerbot/slack/utils.py +++ b/backend/danswer/danswerbot/slack/utils.py @@ -18,11 +18,13 @@ from danswer.configs.app_configs import DISABLE_TELEMETRY from danswer.configs.constants import ID_SEPARATOR from danswer.configs.constants import MessageType +from danswer.configs.danswerbot_configs import DANSWER_BOT_FEEDBACK_VISIBILITY from danswer.configs.danswerbot_configs import DANSWER_BOT_MAX_QPM from danswer.configs.danswerbot_configs import DANSWER_BOT_MAX_WAIT_TIME from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES from danswer.connectors.slack.utils import make_slack_api_rate_limited from danswer.connectors.slack.utils import SlackTextCleaner +from danswer.danswerbot.slack.constants import FeedbackVisibility from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID from danswer.danswerbot.slack.tokens import fetch_tokens from danswer.db.engine import get_sqlalchemy_engine @@ -449,3 +451,10 @@ def waiter(self, func_randid: int) -> None: self.refill() del self.waiting_questions[0] + + +def get_feedback_visibility() -> FeedbackVisibility: + try: + return FeedbackVisibility(DANSWER_BOT_FEEDBACK_VISIBILITY.lower()) + except ValueError: + return FeedbackVisibility.PRIVATE diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index cc080031976..738d02a1657 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -18,13 +18,19 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import ChatMessage from danswer.db.models import ChatSession +from danswer.db.models import ChatSessionSharedStatus from danswer.db.models import DocumentSet as DBDocumentSet from danswer.db.models import Persona +from danswer.db.models import Persona__User +from danswer.db.models import Persona__UserGroup from danswer.db.models import Prompt from danswer.db.models import SearchDoc from danswer.db.models import SearchDoc as DBSearchDoc from danswer.db.models import StarterMessage -from danswer.search.models import RecencyBiasSetting +from danswer.db.models import User__UserGroup +from danswer.llm.override_models import LLMOverride +from danswer.llm.override_models import PromptOverride +from danswer.search.enums import RecencyBiasSetting from danswer.search.models import RetrievalDocs from danswer.search.models import SavedSearchDoc from danswer.search.models import SearchDoc as ServerSearchDoc @@ -35,11 +41,23 @@ def get_chat_session_by_id( - chat_session_id: int, user_id: UUID | None, db_session: Session + chat_session_id: int, + user_id: UUID | None, + db_session: Session, + include_deleted: bool = False, + is_shared: bool = False, ) -> ChatSession: - stmt = select(ChatSession).where( - ChatSession.id == chat_session_id, ChatSession.user_id == user_id - ) + stmt = select(ChatSession).where(ChatSession.id == chat_session_id) + + if is_shared: + stmt = stmt.where(ChatSession.shared_status == ChatSessionSharedStatus.PUBLIC) + else: + # if user_id is None, assume this is an admin who should be able + # to view all chat sessions + if user_id is not None: + stmt = stmt.where( + or_(ChatSession.user_id == user_id, ChatSession.user_id.is_(None)) + ) result = db_session.execute(stmt) chat_session = result.scalar_one_or_none() @@ -47,7 +65,7 @@ def get_chat_session_by_id( if not chat_session: raise ValueError("Invalid Chat Session ID provided") - if chat_session.deleted: + if not include_deleted and chat_session.deleted: raise ValueError("Chat session has been deleted") return chat_session @@ -78,12 +96,16 @@ def create_chat_session( description: str, user_id: UUID | None, persona_id: int | None = None, + llm_override: LLMOverride | None = None, + prompt_override: PromptOverride | None = None, one_shot: bool = False, ) -> ChatSession: chat_session = ChatSession( user_id=user_id, persona_id=persona_id, description=description, + llm_override=llm_override, + prompt_override=prompt_override, one_shot=one_shot, ) @@ -94,7 +116,11 @@ def create_chat_session( def update_chat_session( - user_id: UUID | None, chat_session_id: int, description: str, db_session: Session + db_session: Session, + user_id: UUID | None, + chat_session_id: int, + description: str | None = None, + sharing_status: ChatSessionSharedStatus | None = None, ) -> ChatSession: chat_session = get_chat_session_by_id( chat_session_id=chat_session_id, user_id=user_id, db_session=db_session @@ -103,7 +129,10 @@ def update_chat_session( if chat_session.deleted: raise ValueError("Trying to rename a deleted chat session") - chat_session.description = description + if description is not None: + chat_session.description = description + if sharing_status is not None: + chat_session.shared_status = sharing_status db_session.commit() @@ -468,6 +497,7 @@ def upsert_persona( llm_model_version_override: str | None, starter_messages: list[StarterMessage] | None, shared: bool, + is_public: bool, db_session: Session, persona_id: int | None = None, default_persona: bool = False, @@ -494,6 +524,7 @@ def upsert_persona( persona.llm_model_version_override = llm_model_version_override persona.starter_messages = starter_messages persona.deleted = False # Un-delete if previously deleted + persona.is_public = is_public # Do not delete any associations manually added unless # a new updated list is provided @@ -509,6 +540,7 @@ def upsert_persona( persona = Persona( id=persona_id, user_id=None if shared else user_id, + is_public=is_public, name=name, description=description, num_chunks=num_chunks, @@ -638,9 +670,28 @@ def get_personas( include_slack_bot_personas: bool = False, include_deleted: bool = False, ) -> Sequence[Persona]: - stmt = select(Persona) + stmt = select(Persona).distinct() if user_id is not None: - stmt = stmt.where(or_(Persona.user_id == user_id, Persona.user_id.is_(None))) + # Subquery to find all groups the user belongs to + user_groups_subquery = ( + select(User__UserGroup.user_group_id) + .where(User__UserGroup.user_id == user_id) + .subquery() + ) + + # Include personas where the user is directly related or part of a user group that has access + access_conditions = or_( + Persona.is_public == True, # noqa: E712 + Persona.id.in_( # User has access through list of users with access + select(Persona__User.persona_id).where(Persona__User.user_id == user_id) + ), + Persona.id.in_( # User is part of a group that has access + select(Persona__UserGroup.persona_id).where( + Persona__UserGroup.user_group_id.in_(user_groups_subquery) # type: ignore + ) + ), + ) + stmt = stmt.where(access_conditions) if not include_default: stmt = stmt.where(Persona.default_persona.is_(False)) @@ -693,7 +744,8 @@ def create_db_search_doc( boost=server_search_doc.boost, hidden=server_search_doc.hidden, doc_metadata=server_search_doc.metadata, - score=server_search_doc.score, + # For docs further down that aren't reranked, we can't use the retrieval score + score=server_search_doc.score or 0.0, match_highlights=server_search_doc.match_highlights, updated_at=server_search_doc.updated_at, primary_owners=server_search_doc.primary_owners, @@ -714,6 +766,7 @@ def get_db_search_doc_by_id(doc_id: int, db_session: Session) -> DBSearchDoc | N def translate_db_search_doc_to_server_search_doc( db_search_doc: SearchDoc, + remove_doc_content: bool = False, ) -> SavedSearchDoc: return SavedSearchDoc( db_doc_id=db_search_doc.id, @@ -721,22 +774,30 @@ def translate_db_search_doc_to_server_search_doc( chunk_ind=db_search_doc.chunk_ind, semantic_identifier=db_search_doc.semantic_id, link=db_search_doc.link, - blurb=db_search_doc.blurb, + blurb=db_search_doc.blurb if not remove_doc_content else "", source_type=db_search_doc.source_type, boost=db_search_doc.boost, hidden=db_search_doc.hidden, - metadata=db_search_doc.doc_metadata, + metadata=db_search_doc.doc_metadata if not remove_doc_content else {}, score=db_search_doc.score, - match_highlights=db_search_doc.match_highlights, - updated_at=db_search_doc.updated_at, - primary_owners=db_search_doc.primary_owners, - secondary_owners=db_search_doc.secondary_owners, + match_highlights=db_search_doc.match_highlights + if not remove_doc_content + else [], + updated_at=db_search_doc.updated_at if not remove_doc_content else None, + primary_owners=db_search_doc.primary_owners if not remove_doc_content else [], + secondary_owners=db_search_doc.secondary_owners + if not remove_doc_content + else [], ) -def get_retrieval_docs_from_chat_message(chat_message: ChatMessage) -> RetrievalDocs: +def get_retrieval_docs_from_chat_message( + chat_message: ChatMessage, remove_doc_content: bool = False +) -> RetrievalDocs: top_documents = [ - translate_db_search_doc_to_server_search_doc(db_doc) + translate_db_search_doc_to_server_search_doc( + db_doc, remove_doc_content=remove_doc_content + ) for db_doc in chat_message.search_docs ] top_documents = sorted(top_documents, key=lambda doc: doc.score, reverse=True) # type: ignore @@ -744,7 +805,7 @@ def get_retrieval_docs_from_chat_message(chat_message: ChatMessage) -> Retrieval def translate_db_message_to_chat_message_detail( - chat_message: ChatMessage, + chat_message: ChatMessage, remove_doc_content: bool = False ) -> ChatMessageDetail: chat_msg_detail = ChatMessageDetail( message_id=chat_message.id, @@ -752,7 +813,9 @@ def translate_db_message_to_chat_message_detail( latest_child_message=chat_message.latest_child_message, message=chat_message.message, rephrased_query=chat_message.rephrased_query, - context_docs=get_retrieval_docs_from_chat_message(chat_message), + context_docs=get_retrieval_docs_from_chat_message( + chat_message, remove_doc_content=remove_doc_content + ), message_type=chat_message.message_type, time_sent=chat_message.time_sent, citations=chat_message.citations, diff --git a/backend/danswer/db/document_set.py b/backend/danswer/db/document_set.py index 848f5088377..c3bab1e741a 100644 --- a/backend/danswer/db/document_set.py +++ b/backend/danswer/db/document_set.py @@ -16,6 +16,7 @@ from danswer.db.models import DocumentSet__ConnectorCredentialPair from danswer.server.features.document_set.models import DocumentSetCreationRequest from danswer.server.features.document_set.models import DocumentSetUpdateRequest +from danswer.utils.variable_functionality import fetch_versioned_implementation def _delete_document_set_cc_pairs__no_commit( @@ -41,6 +42,12 @@ def _mark_document_set_cc_pairs_as_outdated__no_commit( row.is_current = False +def delete_document_set_privacy__no_commit( + document_set_id: int, db_session: Session +) -> None: + """No private document sets in Danswer MIT""" + + def get_document_set_by_id( db_session: Session, document_set_id: int ) -> DocumentSetDBModel | None: @@ -67,6 +74,17 @@ def get_document_sets_by_ids( ).all() +def make_doc_set_private( + document_set_id: int, + user_ids: list[UUID] | None, + group_ids: list[int] | None, + db_session: Session, +) -> None: + # May cause error if someone switches down to MIT from EE + if user_ids or group_ids: + raise NotImplementedError("Danswer MIT does not support private Document Sets") + + def insert_document_set( document_set_creation_request: DocumentSetCreationRequest, user_id: UUID | None, @@ -83,6 +101,7 @@ def insert_document_set( name=document_set_creation_request.name, description=document_set_creation_request.description, user_id=user_id, + is_public=document_set_creation_request.is_public, ) db_session.add(new_document_set_row) db_session.flush() # ensure the new document set gets assigned an ID @@ -96,6 +115,19 @@ def insert_document_set( for cc_pair_id in document_set_creation_request.cc_pair_ids ] db_session.add_all(ds_cc_pairs) + + versioned_private_doc_set_fn = fetch_versioned_implementation( + "danswer.db.document_set", "make_doc_set_private" + ) + + # Private Document Sets + versioned_private_doc_set_fn( + document_set_id=new_document_set_row.id, + user_ids=document_set_creation_request.users, + group_ids=document_set_creation_request.groups, + db_session=db_session, + ) + db_session.commit() except: db_session.rollback() @@ -130,6 +162,19 @@ def update_document_set( document_set_row.description = document_set_update_request.description document_set_row.is_up_to_date = False + document_set_row.is_public = document_set_update_request.is_public + + versioned_private_doc_set_fn = fetch_versioned_implementation( + "danswer.db.document_set", "make_doc_set_private" + ) + + # Private Document Sets + versioned_private_doc_set_fn( + document_set_id=document_set_row.id, + user_ids=document_set_update_request.users, + group_ids=document_set_update_request.groups, + db_session=db_session, + ) # update the attached CC pairs # first, mark all existing CC pairs as not current @@ -205,6 +250,15 @@ def mark_document_set_as_to_be_deleted( _delete_document_set_cc_pairs__no_commit( db_session=db_session, document_set_id=document_set_id ) + + # delete all private document set information + versioned_delete_private_fn = fetch_versioned_implementation( + "danswer.db.document_set", "delete_document_set_privacy__no_commit" + ) + versioned_delete_private_fn( + document_set_id=document_set_id, db_session=db_session + ) + # mark the row as needing a sync, it will be deleted there since there # are no more relationships to cc pairs document_set_row.is_up_to_date = False @@ -248,7 +302,7 @@ def mark_cc_pair__document_set_relationships_to_be_deleted__no_commit( def fetch_document_sets( - db_session: Session, include_outdated: bool = False + user_id: UUID | None, db_session: Session, include_outdated: bool = False ) -> list[tuple[DocumentSetDBModel, list[ConnectorCredentialPair]]]: """Return is a list where each element contains a tuple of: 1. The document set itself @@ -301,6 +355,31 @@ def fetch_document_sets( ] +def fetch_all_document_sets(db_session: Session) -> Sequence[DocumentSetDBModel]: + """Used for Admin UI where they should have visibility into all document sets""" + return db_session.scalars(select(DocumentSetDBModel)).all() + + +def fetch_user_document_sets( + user_id: UUID | None, db_session: Session +) -> list[tuple[DocumentSetDBModel, list[ConnectorCredentialPair]]]: + # If Auth is turned off, all document sets become visible + # document sets are not permission enforced, only for organizational purposes + # the documents themselves are permission enforced + if user_id is None: + return fetch_document_sets( + user_id=user_id, db_session=db_session, include_outdated=True + ) + + versioned_fetch_doc_sets_fn = fetch_versioned_implementation( + "danswer.db.document_set", "fetch_document_sets" + ) + + return versioned_fetch_doc_sets_fn( + user_id=user_id, db_session=db_session, include_outdated=True + ) + + def fetch_documents_for_document_set( document_set_id: int, db_session: Session, current_only: bool = True ) -> Sequence[Document]: @@ -404,6 +483,8 @@ def check_document_sets_are_public( db_session: Session, document_set_ids: list[int], ) -> bool: + """Checks if any of the CC-Pairs are Non Public (meaning that some documents in this document + set is not Public""" connector_credential_pair_ids = ( db_session.query( DocumentSet__ConnectorCredentialPair.connector_credential_pair_id diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 146f11ef81b..1be57179c70 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -1,8 +1,9 @@ +import contextlib from collections.abc import AsyncGenerator from collections.abc import Generator from datetime import datetime +from typing import ContextManager -from ddtrace import tracer from sqlalchemy import text from sqlalchemy.engine import create_engine from sqlalchemy.engine import Engine @@ -10,6 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from danswer.configs.app_configs import POSTGRES_DB from danswer.configs.app_configs import POSTGRES_HOST @@ -69,10 +71,16 @@ def get_sqlalchemy_async_engine() -> AsyncEngine: return _ASYNC_ENGINE +def get_session_context_manager() -> ContextManager: + return contextlib.contextmanager(get_session)() + + def get_session() -> Generator[Session, None, None]: - with tracer.trace("db.get_session"): - with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session: - yield session + # The line below was added to monitor the latency caused by Postgres connections + # during API calls. + # with tracer.trace("db.get_session"): + with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session: + yield session async def get_async_session() -> AsyncGenerator[AsyncSession, None]: @@ -80,3 +88,6 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]: get_sqlalchemy_async_engine(), expire_on_commit=False ) as async_session: yield async_session + + +SessionFactory = sessionmaker(bind=get_sqlalchemy_engine()) diff --git a/backend/danswer/db/enums.py b/backend/danswer/db/enums.py new file mode 100644 index 00000000000..2a02e078c60 --- /dev/null +++ b/backend/danswer/db/enums.py @@ -0,0 +1,35 @@ +from enum import Enum as PyEnum + + +class IndexingStatus(str, PyEnum): + NOT_STARTED = "not_started" + IN_PROGRESS = "in_progress" + SUCCESS = "success" + FAILED = "failed" + + +# these may differ in the future, which is why we're okay with this duplication +class DeletionStatus(str, PyEnum): + NOT_STARTED = "not_started" + IN_PROGRESS = "in_progress" + SUCCESS = "success" + FAILED = "failed" + + +# Consistent with Celery task statuses +class TaskStatus(str, PyEnum): + PENDING = "PENDING" + STARTED = "STARTED" + SUCCESS = "SUCCESS" + FAILURE = "FAILURE" + + +class IndexModelStatus(str, PyEnum): + PAST = "PAST" + PRESENT = "PRESENT" + FUTURE = "FUTURE" + + +class ChatSessionSharedStatus(str, PyEnum): + PUBLIC = "public" + PRIVATE = "private" diff --git a/backend/danswer/db/file_store.py b/backend/danswer/db/file_store.py new file mode 100644 index 00000000000..f0a44bf5da6 --- /dev/null +++ b/backend/danswer/db/file_store.py @@ -0,0 +1,96 @@ +from abc import ABC +from abc import abstractmethod +from typing import IO + +from sqlalchemy.orm import Session + +from danswer.db.pg_file_store import create_populate_lobj +from danswer.db.pg_file_store import delete_lobj_by_id +from danswer.db.pg_file_store import delete_pgfilestore_by_file_name +from danswer.db.pg_file_store import get_pgfilestore_by_file_name +from danswer.db.pg_file_store import read_lobj +from danswer.db.pg_file_store import upsert_pgfilestore + + +class FileStore(ABC): + """ + An abstraction for storing files and large binary objects. + """ + + @abstractmethod + def save_file(self, file_name: str, content: IO) -> None: + """ + Save a file to the blob store + + Parameters: + - connector_name: Name of the CC-Pair (as specified by the user in the UI) + - file_name: Name of the file to save + - content: Contents of the file + """ + raise NotImplementedError + + @abstractmethod + def read_file(self, file_name: str, mode: str | None) -> IO: + """ + Read the content of a given file by the name + + Parameters: + - file_name: Name of file to read + + Returns: + Contents of the file and metadata dict + """ + + @abstractmethod + def delete_file(self, file_name: str) -> None: + """ + Delete a file by its name. + + Parameters: + - file_name: Name of file to delete + """ + + +class PostgresBackedFileStore(FileStore): + def __init__(self, db_session: Session): + self.db_session = db_session + + def save_file(self, file_name: str, content: IO) -> None: + try: + # The large objects in postgres are saved as special objects can can be listed with + # SELECT * FROM pg_largeobject_metadata; + obj_id = create_populate_lobj(content=content, db_session=self.db_session) + upsert_pgfilestore( + file_name=file_name, lobj_oid=obj_id, db_session=self.db_session + ) + self.db_session.commit() + except Exception: + self.db_session.rollback() + raise + + def read_file(self, file_name: str, mode: str | None = None) -> IO: + file_record = get_pgfilestore_by_file_name( + file_name=file_name, db_session=self.db_session + ) + return read_lobj( + lobj_oid=file_record.lobj_oid, db_session=self.db_session, mode=mode + ) + + def delete_file(self, file_name: str) -> None: + try: + file_record = get_pgfilestore_by_file_name( + file_name=file_name, db_session=self.db_session + ) + delete_lobj_by_id(file_record.lobj_oid, db_session=self.db_session) + delete_pgfilestore_by_file_name( + file_name=file_name, db_session=self.db_session + ) + self.db_session.commit() + except Exception: + self.db_session.rollback() + raise + + +def get_default_file_store(db_session: Session) -> FileStore: + # The only supported file store now is the Postgres File Store + return PostgresBackedFileStore(db_session=db_session) diff --git a/backend/danswer/db/index_attempt.py b/backend/danswer/db/index_attempt.py index ce913098eb3..4580140a5f1 100644 --- a/backend/danswer/db/index_attempt.py +++ b/backend/danswer/db/index_attempt.py @@ -291,7 +291,7 @@ def cancel_indexing_attempts_past_model( db_session.commit() -def count_unique_cc_pairs_with_index_attempts( +def count_unique_cc_pairs_with_successful_index_attempts( embedding_model_id: int | None, db_session: Session, ) -> int: @@ -299,12 +299,7 @@ def count_unique_cc_pairs_with_index_attempts( db_session.query(IndexAttempt.connector_id, IndexAttempt.credential_id) .filter( IndexAttempt.embedding_model_id == embedding_model_id, - # Should not be able to hang since indexing jobs expire after a limit - # It will then be marked failed, and the next cycle it will be in a completed state - or_( - IndexAttempt.status == IndexingStatus.SUCCESS, - IndexAttempt.status == IndexingStatus.FAILED, - ), + IndexAttempt.status == IndexingStatus.SUCCESS, ) .distinct() .count() diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 5ca3bdbe94a..7fb6bbaa774 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -35,37 +35,16 @@ from danswer.configs.constants import MessageType from danswer.configs.constants import SearchFeedbackType from danswer.connectors.models import InputType -from danswer.search.models import RecencyBiasSetting -from danswer.search.models import SearchType - - -class IndexingStatus(str, PyEnum): - NOT_STARTED = "not_started" - IN_PROGRESS = "in_progress" - SUCCESS = "success" - FAILED = "failed" - - -# these may differ in the future, which is why we're okay with this duplication -class DeletionStatus(str, PyEnum): - NOT_STARTED = "not_started" - IN_PROGRESS = "in_progress" - SUCCESS = "success" - FAILED = "failed" - - -# Consistent with Celery task statuses -class TaskStatus(str, PyEnum): - PENDING = "PENDING" - STARTED = "STARTED" - SUCCESS = "SUCCESS" - FAILURE = "FAILURE" - - -class IndexModelStatus(str, PyEnum): - PAST = "PAST" - PRESENT = "PRESENT" - FUTURE = "FUTURE" +from danswer.db.enums import ChatSessionSharedStatus +from danswer.db.enums import IndexingStatus +from danswer.db.enums import IndexModelStatus +from danswer.db.enums import TaskStatus +from danswer.db.pydantic_type import PydanticType +from danswer.dynamic_configs.interface import JSON_ro +from danswer.llm.override_models import LLMOverride +from danswer.llm.override_models import PromptOverride +from danswer.search.enums import RecencyBiasSetting +from danswer.search.enums import SearchType class Base(DeclarativeBase): @@ -96,6 +75,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base): "ChatSession", back_populates="user" ) prompts: Mapped[List["Prompt"]] = relationship("Prompt", back_populates="user") + # Personas owned by this user personas: Mapped[List["Persona"]] = relationship("Persona", back_populates="user") @@ -140,6 +120,22 @@ class Persona__Prompt(Base): prompt_id: Mapped[int] = mapped_column(ForeignKey("prompt.id"), primary_key=True) +class Persona__User(Base): + __tablename__ = "persona__user" + + persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True) + user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True) + + +class DocumentSet__User(Base): + __tablename__ = "document_set__user" + + document_set_id: Mapped[int] = mapped_column( + ForeignKey("document_set.id"), primary_key=True + ) + user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True) + + class DocumentSet__ConnectorCredentialPair(Base): __tablename__ = "document_set__connector_credential_pair" @@ -224,7 +220,7 @@ class ConnectorCredentialPair(Base): DateTime(timezone=True), default=None ) last_attempt_status: Mapped[IndexingStatus | None] = mapped_column( - Enum(IndexingStatus) + Enum(IndexingStatus, native_enum=False) ) total_docs_indexed: Mapped[int] = mapped_column(Integer, default=0) @@ -291,7 +287,9 @@ class Tag(Base): id: Mapped[int] = mapped_column(primary_key=True) tag_key: Mapped[str] = mapped_column(String) tag_value: Mapped[str] = mapped_column(String) - source: Mapped[DocumentSource] = mapped_column(Enum(DocumentSource)) + source: Mapped[DocumentSource] = mapped_column( + Enum(DocumentSource, native_enum=False) + ) documents = relationship( "Document", @@ -378,7 +376,9 @@ class EmbeddingModel(Base): normalize: Mapped[bool] = mapped_column(Boolean) query_prefix: Mapped[str] = mapped_column(String) passage_prefix: Mapped[str] = mapped_column(String) - status: Mapped[IndexModelStatus] = mapped_column(Enum(IndexModelStatus)) + status: Mapped[IndexModelStatus] = mapped_column( + Enum(IndexModelStatus, native_enum=False) + ) index_name: Mapped[str] = mapped_column(String) index_attempts: Mapped[List["IndexAttempt"]] = relationship( @@ -423,7 +423,9 @@ class IndexAttempt(Base): # This is only for attempts that are explicitly marked as from the start via # the run once API from_beginning: Mapped[bool] = mapped_column(Boolean) - status: Mapped[IndexingStatus] = mapped_column(Enum(IndexingStatus)) + status: Mapped[IndexingStatus] = mapped_column( + Enum(IndexingStatus, native_enum=False) + ) # The two below may be slightly out of sync if user switches Embedding Model new_docs_indexed: Mapped[int | None] = mapped_column(Integer, default=0) total_docs_indexed: Mapped[int | None] = mapped_column(Integer, default=0) @@ -526,7 +528,9 @@ class SearchDoc(Base): link: Mapped[str | None] = mapped_column(String, nullable=True) blurb: Mapped[str] = mapped_column(String) boost: Mapped[int] = mapped_column(Integer) - source_type: Mapped[DocumentSource] = mapped_column(Enum(DocumentSource)) + source_type: Mapped[DocumentSource] = mapped_column( + Enum(DocumentSource, native_enum=False) + ) hidden: Mapped[bool] = mapped_column(Boolean) doc_metadata: Mapped[dict[str, str | list[str]]] = mapped_column(postgresql.JSONB()) score: Mapped[float] = mapped_column(Float) @@ -560,6 +564,25 @@ class ChatSession(Base): one_shot: Mapped[bool] = mapped_column(Boolean, default=False) # Only ever set to True if system is set to not hard-delete chats deleted: Mapped[bool] = mapped_column(Boolean, default=False) + # controls whether or not this conversation is viewable by others + shared_status: Mapped[ChatSessionSharedStatus] = mapped_column( + Enum(ChatSessionSharedStatus, native_enum=False), + default=ChatSessionSharedStatus.PRIVATE, + ) + + # the latest "overrides" specified by the user. These take precedence over + # the attached persona. However, overrides specified directly in the + # `send-message` call will take precedence over these. + # NOTE: currently only used by the chat seeding flow, will be used in the + # future once we allow users to override default values via the Chat UI + # itself + llm_override: Mapped[LLMOverride | None] = mapped_column( + PydanticType(LLMOverride), nullable=True + ) + prompt_override: Mapped[PromptOverride | None] = mapped_column( + PydanticType(PromptOverride), nullable=True + ) + time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), @@ -599,7 +622,9 @@ class ChatMessage(Base): # If prompt is None, then token_count is 0 as this message won't be passed into # the LLM's context (not included in the history of messages) token_count: Mapped[int] = mapped_column(Integer) - message_type: Mapped[MessageType] = mapped_column(Enum(MessageType)) + message_type: Mapped[MessageType] = mapped_column( + Enum(MessageType, native_enum=False) + ) # Maps the citation numbers to a SearchDoc id citations: Mapped[dict[int, int]] = mapped_column(postgresql.JSONB(), nullable=True) # Only applies for LLM @@ -616,7 +641,7 @@ class ChatMessage(Base): document_feedbacks: Mapped[List["DocumentRetrievalFeedback"]] = relationship( "DocumentRetrievalFeedback", back_populates="chat_message" ) - search_docs = relationship( + search_docs: Mapped[list["SearchDoc"]] = relationship( "SearchDoc", secondary="chat_message__search_doc", back_populates="chat_messages", @@ -638,7 +663,7 @@ class DocumentRetrievalFeedback(Base): document_rank: Mapped[int] = mapped_column(Integer) clicked: Mapped[bool] = mapped_column(Boolean, default=False) feedback: Mapped[SearchFeedbackType | None] = mapped_column( - Enum(SearchFeedbackType), nullable=True + Enum(SearchFeedbackType, native_enum=False), nullable=True ) chat_message: Mapped[ChatMessage] = relationship( @@ -677,6 +702,9 @@ class DocumentSet(Base): user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) # Whether changes to the document set have been propagated is_up_to_date: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + # If `False`, then the document set is not visible to users who are not explicitly + # given access to it either via the `users` or `groups` relationships + is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) connector_credential_pairs: Mapped[list[ConnectorCredentialPair]] = relationship( "ConnectorCredentialPair", @@ -689,6 +717,18 @@ class DocumentSet(Base): secondary=Persona__DocumentSet.__table__, back_populates="document_sets", ) + # Other users with access + users: Mapped[list[User]] = relationship( + "User", + secondary=DocumentSet__User.__table__, + viewonly=True, + ) + # EE only + groups: Mapped[list["UserGroup"]] = relationship( + "UserGroup", + secondary="document_set__user_group", + viewonly=True, + ) class Prompt(Base): @@ -735,7 +775,7 @@ class Persona(Base): description: Mapped[str] = mapped_column(String) # Currently stored but unused, all flows use hybrid search_type: Mapped[SearchType] = mapped_column( - Enum(SearchType), default=SearchType.HYBRID + Enum(SearchType, native_enum=False), default=SearchType.HYBRID ) # Number of chunks to pass to the LLM for generation. num_chunks: Mapped[float | None] = mapped_column(Float, nullable=True) @@ -745,7 +785,9 @@ class Persona(Base): # Enables using LLM to extract time and source type filters # Can also be admin disabled globally llm_filter_extraction: Mapped[bool] = mapped_column(Boolean) - recency_bias: Mapped[RecencyBiasSetting] = mapped_column(Enum(RecencyBiasSetting)) + recency_bias: Mapped[RecencyBiasSetting] = mapped_column( + Enum(RecencyBiasSetting, native_enum=False) + ) # Allows the Persona to specify a different LLM version than is controlled # globablly via env variables. For flexibility, validity is not currently enforced # NOTE: only is applied on the actual response generation - is not used for things like @@ -766,6 +808,7 @@ class Persona(Base): # where lower value IDs (e.g. created earlier) are displayed first display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None) deleted: Mapped[bool] = mapped_column(Boolean, default=False) + is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) # These are only defaults, users can select from all if desired prompts: Mapped[list[Prompt]] = relationship( @@ -779,7 +822,20 @@ class Persona(Base): secondary=Persona__DocumentSet.__table__, back_populates="personas", ) + # Owner user: Mapped[User] = relationship("User", back_populates="personas") + # Other users with access + users: Mapped[list[User]] = relationship( + "User", + secondary=Persona__User.__table__, + viewonly=True, + ) + # EE only + groups: Mapped[list["UserGroup"]] = relationship( + "UserGroup", + secondary="persona__user_group", + viewonly=True, + ) # Default personas loaded via yaml cannot have the same name __table_args__ = ( @@ -844,10 +900,143 @@ class TaskQueueState(Base): # For any job type, this would be the same task_name: Mapped[str] = mapped_column(String) # Note that if the task dies, this won't necessarily be marked FAILED correctly - status: Mapped[TaskStatus] = mapped_column(Enum(TaskStatus)) + status: Mapped[TaskStatus] = mapped_column(Enum(TaskStatus, native_enum=False)) start_time: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True) ) register_time: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) + + +class KVStore(Base): + __tablename__ = "key_value_store" + + key: Mapped[str] = mapped_column(String, primary_key=True) + value: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=False) + + +class PGFileStore(Base): + __tablename__ = "file_store" + file_name = mapped_column(String, primary_key=True) + lobj_oid = mapped_column(Integer, nullable=False) + + +""" +************************************************************************ +Enterprise Edition Models +************************************************************************ + +These models are only used in Enterprise Edition only features in Danswer. +They are kept here to simplify the codebase and avoid having different assumptions +on the shape of data being passed around between the MIT and EE versions of Danswer. + +In the MIT version of Danswer, assume these tables are always empty. +""" + + +class SamlAccount(Base): + __tablename__ = "saml" + + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), unique=True) + encrypted_cookie: Mapped[str] = mapped_column(Text, unique=True) + expires_at: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True)) + updated_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) + + user: Mapped[User] = relationship("User") + + +class User__UserGroup(Base): + __tablename__ = "user__user_group" + + user_group_id: Mapped[int] = mapped_column( + ForeignKey("user_group.id"), primary_key=True + ) + user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True) + + +class UserGroup__ConnectorCredentialPair(Base): + __tablename__ = "user_group__connector_credential_pair" + + user_group_id: Mapped[int] = mapped_column( + ForeignKey("user_group.id"), primary_key=True + ) + cc_pair_id: Mapped[int] = mapped_column( + ForeignKey("connector_credential_pair.id"), primary_key=True + ) + # if `True`, then is part of the current state of the UserGroup + # if `False`, then is a part of the prior state of the UserGroup + # rows with `is_current=False` should be deleted when the UserGroup + # is updated and should not exist for a given UserGroup if + # `UserGroup.is_up_to_date == True` + is_current: Mapped[bool] = mapped_column( + Boolean, + default=True, + primary_key=True, + ) + + cc_pair: Mapped[ConnectorCredentialPair] = relationship( + "ConnectorCredentialPair", + ) + + +class Persona__UserGroup(Base): + __tablename__ = "persona__user_group" + + persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True) + user_group_id: Mapped[int] = mapped_column( + ForeignKey("user_group.id"), primary_key=True + ) + + +class DocumentSet__UserGroup(Base): + __tablename__ = "document_set__user_group" + + document_set_id: Mapped[int] = mapped_column( + ForeignKey("document_set.id"), primary_key=True + ) + user_group_id: Mapped[int] = mapped_column( + ForeignKey("user_group.id"), primary_key=True + ) + + +class UserGroup(Base): + __tablename__ = "user_group" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String, unique=True) + # whether or not changes to the UserGroup have been propagated to Vespa + is_up_to_date: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + # tell the sync job to clean up the group + is_up_for_deletion: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=False + ) + + users: Mapped[list[User]] = relationship( + "User", + secondary=User__UserGroup.__table__, + ) + cc_pairs: Mapped[list[ConnectorCredentialPair]] = relationship( + "ConnectorCredentialPair", + secondary=UserGroup__ConnectorCredentialPair.__table__, + viewonly=True, + ) + cc_pair_relationships: Mapped[ + list[UserGroup__ConnectorCredentialPair] + ] = relationship( + "UserGroup__ConnectorCredentialPair", + viewonly=True, + ) + personas: Mapped[list[Persona]] = relationship( + "Persona", + secondary=Persona__UserGroup.__table__, + viewonly=True, + ) + document_sets: Mapped[list[DocumentSet]] = relationship( + "DocumentSet", + secondary=DocumentSet__UserGroup.__table__, + viewonly=True, + ) diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py new file mode 100644 index 00000000000..38351b18b02 --- /dev/null +++ b/backend/danswer/db/persona.py @@ -0,0 +1,85 @@ +from uuid import UUID + +from fastapi import HTTPException +from sqlalchemy.orm import Session + +from danswer.db.chat import get_prompts_by_ids +from danswer.db.chat import upsert_persona +from danswer.db.document_set import get_document_sets_by_ids +from danswer.db.models import User +from danswer.server.features.persona.models import CreatePersonaRequest +from danswer.server.features.persona.models import PersonaSnapshot +from danswer.utils.logger import setup_logger +from danswer.utils.variable_functionality import fetch_versioned_implementation + +logger = setup_logger() + + +def make_persona_private( + persona_id: int, + user_ids: list[UUID] | None, + group_ids: list[int] | None, + db_session: Session, +) -> None: + # May cause error if someone switches down to MIT from EE + if user_ids or group_ids: + raise NotImplementedError("Danswer MIT does not support private Document Sets") + + +def create_update_persona( + persona_id: int | None, + create_persona_request: CreatePersonaRequest, + user: User | None, + db_session: Session, +) -> PersonaSnapshot: + user_id = user.id if user is not None else None + + # Permission to actually use these is checked later + document_sets = list( + get_document_sets_by_ids( + document_set_ids=create_persona_request.document_set_ids, + db_session=db_session, + ) + ) + prompts = list( + get_prompts_by_ids( + prompt_ids=create_persona_request.prompt_ids, + db_session=db_session, + ) + ) + + try: + persona = upsert_persona( + persona_id=persona_id, + user_id=user_id, + name=create_persona_request.name, + description=create_persona_request.description, + num_chunks=create_persona_request.num_chunks, + llm_relevance_filter=create_persona_request.llm_relevance_filter, + llm_filter_extraction=create_persona_request.llm_filter_extraction, + recency_bias=create_persona_request.recency_bias, + prompts=prompts, + document_sets=document_sets, + llm_model_version_override=create_persona_request.llm_model_version_override, + starter_messages=create_persona_request.starter_messages, + shared=create_persona_request.shared, + is_public=create_persona_request.is_public, + db_session=db_session, + ) + + versioned_make_persona_private = fetch_versioned_implementation( + "danswer.db.persona", "make_persona_private" + ) + + # Privatize Persona + versioned_make_persona_private( + persona_id=persona.id, + user_ids=create_persona_request.users, + group_ids=create_persona_request.groups, + db_session=db_session, + ) + + except ValueError as e: + logger.exception("Failed to create persona") + raise HTTPException(status_code=400, detail=str(e)) + return PersonaSnapshot.from_model(persona) diff --git a/backend/danswer/db/pg_file_store.py b/backend/danswer/db/pg_file_store.py new file mode 100644 index 00000000000..91a57adab7f --- /dev/null +++ b/backend/danswer/db/pg_file_store.py @@ -0,0 +1,93 @@ +from io import BytesIO +from typing import IO + +from psycopg2.extensions import connection +from sqlalchemy.orm import Session + +from danswer.db.models import PGFileStore +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def get_pg_conn_from_session(db_session: Session) -> connection: + return db_session.connection().connection.connection # type: ignore + + +def create_populate_lobj( + content: IO, + db_session: Session, +) -> int: + """Note, this does not commit the changes to the DB + This is because the commit should happen with the PGFileStore row creation + That step finalizes both the Large Object and the table tracking it + """ + pg_conn = get_pg_conn_from_session(db_session) + large_object = pg_conn.lobject() + + large_object.write(content.read()) + large_object.close() + + return large_object.oid + + +def read_lobj(lobj_oid: int, db_session: Session, mode: str | None = None) -> IO: + pg_conn = get_pg_conn_from_session(db_session) + large_object = ( + pg_conn.lobject(lobj_oid, mode=mode) if mode else pg_conn.lobject(lobj_oid) + ) + return BytesIO(large_object.read()) + + +def delete_lobj_by_id( + lobj_oid: int, + db_session: Session, +) -> None: + pg_conn = get_pg_conn_from_session(db_session) + pg_conn.lobject(lobj_oid).unlink() + + +def upsert_pgfilestore( + file_name: str, lobj_oid: int, db_session: Session, commit: bool = False +) -> PGFileStore: + pgfilestore = db_session.query(PGFileStore).filter_by(file_name=file_name).first() + + if pgfilestore: + try: + # This should not happen in normal execution + delete_lobj_by_id(lobj_oid=pgfilestore.lobj_oid, db_session=db_session) + except Exception: + # If the delete fails as well, the large object doesn't exist anyway and even if it + # fails to delete, it's not too terrible as most files sizes are insignificant + logger.error( + f"Failed to delete large object with oid {pgfilestore.lobj_oid}" + ) + + pgfilestore.lobj_oid = lobj_oid + else: + pgfilestore = PGFileStore(file_name=file_name, lobj_oid=lobj_oid) + db_session.add(pgfilestore) + + if commit: + db_session.commit() + + return pgfilestore + + +def get_pgfilestore_by_file_name( + file_name: str, + db_session: Session, +) -> PGFileStore: + pgfilestore = db_session.query(PGFileStore).filter_by(file_name=file_name).first() + + if not pgfilestore: + raise RuntimeError(f"File by name {file_name} does not exist or was deleted") + + return pgfilestore + + +def delete_pgfilestore_by_file_name( + file_name: str, + db_session: Session, +) -> None: + db_session.query(PGFileStore).filter_by(file_name=file_name).delete() diff --git a/backend/danswer/db/pydantic_type.py b/backend/danswer/db/pydantic_type.py new file mode 100644 index 00000000000..1f37152a851 --- /dev/null +++ b/backend/danswer/db/pydantic_type.py @@ -0,0 +1,32 @@ +import json +from typing import Any +from typing import Optional +from typing import Type + +from pydantic import BaseModel +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.types import TypeDecorator + + +class PydanticType(TypeDecorator): + impl = JSONB + + def __init__( + self, pydantic_model: Type[BaseModel], *args: Any, **kwargs: Any + ) -> None: + super().__init__(*args, **kwargs) + self.pydantic_model = pydantic_model + + def process_bind_param( + self, value: Optional[BaseModel], dialect: Any + ) -> Optional[dict]: + if value is not None: + return json.loads(value.json()) + return None + + def process_result_value( + self, value: Optional[dict], dialect: Any + ) -> Optional[BaseModel]: + if value is not None: + return self.pydantic_model.parse_obj(value) + return None diff --git a/backend/danswer/db/slack_bot_config.py b/backend/danswer/db/slack_bot_config.py index f2aeae7b312..c3b463e35d2 100644 --- a/backend/danswer/db/slack_bot_config.py +++ b/backend/danswer/db/slack_bot_config.py @@ -12,7 +12,7 @@ from danswer.db.models import Persona__DocumentSet from danswer.db.models import SlackBotConfig from danswer.db.models import SlackBotResponseType -from danswer.search.models import RecencyBiasSetting +from danswer.search.enums import RecencyBiasSetting def _build_persona_name(channel_names: list[str]) -> str: @@ -62,6 +62,7 @@ def create_slack_bot_persona( llm_model_version_override=None, starter_messages=None, shared=True, + is_public=True, default_persona=False, db_session=db_session, commit=False, diff --git a/backend/danswer/document_index/vespa/index.py b/backend/danswer/document_index/vespa/index.py index 178aadf3eea..56c36d1e41e 100644 --- a/backend/danswer/document_index/vespa/index.py +++ b/backend/danswer/document_index/vespa/index.py @@ -64,8 +64,8 @@ from danswer.indexing.models import DocMetadataAwareIndexChunk from danswer.indexing.models import InferenceChunk from danswer.search.models import IndexFilters -from danswer.search.search_runner import query_processing -from danswer.search.search_runner import remove_stop_words_and_punctuation +from danswer.search.retrieval.search_runner import query_processing +from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation from danswer.utils.batching import batch_generator from danswer.utils.logger import setup_logger from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel @@ -112,13 +112,13 @@ def _does_document_exist( """Returns whether the document already exists and the users/group whitelists Specifically in this case, document refers to a vespa document which is equivalent to a Danswer chunk. This checks for whether the chunk exists already in the index""" - doc_fetch_response = http_client.get( - f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}" - ) + doc_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}" + doc_fetch_response = http_client.get(doc_url) if doc_fetch_response.status_code == 404: return False if doc_fetch_response.status_code != 200: + logger.debug(f"Failed to check for document with URL {doc_url}") raise RuntimeError( f"Unexpected fetch document by ID value from Vespa " f"with error {doc_fetch_response.status_code}" @@ -157,7 +157,24 @@ def _get_vespa_chunk_ids_by_document_id( "hits": hits_per_page, } while True: - results = requests.post(SEARCH_ENDPOINT, json=params).json() + res = requests.post(SEARCH_ENDPOINT, json=params) + try: + res.raise_for_status() + except requests.HTTPError as e: + request_info = f"Headers: {res.request.headers}\nPayload: {params}" + response_info = ( + f"Status Code: {res.status_code}\nResponse Content: {res.text}" + ) + error_base = f"Error occurred getting chunk by Document ID {document_id}" + logger.error( + f"{error_base}:\n" + f"{request_info}\n" + f"{response_info}\n" + f"Exception: {e}" + ) + raise requests.HTTPError(error_base) from e + + results = res.json() hits = results["root"].get("children", []) doc_chunk_ids.extend( @@ -179,10 +196,14 @@ def _delete_vespa_doc_chunks( ) for chunk_id in doc_chunk_ids: - res = http_client.delete( - f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{chunk_id}" - ) - res.raise_for_status() + try: + res = http_client.delete( + f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{chunk_id}" + ) + res.raise_for_status() + except httpx.HTTPStatusError as e: + logger.error(f"Failed to delete chunk, details: {e.response.text}") + raise def _delete_vespa_docs( @@ -559,18 +580,35 @@ def _query_vespa(query_params: Mapping[str, str | int | float]) -> list[Inferenc if "query" in query_params and not cast(str, query_params["query"]).strip(): raise ValueError("No/empty query received") + params = dict( + **query_params, + **{ + "presentation.timing": True, + } + if LOG_VESPA_TIMING_INFORMATION + else {}, + ) + response = requests.post( SEARCH_ENDPOINT, - json=dict( - **query_params, - **{ - "presentation.timing": True, - } - if LOG_VESPA_TIMING_INFORMATION - else {}, - ), + json=params, ) - response.raise_for_status() + try: + response.raise_for_status() + except requests.HTTPError as e: + request_info = f"Headers: {response.request.headers}\nPayload: {params}" + response_info = ( + f"Status Code: {response.status_code}\n" + f"Response Content: {response.text}" + ) + error_base = "Failed to query Vespa" + logger.error( + f"{error_base}:\n" + f"{request_info}\n" + f"{response_info}\n" + f"Exception: {e}" + ) + raise requests.HTTPError(error_base) from e response_json: dict[str, Any] = response.json() if LOG_VESPA_TIMING_INFORMATION: diff --git a/backend/danswer/dynamic_configs/__init__.py b/backend/danswer/dynamic_configs/__init__.py index 0fc2233fa9b..e69de29bb2d 100644 --- a/backend/danswer/dynamic_configs/__init__.py +++ b/backend/danswer/dynamic_configs/__init__.py @@ -1,13 +0,0 @@ -from danswer.configs.app_configs import DYNAMIC_CONFIG_DIR_PATH -from danswer.configs.app_configs import DYNAMIC_CONFIG_STORE -from danswer.dynamic_configs.file_system.store import FileSystemBackedDynamicConfigStore -from danswer.dynamic_configs.interface import DynamicConfigStore - - -def get_dynamic_config_store() -> DynamicConfigStore: - dynamic_config_store_type = DYNAMIC_CONFIG_STORE - if dynamic_config_store_type == FileSystemBackedDynamicConfigStore.__name__: - return FileSystemBackedDynamicConfigStore(DYNAMIC_CONFIG_DIR_PATH) - - # TODO: change exception type - raise Exception("Unknown dynamic config store type") diff --git a/backend/danswer/dynamic_configs/factory.py b/backend/danswer/dynamic_configs/factory.py new file mode 100644 index 00000000000..a82bc315c8b --- /dev/null +++ b/backend/danswer/dynamic_configs/factory.py @@ -0,0 +1,16 @@ +from danswer.configs.app_configs import DYNAMIC_CONFIG_DIR_PATH +from danswer.configs.app_configs import DYNAMIC_CONFIG_STORE +from danswer.dynamic_configs.interface import DynamicConfigStore +from danswer.dynamic_configs.store import FileSystemBackedDynamicConfigStore +from danswer.dynamic_configs.store import PostgresBackedDynamicConfigStore + + +def get_dynamic_config_store() -> DynamicConfigStore: + dynamic_config_store_type = DYNAMIC_CONFIG_STORE + if dynamic_config_store_type == FileSystemBackedDynamicConfigStore.__name__: + return FileSystemBackedDynamicConfigStore(DYNAMIC_CONFIG_DIR_PATH) + if dynamic_config_store_type == PostgresBackedDynamicConfigStore.__name__: + return PostgresBackedDynamicConfigStore() + + # TODO: change exception type + raise Exception("Unknown dynamic config store type") diff --git a/backend/danswer/dynamic_configs/port_configs.py b/backend/danswer/dynamic_configs/port_configs.py new file mode 100644 index 00000000000..34abcff7412 --- /dev/null +++ b/backend/danswer/dynamic_configs/port_configs.py @@ -0,0 +1,40 @@ +import json +from pathlib import Path + +from danswer.configs.app_configs import DYNAMIC_CONFIG_DIR_PATH +from danswer.dynamic_configs.factory import PostgresBackedDynamicConfigStore +from danswer.dynamic_configs.interface import ConfigNotFoundError + + +def read_file_system_store(directory_path: str) -> dict: + store = {} + base_path = Path(directory_path) + for file_path in base_path.iterdir(): + if file_path.is_file() and "." not in file_path.name: + with open(file_path, "r") as file: + key = file_path.stem + value = json.load(file) + + if value: + store[key] = value + return store + + +def insert_into_postgres(store_data: dict) -> None: + port_once_key = "file_store_ported" + config_store = PostgresBackedDynamicConfigStore() + try: + config_store.load(port_once_key) + return + except ConfigNotFoundError: + pass + + for key, value in store_data.items(): + config_store.store(key, value) + + config_store.store(port_once_key, True) + + +def port_filesystem_to_postgres(directory_path: str = DYNAMIC_CONFIG_DIR_PATH) -> None: + store_data = read_file_system_store(directory_path) + insert_into_postgres(store_data) diff --git a/backend/danswer/dynamic_configs/file_system/store.py b/backend/danswer/dynamic_configs/store.py similarity index 52% rename from backend/danswer/dynamic_configs/file_system/store.py rename to backend/danswer/dynamic_configs/store.py index 75cc0d7407e..043d762d479 100644 --- a/backend/danswer/dynamic_configs/file_system/store.py +++ b/backend/danswer/dynamic_configs/store.py @@ -1,10 +1,15 @@ import json import os +from collections.abc import Iterator +from contextlib import contextmanager from pathlib import Path from typing import cast from filelock import FileLock +from sqlalchemy.orm import Session +from danswer.db.engine import SessionFactory +from danswer.db.models import KVStore from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.dynamic_configs.interface import DynamicConfigStore from danswer.dynamic_configs.interface import JSON_ro @@ -46,3 +51,38 @@ def delete(self, key: str) -> None: lock = _get_file_lock(file_path) with lock.acquire(timeout=FILE_LOCK_TIMEOUT): os.remove(file_path) + + +class PostgresBackedDynamicConfigStore(DynamicConfigStore): + @contextmanager + def get_session(self) -> Iterator[Session]: + session: Session = SessionFactory() + try: + yield session + finally: + session.close() + + def store(self, key: str, val: JSON_ro) -> None: + with self.get_session() as session: + obj = session.query(KVStore).filter_by(key=key).first() + if obj: + obj.value = val + else: + obj = KVStore(key=key, value=val) # type: ignore + session.query(KVStore).filter_by(key=key).delete() + session.add(obj) + session.commit() + + def load(self, key: str) -> JSON_ro: + with self.get_session() as session: + obj = session.query(KVStore).filter_by(key=key).first() + if not obj: + raise ConfigNotFoundError + return cast(JSON_ro, obj.value) + + def delete(self, key: str) -> None: + with self.get_session() as session: + result = session.query(KVStore).filter_by(key=key).delete() # type: ignore + if result == 0: + raise ConfigNotFoundError + session.commit() diff --git a/backend/danswer/indexing/chunker.py b/backend/danswer/indexing/chunker.py index 9be9348b9f9..b6f59d18901 100644 --- a/backend/danswer/indexing/chunker.py +++ b/backend/danswer/indexing/chunker.py @@ -5,18 +5,22 @@ from danswer.configs.app_configs import BLURB_SIZE from danswer.configs.app_configs import CHUNK_OVERLAP from danswer.configs.app_configs import MINI_CHUNK_SIZE +from danswer.configs.constants import DocumentSource from danswer.configs.constants import SECTION_SEPARATOR from danswer.configs.constants import TITLE_SEPARATOR from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.connectors.models import Document from danswer.indexing.models import DocAwareChunk from danswer.search.search_nlp_models import get_default_tokenizer +from danswer.utils.logger import setup_logger from danswer.utils.text_processing import shared_precompare_cleanup - if TYPE_CHECKING: from transformers import AutoTokenizer # type:ignore + +logger = setup_logger() + ChunkFunc = Callable[[Document], list[DocAwareChunk]] @@ -178,4 +182,7 @@ def chunk(self, document: Document) -> list[DocAwareChunk]: class DefaultChunker(Chunker): def chunk(self, document: Document) -> list[DocAwareChunk]: + # Specifically for reproducing an issue with gmail + if document.source == DocumentSource.GMAIL: + logger.debug(f"Chunking {document.semantic_identifier}") return chunk_document(document) diff --git a/backend/danswer/indexing/embedder.py b/backend/danswer/indexing/embedder.py index 3be10f5b41c..20a8690e366 100644 --- a/backend/danswer/indexing/embedder.py +++ b/backend/danswer/indexing/embedder.py @@ -4,8 +4,6 @@ from sqlalchemy.orm import Session from danswer.configs.app_configs import ENABLE_MINI_CHUNK -from danswer.configs.app_configs import INDEXING_MODEL_SERVER_HOST -from danswer.configs.app_configs import MODEL_SERVER_PORT from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.db.embedding_model import get_current_db_embedding_model @@ -16,9 +14,12 @@ from danswer.indexing.models import ChunkEmbedding from danswer.indexing.models import DocAwareChunk from danswer.indexing.models import IndexChunk +from danswer.search.enums import EmbedTextType from danswer.search.search_nlp_models import EmbeddingModel -from danswer.search.search_nlp_models import EmbedTextType +from danswer.utils.batching import batch_list from danswer.utils.logger import setup_logger +from shared_configs.configs import INDEXING_MODEL_SERVER_HOST +from shared_configs.configs import MODEL_SERVER_PORT logger = setup_logger() @@ -73,6 +74,8 @@ def embed_chunks( title_embed_dict: dict[str, list[float]] = {} embedded_chunks: list[IndexChunk] = [] + # Create Mini Chunks for more precise matching of details + # Off by default with unedited settings chunk_texts = [] chunk_mini_chunks_count = {} for chunk_ind, chunk in enumerate(chunks): @@ -85,23 +88,43 @@ def embed_chunks( chunk_texts.extend(mini_chunk_texts) chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts) - text_batches = [ - chunk_texts[i : i + batch_size] - for i in range(0, len(chunk_texts), batch_size) - ] + # Batching for embedding + text_batches = batch_list(chunk_texts, batch_size) embeddings: list[list[float]] = [] len_text_batches = len(text_batches) for idx, text_batch in enumerate(text_batches, start=1): - logger.debug(f"Embedding text batch {idx} of {len_text_batches}") - # Normalize embeddings is only configured via model_configs.py, be sure to use right value for the set loss + logger.debug(f"Embedding Content Texts batch {idx} of {len_text_batches}") + # Normalize embeddings is only configured via model_configs.py, be sure to use right + # value for the set loss embeddings.extend( self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE) ) - # Replace line above with the line below for easy debugging of indexing flow, skipping the actual model + # Replace line above with the line below for easy debugging of indexing flow + # skipping the actual model # embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))]) + chunk_titles = { + chunk.source_document.get_title_for_document_index() for chunk in chunks + } + + # Drop any None or empty strings + chunk_titles_list = [title for title in chunk_titles if title] + + # Embed Titles in batches + title_batches = batch_list(chunk_titles_list, batch_size) + len_title_batches = len(title_batches) + for ind_batch, title_batch in enumerate(title_batches, start=1): + logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}") + title_embeddings = self.embedding_model.encode( + title_batch, text_type=EmbedTextType.PASSAGE + ) + title_embed_dict.update( + {title: vector for title, vector in zip(title_batch, title_embeddings)} + ) + + # Mapping embeddings to chunks embedding_ind_start = 0 for chunk_ind, chunk in enumerate(chunks): num_embeddings = chunk_mini_chunks_count[chunk_ind] @@ -114,9 +137,12 @@ def embed_chunks( title_embedding = None if title: if title in title_embed_dict: - # Using cached value for speedup + # Using cached value to avoid recalculating for every chunk title_embedding = title_embed_dict[title] else: + logger.error( + "Title had to be embedded separately, this should not happen!" + ) title_embedding = self.embedding_model.encode( [title], text_type=EmbedTextType.PASSAGE )[0] diff --git a/backend/danswer/indexing/models.py b/backend/danswer/indexing/models.py index c875c88bdd2..68f9e3886ae 100644 --- a/backend/danswer/indexing/models.py +++ b/backend/danswer/indexing/models.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from dataclasses import fields from datetime import datetime +from typing import TYPE_CHECKING from pydantic import BaseModel @@ -9,6 +10,9 @@ from danswer.connectors.models import Document from danswer.utils.logger import setup_logger +if TYPE_CHECKING: + from danswer.db.models import EmbeddingModel + logger = setup_logger() @@ -130,3 +134,13 @@ class EmbeddingModelDetail(BaseModel): normalize: bool query_prefix: str | None passage_prefix: str | None + + @classmethod + def from_model(cls, embedding_model: "EmbeddingModel") -> "EmbeddingModelDetail": + return cls( + model_name=embedding_model.model_name, + model_dim=embedding_model.model_dim, + normalize=embedding_model.normalize, + query_prefix=embedding_model.query_prefix, + passage_prefix=embedding_model.passage_prefix, + ) diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py new file mode 100644 index 00000000000..44eae76df23 --- /dev/null +++ b/backend/danswer/llm/answering/answer.py @@ -0,0 +1,178 @@ +from collections.abc import Iterator +from typing import cast + +from langchain.schema.messages import BaseMessage + +from danswer.chat.models import AnswerQuestionPossibleReturn +from danswer.chat.models import AnswerQuestionStreamReturn +from danswer.chat.models import CitationInfo +from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import LlmDoc +from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE +from danswer.configs.chat_configs import QA_TIMEOUT +from danswer.llm.answering.doc_pruning import prune_documents +from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import LLMConfig +from danswer.llm.answering.models import PreviousMessage +from danswer.llm.answering.models import PromptConfig +from danswer.llm.answering.models import StreamProcessor +from danswer.llm.answering.prompts.citations_prompt import build_citations_prompt +from danswer.llm.answering.prompts.quotes_prompt import ( + build_quotes_prompt, +) +from danswer.llm.answering.stream_processing.citation_processing import ( + build_citation_processor, +) +from danswer.llm.answering.stream_processing.quotes_processing import ( + build_quotes_processor, +) +from danswer.llm.factory import get_default_llm +from danswer.llm.utils import get_default_llm_tokenizer + + +def _get_stream_processor( + docs: list[LlmDoc], answer_style_configs: AnswerStyleConfig +) -> StreamProcessor: + if answer_style_configs.citation_config: + return build_citation_processor( + context_docs=docs, + ) + if answer_style_configs.quotes_config: + return build_quotes_processor( + context_docs=docs, is_json_prompt=not (QA_PROMPT_OVERRIDE == "weak") + ) + + raise RuntimeError("Not implemented yet") + + +class Answer: + def __init__( + self, + question: str, + docs: list[LlmDoc], + answer_style_config: AnswerStyleConfig, + llm_config: LLMConfig, + prompt_config: PromptConfig, + # must be the same length as `docs`. If None, all docs are considered "relevant" + doc_relevance_list: list[bool] | None = None, + message_history: list[PreviousMessage] | None = None, + single_message_history: str | None = None, + timeout: int = QA_TIMEOUT, + ) -> None: + if single_message_history and message_history: + raise ValueError( + "Cannot provide both `message_history` and `single_message_history`" + ) + + self.question = question + self.docs = docs + self.doc_relevance_list = doc_relevance_list + self.message_history = message_history or [] + # used for QA flow where we only want to send a single message + self.single_message_history = single_message_history + + self.answer_style_config = answer_style_config + self.llm_config = llm_config + self.prompt_config = prompt_config + + self.llm = get_default_llm( + gen_ai_model_provider=self.llm_config.model_provider, + gen_ai_model_version_override=self.llm_config.model_version, + timeout=timeout, + temperature=self.llm_config.temperature, + ) + self.llm_tokenizer = get_default_llm_tokenizer() + + self.process_stream_fn = _get_stream_processor(docs, answer_style_config) + + self._final_prompt: list[BaseMessage] | None = None + + self._pruned_docs: list[LlmDoc] | None = None + + self._streamed_output: list[str] | None = None + self._processed_stream: list[AnswerQuestionPossibleReturn] | None = None + + @property + def pruned_docs(self) -> list[LlmDoc]: + if self._pruned_docs is not None: + return self._pruned_docs + + self._pruned_docs = prune_documents( + docs=self.docs, + doc_relevance_list=self.doc_relevance_list, + prompt_config=self.prompt_config, + llm_config=self.llm_config, + question=self.question, + document_pruning_config=self.answer_style_config.document_pruning_config, + ) + return self._pruned_docs + + @property + def final_prompt(self) -> list[BaseMessage]: + if self._final_prompt is not None: + return self._final_prompt + + if self.answer_style_config.citation_config: + self._final_prompt = build_citations_prompt( + question=self.question, + message_history=self.message_history, + llm_config=self.llm_config, + prompt_config=self.prompt_config, + context_docs=self.pruned_docs, + all_doc_useful=self.answer_style_config.citation_config.all_docs_useful, + llm_tokenizer_encode_func=self.llm_tokenizer.encode, + history_message=self.single_message_history or "", + ) + elif self.answer_style_config.quotes_config: + self._final_prompt = build_quotes_prompt( + question=self.question, + context_docs=self.pruned_docs, + history_str=self.single_message_history or "", + prompt=self.prompt_config, + ) + + return cast(list[BaseMessage], self._final_prompt) + + @property + def raw_streamed_output(self) -> Iterator[str]: + if self._streamed_output is not None: + yield from self._streamed_output + return + + streamed_output = [] + for message in self.llm.stream(self.final_prompt): + streamed_output.append(message) + yield message + + self._streamed_output = streamed_output + + @property + def processed_streamed_output(self) -> AnswerQuestionStreamReturn: + if self._processed_stream is not None: + yield from self._processed_stream + return + + processed_stream = [] + for processed_packet in self.process_stream_fn(self.raw_streamed_output): + processed_stream.append(processed_packet) + yield processed_packet + + self._processed_stream = processed_stream + + @property + def llm_answer(self) -> str: + answer = "" + for packet in self.processed_streamed_output: + if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece: + answer += packet.answer_piece + + return answer + + @property + def citations(self) -> list[CitationInfo]: + citations: list[CitationInfo] = [] + for packet in self.processed_streamed_output: + if isinstance(packet, CitationInfo): + citations.append(packet) + + return citations diff --git a/backend/danswer/llm/answering/doc_pruning.py b/backend/danswer/llm/answering/doc_pruning.py new file mode 100644 index 00000000000..f1007d19e59 --- /dev/null +++ b/backend/danswer/llm/answering/doc_pruning.py @@ -0,0 +1,209 @@ +from copy import deepcopy +from typing import TypeVar + +from danswer.chat.models import ( + LlmDoc, +) +from danswer.configs.constants import IGNORE_FOR_QA +from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE +from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.models import DocumentPruningConfig +from danswer.llm.answering.models import LLMConfig +from danswer.llm.answering.models import PromptConfig +from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens +from danswer.llm.utils import get_default_llm_tokenizer +from danswer.llm.utils import tokenizer_trim_content +from danswer.prompts.prompt_utils import build_doc_context_str +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + +T = TypeVar("T", bound=LlmDoc | InferenceChunk) + +_METADATA_TOKEN_ESTIMATE = 75 + + +class PruningError(Exception): + pass + + +def _compute_limit( + prompt_config: PromptConfig, + llm_config: LLMConfig, + question: str, + max_chunks: int | None, + max_window_percentage: float | None, + max_tokens: int | None, +) -> int: + llm_max_document_tokens = compute_max_document_tokens( + prompt_config=prompt_config, llm_config=llm_config, actual_user_input=question + ) + + window_percentage_based_limit = ( + max_window_percentage * llm_max_document_tokens + if max_window_percentage + else None + ) + chunk_count_based_limit = ( + max_chunks * DOC_EMBEDDING_CONTEXT_SIZE if max_chunks else None + ) + + limit_options = [ + lim + for lim in [ + window_percentage_based_limit, + chunk_count_based_limit, + max_tokens, + llm_max_document_tokens, + ] + if lim + ] + return int(min(limit_options)) + + +def reorder_docs( + docs: list[T], + doc_relevance_list: list[bool] | None, +) -> list[T]: + if doc_relevance_list is None: + return docs + + reordered_docs: list[T] = [] + if doc_relevance_list is not None: + for selection_target in [True, False]: + for doc, is_relevant in zip(docs, doc_relevance_list): + if is_relevant == selection_target: + reordered_docs.append(doc) + return reordered_docs + + +def _remove_docs_to_ignore(docs: list[LlmDoc]) -> list[LlmDoc]: + return [doc for doc in docs if not doc.metadata.get(IGNORE_FOR_QA)] + + +def _apply_pruning( + docs: list[LlmDoc], + doc_relevance_list: list[bool] | None, + token_limit: int, + is_manually_selected_docs: bool, +) -> list[LlmDoc]: + llm_tokenizer = get_default_llm_tokenizer() + docs = deepcopy(docs) # don't modify in place + + # re-order docs with all the "relevant" docs at the front + docs = reorder_docs(docs=docs, doc_relevance_list=doc_relevance_list) + # remove docs that are explicitly marked as not for QA + docs = _remove_docs_to_ignore(docs=docs) + + tokens_per_doc: list[int] = [] + final_doc_ind = None + total_tokens = 0 + for ind, llm_doc in enumerate(docs): + doc_tokens = len( + llm_tokenizer.encode( + build_doc_context_str( + semantic_identifier=llm_doc.semantic_identifier, + source_type=llm_doc.source_type, + content=llm_doc.content, + metadata_dict=llm_doc.metadata, + updated_at=llm_doc.updated_at, + ind=ind, + ) + ) + ) + # if chunks, truncate chunks that are way too long + # this can happen if the embedding model tokenizer is different + # than the LLM tokenizer + if ( + not is_manually_selected_docs + and doc_tokens > DOC_EMBEDDING_CONTEXT_SIZE + _METADATA_TOKEN_ESTIMATE + ): + logger.warning( + "Found more tokens in chunk than expected, " + "likely mismatch between embedding and LLM tokenizers. Trimming content..." + ) + llm_doc.content = tokenizer_trim_content( + content=llm_doc.content, + desired_length=DOC_EMBEDDING_CONTEXT_SIZE, + tokenizer=llm_tokenizer, + ) + doc_tokens = DOC_EMBEDDING_CONTEXT_SIZE + tokens_per_doc.append(doc_tokens) + total_tokens += doc_tokens + if total_tokens > token_limit: + final_doc_ind = ind + break + + if final_doc_ind is not None: + if is_manually_selected_docs: + # for document selection, only allow the final document to get truncated + # if more than that, then the user message is too long + if final_doc_ind != len(docs) - 1: + raise PruningError( + "LLM context window exceeded. Please de-select some documents or shorten your query." + ) + + final_doc_desired_length = tokens_per_doc[final_doc_ind] - ( + total_tokens - token_limit + ) + final_doc_content_length = ( + final_doc_desired_length - _METADATA_TOKEN_ESTIMATE + ) + # this could occur if we only have space for the title / metadata + # not ideal, but it's the most reasonable thing to do + # NOTE: the frontend prevents documents from being selected if + # less than 75 tokens are available to try and avoid this situation + # from occuring in the first place + if final_doc_content_length <= 0: + logger.error( + f"Final doc ({docs[final_doc_ind].semantic_identifier}) content " + "length is less than 0. Removing this doc from the final prompt." + ) + docs.pop() + else: + docs[final_doc_ind].content = tokenizer_trim_content( + content=docs[final_doc_ind].content, + desired_length=final_doc_content_length, + tokenizer=llm_tokenizer, + ) + else: + # for regular search, don't truncate the final document unless it's the only one + if final_doc_ind != 0: + docs = docs[:final_doc_ind] + else: + docs[0].content = tokenizer_trim_content( + content=docs[0].content, + desired_length=token_limit - _METADATA_TOKEN_ESTIMATE, + tokenizer=llm_tokenizer, + ) + docs = [docs[0]] + + return docs + + +def prune_documents( + docs: list[LlmDoc], + doc_relevance_list: list[bool] | None, + prompt_config: PromptConfig, + llm_config: LLMConfig, + question: str, + document_pruning_config: DocumentPruningConfig, +) -> list[LlmDoc]: + if doc_relevance_list is not None: + assert len(docs) == len(doc_relevance_list) + + doc_token_limit = _compute_limit( + prompt_config=prompt_config, + llm_config=llm_config, + question=question, + max_chunks=document_pruning_config.max_chunks, + max_window_percentage=document_pruning_config.max_window_percentage, + max_tokens=document_pruning_config.max_tokens, + ) + return _apply_pruning( + docs=docs, + doc_relevance_list=doc_relevance_list, + token_limit=doc_token_limit, + is_manually_selected_docs=document_pruning_config.is_manually_selected_docs, + ) diff --git a/backend/danswer/llm/answering/models.py b/backend/danswer/llm/answering/models.py new file mode 100644 index 00000000000..71ea66661a4 --- /dev/null +++ b/backend/danswer/llm/answering/models.py @@ -0,0 +1,143 @@ +from collections.abc import Callable +from collections.abc import Iterator +from typing import Any +from typing import TYPE_CHECKING + +from pydantic import BaseModel +from pydantic import Field +from pydantic import root_validator + +from danswer.chat.models import AnswerQuestionStreamReturn +from danswer.configs.constants import MessageType +from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER +from danswer.llm.override_models import LLMOverride +from danswer.llm.override_models import PromptOverride +from danswer.llm.utils import get_default_llm_version + +if TYPE_CHECKING: + from danswer.db.models import ChatMessage + from danswer.db.models import Prompt + from danswer.db.models import Persona + + +StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn] + + +class PreviousMessage(BaseModel): + """Simplified version of `ChatMessage`""" + + message: str + token_count: int + message_type: MessageType + + @classmethod + def from_chat_message(cls, chat_message: "ChatMessage") -> "PreviousMessage": + return cls( + message=chat_message.message, + token_count=chat_message.token_count, + message_type=chat_message.message_type, + ) + + +class DocumentPruningConfig(BaseModel): + max_chunks: int | None = None + max_window_percentage: float | None = None + max_tokens: int | None = None + # different pruning behavior is expected when the + # user manually selects documents they want to chat with + # e.g. we don't want to truncate each document to be no more + # than one chunk long + is_manually_selected_docs: bool = False + + +class CitationConfig(BaseModel): + all_docs_useful: bool = False + + +class QuotesConfig(BaseModel): + pass + + +class AnswerStyleConfig(BaseModel): + citation_config: CitationConfig | None = None + quotes_config: QuotesConfig | None = None + document_pruning_config: DocumentPruningConfig = Field( + default_factory=DocumentPruningConfig + ) + + @root_validator + def check_quotes_and_citation(cls, values: dict[str, Any]) -> dict[str, Any]: + citation_config = values.get("citation_config") + quotes_config = values.get("quotes_config") + + if citation_config is None and quotes_config is None: + raise ValueError( + "One of `citation_config` or `quotes_config` must be provided" + ) + + if citation_config is not None and quotes_config is not None: + raise ValueError( + "Only one of `citation_config` or `quotes_config` must be provided" + ) + + return values + + +class LLMConfig(BaseModel): + """Final representation of the LLM configuration passed into + the `Answer` object.""" + + model_provider: str + model_version: str + temperature: float + + @classmethod + def from_persona( + cls, persona: "Persona", llm_override: LLMOverride | None = None + ) -> "LLMConfig": + model_provider_override = llm_override.model_provider if llm_override else None + model_version_override = llm_override.model_version if llm_override else None + temperature_override = llm_override.temperature if llm_override else None + + return cls( + model_provider=model_provider_override or GEN_AI_MODEL_PROVIDER, + model_version=( + model_version_override + or persona.llm_model_version_override + or get_default_llm_version()[0] + ), + temperature=temperature_override or 0.0, + ) + + class Config: + frozen = True + + +class PromptConfig(BaseModel): + """Final representation of the Prompt configuration passed + into the `Answer` object.""" + + system_prompt: str + task_prompt: str + datetime_aware: bool + include_citations: bool + + @classmethod + def from_model( + cls, model: "Prompt", prompt_override: PromptOverride | None = None + ) -> "PromptConfig": + override_system_prompt = ( + prompt_override.system_prompt if prompt_override else None + ) + override_task_prompt = prompt_override.task_prompt if prompt_override else None + + return cls( + system_prompt=override_system_prompt or model.system_prompt, + task_prompt=override_task_prompt or model.task_prompt, + datetime_aware=model.datetime_aware, + include_citations=model.include_citations, + ) + + # needed so that this can be passed into lru_cache funcs + class Config: + frozen = True diff --git a/backend/danswer/llm/answering/prompts/citations_prompt.py b/backend/danswer/llm/answering/prompts/citations_prompt.py new file mode 100644 index 00000000000..60f1e1098fa --- /dev/null +++ b/backend/danswer/llm/answering/prompts/citations_prompt.py @@ -0,0 +1,287 @@ +from collections.abc import Callable +from functools import lru_cache +from typing import cast + +from langchain.schema.messages import BaseMessage +from langchain.schema.messages import HumanMessage +from langchain.schema.messages import SystemMessage + +from danswer.chat.models import LlmDoc +from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION +from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS +from danswer.db.chat import get_default_prompt +from danswer.db.models import Persona +from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.models import LLMConfig +from danswer.llm.answering.models import PreviousMessage +from danswer.llm.answering.models import PromptConfig +from danswer.llm.utils import check_number_of_tokens +from danswer.llm.utils import get_default_llm_tokenizer +from danswer.llm.utils import get_max_input_tokens +from danswer.llm.utils import translate_history_to_basemessages +from danswer.prompts.chat_prompts import ADDITIONAL_INFO +from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT +from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT +from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT +from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT +from danswer.prompts.direct_qa_prompts import ( + CITATIONS_PROMPT, +) +from danswer.prompts.prompt_utils import build_complete_context_str +from danswer.prompts.prompt_utils import build_task_prompt_reminders +from danswer.prompts.prompt_utils import get_current_llm_day_time +from danswer.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT +from danswer.prompts.token_counts import ( + CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT, +) +from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT +from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT +from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT + + +_PER_MESSAGE_TOKEN_BUFFER = 7 + + +def find_last_index(lst: list[int], max_prompt_tokens: int) -> int: + """From the back, find the index of the last element to include + before the list exceeds the maximum""" + running_sum = 0 + + last_ind = 0 + for i in range(len(lst) - 1, -1, -1): + running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER + if running_sum > max_prompt_tokens: + last_ind = i + 1 + break + if last_ind >= len(lst): + raise ValueError("Last message alone is too large!") + return last_ind + + +def drop_messages_history_overflow( + system_msg: BaseMessage | None, + system_token_count: int, + history_msgs: list[BaseMessage], + history_token_counts: list[int], + final_msg: BaseMessage, + final_msg_token_count: int, + max_allowed_tokens: int, +) -> list[BaseMessage]: + """As message history grows, messages need to be dropped starting from the furthest in the past. + The System message should be kept if at all possible and the latest user input which is inserted in the + prompt template must be included""" + if len(history_msgs) != len(history_token_counts): + # This should never happen + raise ValueError("Need exactly 1 token count per message for tracking overflow") + + prompt: list[BaseMessage] = [] + + # Start dropping from the history if necessary + all_tokens = history_token_counts + [system_token_count, final_msg_token_count] + ind_prev_msg_start = find_last_index( + all_tokens, max_prompt_tokens=max_allowed_tokens + ) + + if system_msg and ind_prev_msg_start <= len(history_msgs): + prompt.append(system_msg) + + prompt.extend(history_msgs[ind_prev_msg_start:]) + + prompt.append(final_msg) + + return prompt + + +def get_prompt_tokens(prompt_config: PromptConfig) -> int: + # Note: currently custom prompts do not allow datetime aware, only default prompts + return ( + check_number_of_tokens(prompt_config.system_prompt) + + check_number_of_tokens(prompt_config.task_prompt) + + CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT + + CITATION_STATEMENT_TOKEN_CNT + + CITATION_REMINDER_TOKEN_CNT + + (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0) + + (ADDITIONAL_INFO_TOKEN_CNT if prompt_config.datetime_aware else 0) + ) + + +# buffer just to be safe so that we don't overflow the token limit due to +# a small miscalculation +_MISC_BUFFER = 40 + + +def compute_max_document_tokens( + prompt_config: PromptConfig, + llm_config: LLMConfig, + actual_user_input: str | None = None, + max_llm_token_override: int | None = None, +) -> int: + """Estimates the number of tokens available for context documents. Formula is roughly: + + ( + model_context_window - reserved_output_tokens - prompt_tokens + - (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe) + ) + + The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g. + if we're trying to determine if the user should be able to select another document) then we just set an + arbitrary "upper bound". + """ + # if we can't find a number of tokens, just assume some common default + max_input_tokens = ( + max_llm_token_override + if max_llm_token_override + else get_max_input_tokens(model_name=llm_config.model_version) + ) + prompt_tokens = get_prompt_tokens(prompt_config) + + user_input_tokens = ( + check_number_of_tokens(actual_user_input) + if actual_user_input is not None + else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS + ) + + return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER + + +def compute_max_document_tokens_for_persona( + persona: Persona, + actual_user_input: str | None = None, + max_llm_token_override: int | None = None, +) -> int: + prompt = persona.prompts[0] if persona.prompts else get_default_prompt() + return compute_max_document_tokens( + prompt_config=PromptConfig.from_model(prompt), + llm_config=LLMConfig.from_persona(persona), + actual_user_input=actual_user_input, + max_llm_token_override=max_llm_token_override, + ) + + +def compute_max_llm_input_tokens(llm_config: LLMConfig) -> int: + """Maximum tokens allows in the input to the LLM (of any type).""" + + input_tokens = get_max_input_tokens( + model_name=llm_config.model_version, model_provider=llm_config.model_provider + ) + return input_tokens - _MISC_BUFFER + + +@lru_cache() +def build_system_message( + prompt_config: PromptConfig, + context_exists: bool, + llm_tokenizer_encode_func: Callable, + citation_line: str = REQUIRE_CITATION_STATEMENT, + no_citation_line: str = NO_CITATION_STATEMENT, +) -> tuple[SystemMessage | None, int]: + system_prompt = prompt_config.system_prompt.strip() + if prompt_config.include_citations: + if context_exists: + system_prompt += citation_line + else: + system_prompt += no_citation_line + if prompt_config.datetime_aware: + if system_prompt: + system_prompt += ADDITIONAL_INFO.format( + datetime_info=get_current_llm_day_time() + ) + else: + system_prompt = get_current_llm_day_time() + + if not system_prompt: + return None, 0 + + token_count = len(llm_tokenizer_encode_func(system_prompt)) + system_msg = SystemMessage(content=system_prompt) + + return system_msg, token_count + + +def build_user_message( + question: str, + prompt_config: PromptConfig, + context_docs: list[LlmDoc] | list[InferenceChunk], + all_doc_useful: bool, + history_message: str, +) -> tuple[HumanMessage, int]: + llm_tokenizer = get_default_llm_tokenizer() + llm_tokenizer_encode_func = cast(Callable[[str], list[int]], llm_tokenizer.encode) + + if not context_docs: + # Simpler prompt for cases where there is no context + user_prompt = ( + CHAT_USER_CONTEXT_FREE_PROMPT.format( + task_prompt=prompt_config.task_prompt, user_query=question + ) + if prompt_config.task_prompt + else question + ) + user_prompt = user_prompt.strip() + token_count = len(llm_tokenizer_encode_func(user_prompt)) + user_msg = HumanMessage(content=user_prompt) + return user_msg, token_count + + context_docs_str = build_complete_context_str(context_docs) + optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT + + task_prompt_with_reminder = build_task_prompt_reminders(prompt_config) + + user_prompt = CITATIONS_PROMPT.format( + optional_ignore_statement=optional_ignore, + context_docs_str=context_docs_str, + task_prompt=task_prompt_with_reminder, + user_query=question, + history_block=history_message, + ) + + user_prompt = user_prompt.strip() + token_count = len(llm_tokenizer_encode_func(user_prompt)) + user_msg = HumanMessage(content=user_prompt) + + return user_msg, token_count + + +def build_citations_prompt( + question: str, + message_history: list[PreviousMessage], + prompt_config: PromptConfig, + llm_config: LLMConfig, + context_docs: list[LlmDoc] | list[InferenceChunk], + all_doc_useful: bool, + history_message: str, + llm_tokenizer_encode_func: Callable, +) -> list[BaseMessage]: + context_exists = len(context_docs) > 0 + + system_message_or_none, system_tokens = build_system_message( + prompt_config=prompt_config, + context_exists=context_exists, + llm_tokenizer_encode_func=llm_tokenizer_encode_func, + ) + + history_basemessages, history_token_counts = translate_history_to_basemessages( + message_history + ) + + # Be sure the context_docs passed to build_chat_user_message + # Is the same as passed in later for extracting citations + user_message, user_tokens = build_user_message( + question=question, + prompt_config=prompt_config, + context_docs=context_docs, + all_doc_useful=all_doc_useful, + history_message=history_message, + ) + + final_prompt_msgs = drop_messages_history_overflow( + system_msg=system_message_or_none, + system_token_count=system_tokens, + history_msgs=history_basemessages, + history_token_counts=history_token_counts, + final_msg=user_message, + final_msg_token_count=user_tokens, + max_allowed_tokens=compute_max_llm_input_tokens(llm_config), + ) + + return final_prompt_msgs diff --git a/backend/danswer/llm/answering/prompts/quotes_prompt.py b/backend/danswer/llm/answering/prompts/quotes_prompt.py new file mode 100644 index 00000000000..0824ffa6464 --- /dev/null +++ b/backend/danswer/llm/answering/prompts/quotes_prompt.py @@ -0,0 +1,88 @@ +from langchain.schema.messages import BaseMessage +from langchain.schema.messages import HumanMessage + +from danswer.chat.models import LlmDoc +from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION +from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE +from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.models import PromptConfig +from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK +from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK +from danswer.prompts.direct_qa_prompts import JSON_PROMPT +from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT +from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT +from danswer.prompts.prompt_utils import build_complete_context_str + + +def _build_weak_llm_quotes_prompt( + question: str, + context_docs: list[LlmDoc] | list[InferenceChunk], + history_str: str, + prompt: PromptConfig, + use_language_hint: bool, +) -> list[BaseMessage]: + """Since Danswer supports a variety of LLMs, this less demanding prompt is provided + as an option to use with weaker LLMs such as small version, low float precision, quantized, + or distilled models. It only uses one context document and has very weak requirements of + output format. + """ + context_block = "" + if context_docs: + context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs[0].content) + + prompt_str = WEAK_LLM_PROMPT.format( + system_prompt=prompt.system_prompt, + context_block=context_block, + task_prompt=prompt.task_prompt, + user_query=question, + ) + return [HumanMessage(content=prompt_str)] + + +def _build_strong_llm_quotes_prompt( + question: str, + context_docs: list[LlmDoc] | list[InferenceChunk], + history_str: str, + prompt: PromptConfig, + use_language_hint: bool, +) -> list[BaseMessage]: + context_block = "" + if context_docs: + context_docs_str = build_complete_context_str(context_docs) + context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str) + + history_block = "" + if history_str: + history_block = HISTORY_BLOCK.format(history_str=history_str) + + full_prompt = JSON_PROMPT.format( + system_prompt=prompt.system_prompt, + context_block=context_block, + history_block=history_block, + task_prompt=prompt.task_prompt, + user_query=question, + language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "", + ).strip() + return [HumanMessage(content=full_prompt)] + + +def build_quotes_prompt( + question: str, + context_docs: list[LlmDoc] | list[InferenceChunk], + history_str: str, + prompt: PromptConfig, + use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), +) -> list[BaseMessage]: + prompt_builder = ( + _build_weak_llm_quotes_prompt + if QA_PROMPT_OVERRIDE == "weak" + else _build_strong_llm_quotes_prompt + ) + + return prompt_builder( + question=question, + context_docs=context_docs, + history_str=history_str, + prompt=prompt, + use_language_hint=use_language_hint, + ) diff --git a/backend/danswer/llm/answering/prompts/utils.py b/backend/danswer/llm/answering/prompts/utils.py new file mode 100644 index 00000000000..bcc8b891815 --- /dev/null +++ b/backend/danswer/llm/answering/prompts/utils.py @@ -0,0 +1,20 @@ +from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT +from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT + + +def build_dummy_prompt( + system_prompt: str, task_prompt: str, retrieval_disabled: bool +) -> str: + if retrieval_disabled: + return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format( + user_query="+ To use the Axero connector, first follow the guide{" "} + + here + {" "} + to generate an API Key. +
+diff --git a/web/src/app/admin/personas/[personaId]/page.tsx b/web/src/app/admin/personas/[personaId]/page.tsx index 0b521319272..e9102d219bb 100644 --- a/web/src/app/admin/personas/[personaId]/page.tsx +++ b/web/src/app/admin/personas/[personaId]/page.tsx @@ -6,7 +6,6 @@ import { DocumentSet } from "@/lib/types"; import { BackButton } from "@/components/BackButton"; import { Card, Title } from "@tremor/react"; import { DeletePersonaButton } from "./DeletePersonaButton"; -import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh"; export default async function Page({ params, @@ -68,8 +67,6 @@ export default async function Page({ return (
- {chatName || `Chat ${chatSession.id}`} -
- )} - {isSelected && - (isRenamingChat ? ( -+ {chatName || `Chat ${chatSession.id}`} +
+ )} + {isSelected && + (isRenamingChat ? ( ++ {humanReadableFormat(chatSession.time_created)} +
+ +