diff --git a/docker-compose-no-searxng.yaml b/docker-compose-no-searxng.yaml deleted file mode 100644 index 683a596..0000000 --- a/docker-compose-no-searxng.yaml +++ /dev/null @@ -1,46 +0,0 @@ -services: - backend: - build: - context: . - dockerfile: ./src/backend/Dockerfile - restart: always - ports: - - "8000:8000" - environment: - - OLLAMA_API_BASE=${OLLAMA_API_BASE:-http://host.docker.internal:11434} - - TAVILY_API_KEY=${TAVILY_API_KEY} - - BING_API_KEY=${BING_API_KEY} - - SERPER_API_KEY=${SERPER_API_KEY} - - OPENAI_API_KEY=${OPENAI_API_KEY} - - GROQ_API_KEY=${GROQ_API_KEY} - - ENABLE_LOCAL_MODELS=${ENABLE_LOCAL_MODELS:-True} - - SEARCH_PROVIDER=${SEARCH_PROVIDER:-tavily} - - SEARXNG_BASE_URL=${SEARXNG_BASE_URL:-http://host.docker.internal:8080} - - CUSTOM_MODEL=${CUSTOM_MODEL} - - REDIS_URL=${REDIS_URL} - develop: - watch: - - action: sync - path: ./src/backend - target: /workspace/src/backend - extra_hosts: - - "host.docker.internal:host-gateway" - frontend: - depends_on: - - backend - build: - context: . - dockerfile: ./src/frontend/Dockerfile - restart: always - environment: - - NEXT_PUBLIC_API_URL=${NEXT_PUBLIC_API_URL:-http://localhost:8000} - - NEXT_PUBLIC_LOCAL_MODE_ENABLED=${NEXT_PUBLIC_LOCAL_MODE_ENABLED:-true} - ports: - - "3000:3000" - develop: - watch: - - action: sync - path: ./src/frontend - target: /app - ignore: - - node_modules/ diff --git a/docker-compose.dev.yaml b/docker-compose.dev.yaml index 1799b77..f14f830 100644 --- a/docker-compose.dev.yaml +++ b/docker-compose.dev.yaml @@ -1,5 +1,7 @@ services: backend: + depends_on: + - postgres build: context: . dockerfile: ./src/backend/Dockerfile @@ -18,6 +20,14 @@ services: - SEARXNG_BASE_URL=${SEARXNG_BASE_URL:-http://host.docker.internal:8080} - CUSTOM_MODEL=${CUSTOM_MODEL} - REDIS_URL=${REDIS_URL} + - DB_ENABLED=${DB_ENABLED:-True} + - POSTGRES_HOST=postgres + entrypoint: > + /bin/sh -c " + cd /workspace/src/backend && + alembic upgrade head && + uvicorn main:app --host 0.0.0.0 --port 8000 + " develop: watch: - action: sync @@ -58,5 +68,20 @@ services: environment: - SEARXNG_BASE_URL=https://${SEARXNG_BASE_URL:-localhost}/ + postgres: + image: postgres:15.2-alpine + restart: always + environment: + - POSTGRES_USER=${POSTGRES_USER:-postgres} + - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password} + - POSTGRES_DB=${POSTGRES_DB:-postgres} + ports: + - "5432:5432" + volumes: + - db_volume:/var/lib/postgresql/data + networks: searxng: + +volumes: + db_volume: diff --git a/docker-scripts/env-defaults b/docker-scripts/env-defaults index c6b54c4..32707c4 100644 --- a/docker-scripts/env-defaults +++ b/docker-scripts/env-defaults @@ -21,3 +21,6 @@ export NEXT_PUBLIC_LOCAL_MODE_ENABLED=${NEXT_PUBLIC_LOCAL_MODE_ENABLED:-true} # Redis export REDIS_URL=${REDIS_URL} + +# Database +export DB_ENABLED=${DB_ENABLED:-True} diff --git a/poetry.lock b/poetry.lock index 040197e..7782766 100644 --- a/poetry.lock +++ b/poetry.lock @@ -109,6 +109,25 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "alembic" +version = "1.13.1" +description = "A database migration tool for SQLAlchemy." +optional = false +python-versions = ">=3.8" +files = [ + {file = "alembic-1.13.1-py3-none-any.whl", hash = "sha256:2edcc97bed0bd3272611ce3a98d98279e9c209e7186e43e75bbb1b2bdfdbcc43"}, + {file = "alembic-1.13.1.tar.gz", hash = "sha256:4932c8558bf68f2ee92b9bbcb8218671c627064d5b08939437af6d77dc05e595"}, +] + +[package.dependencies] +Mako = "*" +SQLAlchemy = ">=1.3.0" +typing-extensions = ">=4" + +[package.extras] +tz = ["backports.zoneinfo"] + [[package]] name = "annotated-types" version = "0.6.0" @@ -1967,6 +1986,25 @@ sqlalchemy = ["opentelemetry-instrumentation-sqlalchemy (>=0.42b0)"] starlette = ["opentelemetry-instrumentation-starlette (>=0.42b0)"] system-metrics = ["opentelemetry-instrumentation-system-metrics (>=0.42b0)"] +[[package]] +name = "mako" +version = "1.3.5" +description = "A super-fast templating language that borrows the best ideas from the existing templating languages." +optional = false +python-versions = ">=3.8" +files = [ + {file = "Mako-1.3.5-py3-none-any.whl", hash = "sha256:260f1dbc3a519453a9c856dedfe4beb4e50bd5a26d96386cb6c80856556bb91a"}, + {file = "Mako-1.3.5.tar.gz", hash = "sha256:48dbc20568c1d276a2698b36d968fa76161bf127194907ea6fc594fa81f943bc"}, +] + +[package.dependencies] +MarkupSafe = ">=0.9.2" + +[package.extras] +babel = ["Babel"] +lingua = ["lingua"] +testing = ["pytest"] + [[package]] name = "marisa-trie" version = "1.1.1" @@ -3032,6 +3070,87 @@ files = [ {file = "protobuf-4.25.3.tar.gz", hash = "sha256:25b5d0b42fd000320bd7830b349e3b696435f3b329810427a6bcce6a5492cc5c"}, ] +[[package]] +name = "psycopg2-binary" +version = "2.9.9" +description = "psycopg2 - Python-PostgreSQL Database Adapter" +optional = false +python-versions = ">=3.7" +files = [ + {file = "psycopg2-binary-2.9.9.tar.gz", hash = "sha256:7f01846810177d829c7692f1f5ada8096762d9172af1b1a28d4ab5b77c923c1c"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c2470da5418b76232f02a2fcd2229537bb2d5a7096674ce61859c3229f2eb202"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c6af2a6d4b7ee9615cbb162b0738f6e1fd1f5c3eda7e5da17861eacf4c717ea7"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:75723c3c0fbbf34350b46a3199eb50638ab22a0228f93fb472ef4d9becc2382b"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:83791a65b51ad6ee6cf0845634859d69a038ea9b03d7b26e703f94c7e93dbcf9"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0ef4854e82c09e84cc63084a9e4ccd6d9b154f1dbdd283efb92ecd0b5e2b8c84"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ed1184ab8f113e8d660ce49a56390ca181f2981066acc27cf637d5c1e10ce46e"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d2997c458c690ec2bc6b0b7ecbafd02b029b7b4283078d3b32a852a7ce3ddd98"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:b58b4710c7f4161b5e9dcbe73bb7c62d65670a87df7bcce9e1faaad43e715245"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:0c009475ee389757e6e34611d75f6e4f05f0cf5ebb76c6037508318e1a1e0d7e"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8dbf6d1bc73f1d04ec1734bae3b4fb0ee3cb2a493d35ede9badbeb901fb40f6f"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-win32.whl", hash = "sha256:3f78fd71c4f43a13d342be74ebbc0666fe1f555b8837eb113cb7416856c79682"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:876801744b0dee379e4e3c38b76fc89f88834bb15bf92ee07d94acd06ec890a0"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ee825e70b1a209475622f7f7b776785bd68f34af6e7a46e2e42f27b659b5bc26"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1ea665f8ce695bcc37a90ee52de7a7980be5161375d42a0b6c6abedbf0d81f0f"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:143072318f793f53819048fdfe30c321890af0c3ec7cb1dfc9cc87aa88241de2"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c332c8d69fb64979ebf76613c66b985414927a40f8defa16cf1bc028b7b0a7b0"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7fc5a5acafb7d6ccca13bfa8c90f8c51f13d8fb87d95656d3950f0158d3ce53"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:977646e05232579d2e7b9c59e21dbe5261f403a88417f6a6512e70d3f8a046be"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b6356793b84728d9d50ead16ab43c187673831e9d4019013f1402c41b1db9b27"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:bc7bb56d04601d443f24094e9e31ae6deec9ccb23581f75343feebaf30423359"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:77853062a2c45be16fd6b8d6de2a99278ee1d985a7bd8b103e97e41c034006d2"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:78151aa3ec21dccd5cdef6c74c3e73386dcdfaf19bced944169697d7ac7482fc"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e6f98446430fdf41bd36d4faa6cb409f5140c1c2cf58ce0bbdaf16af7d3f119"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c77e3d1862452565875eb31bdb45ac62502feabbd53429fdc39a1cc341d681ba"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"}, + {file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"}, + {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"}, + {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"}, + {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8359bf4791968c5a78c56103702000105501adb557f3cf772b2c207284273984"}, + {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:275ff571376626195ab95a746e6a04c7df8ea34638b99fc11160de91f2fef503"}, + {file = "psycopg2_binary-2.9.9-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:f9b5571d33660d5009a8b3c25dc1db560206e2d2f89d3df1cb32d72c0d117d52"}, + {file = "psycopg2_binary-2.9.9-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:420f9bbf47a02616e8554e825208cb947969451978dceb77f95ad09c37791dae"}, + {file = "psycopg2_binary-2.9.9-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:4154ad09dac630a0f13f37b583eae260c6aa885d67dfbccb5b02c33f31a6d420"}, + {file = "psycopg2_binary-2.9.9-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a148c5d507bb9b4f2030a2025c545fccb0e1ef317393eaba42e7eabd28eb6041"}, + {file = "psycopg2_binary-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:68fc1f1ba168724771e38bee37d940d2865cb0f562380a1fb1ffb428b75cb692"}, + {file = "psycopg2_binary-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:281309265596e388ef483250db3640e5f414168c5a67e9c665cafce9492eda2f"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:60989127da422b74a04345096c10d416c2b41bd7bf2a380eb541059e4e999980"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:246b123cc54bb5361588acc54218c8c9fb73068bf227a4a531d8ed56fa3ca7d6"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34eccd14566f8fe14b2b95bb13b11572f7c7d5c36da61caf414d23b91fcc5d94"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18d0ef97766055fec15b5de2c06dd8e7654705ce3e5e5eed3b6651a1d2a9a152"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d3f82c171b4ccd83bbaf35aa05e44e690113bd4f3b7b6cc54d2219b132f3ae55"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ead20f7913a9c1e894aebe47cccf9dc834e1618b7aa96155d2091a626e59c972"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ca49a8119c6cbd77375ae303b0cfd8c11f011abbbd64601167ecca18a87e7cdd"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:323ba25b92454adb36fa425dc5cf6f8f19f78948cbad2e7bc6cdf7b0d7982e59"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:1236ed0952fbd919c100bc839eaa4a39ebc397ed1c08a97fc45fee2a595aa1b3"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:729177eaf0aefca0994ce4cffe96ad3c75e377c7b6f4efa59ebf003b6d398716"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-win32.whl", hash = "sha256:804d99b24ad523a1fe18cc707bf741670332f7c7412e9d49cb5eab67e886b9b5"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-win_amd64.whl", hash = "sha256:a6cdcc3ede532f4a4b96000b6362099591ab4a3e913d70bcbac2b56c872446f7"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:72dffbd8b4194858d0941062a9766f8297e8868e1dd07a7b36212aaa90f49472"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:30dcc86377618a4c8f3b72418df92e77be4254d8f89f14b8e8f57d6d43603c0f"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:31a34c508c003a4347d389a9e6fcc2307cc2150eb516462a7a17512130de109e"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:15208be1c50b99203fe88d15695f22a5bed95ab3f84354c494bcb1d08557df67"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1873aade94b74715be2246321c8650cabf5a0d098a95bab81145ffffa4c13876"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a58c98a7e9c021f357348867f537017057c2ed7f77337fd914d0bedb35dace7"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4686818798f9194d03c9129a4d9a702d9e113a89cb03bffe08c6cf799e053291"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ebdc36bea43063116f0486869652cb2ed7032dbc59fbcb4445c4862b5c1ecf7f"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:ca08decd2697fdea0aea364b370b1249d47336aec935f87b8bbfd7da5b2ee9c1"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ac05fb791acf5e1a3e39402641827780fe44d27e72567a000412c648a85ba860"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-win32.whl", hash = "sha256:9dba73be7305b399924709b91682299794887cbbd88e38226ed9f6712eabee90"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-win_amd64.whl", hash = "sha256:f7ae5d65ccfbebdfa761585228eb4d0df3a8b15cfb53bd953e713e09fbb12957"}, +] + [[package]] name = "pycparser" version = "2.22" @@ -5100,4 +5219,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "aec7c9df55261292eeffe9c97b6f7d80b88ea785808949a546a6a1cf574144e7" +content-hash = "c56e51f4917453a73263f42e73596ac0ddd4fdc88ddbf57049bf90444211f235" diff --git a/pyproject.toml b/pyproject.toml index 4113a08..868ce09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,8 @@ slowapi = "^0.1.9" redis = "^5.0.4" llama-index-llms-ollama = "^0.1.3" llama-index-llms-litellm = "^0.1.4" +alembic = "^1.13.1" +psycopg2-binary = "^2.9.9" [tool.poetry.group.dev.dependencies] pre-commit = "^3.7.1" diff --git a/src/backend/alembic.ini b/src/backend/alembic.ini new file mode 100644 index 0000000..7bb0089 --- /dev/null +++ b/src/backend/alembic.ini @@ -0,0 +1,110 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +; sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/src/backend/alembic/README b/src/backend/alembic/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/src/backend/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/src/backend/alembic/env.py b/src/backend/alembic/env.py new file mode 100644 index 0000000..55f790b --- /dev/null +++ b/src/backend/alembic/env.py @@ -0,0 +1,79 @@ +from logging.config import fileConfig + +from alembic import context +from sqlalchemy import engine_from_config, pool + +from backend.db.engine import create_connection_string +from backend.db.models import Base + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + +config.set_main_option("sqlalchemy.url", create_connection_string()) + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = create_connection_string() + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/src/backend/alembic/script.py.mako b/src/backend/alembic/script.py.mako new file mode 100644 index 0000000..fbc4b07 --- /dev/null +++ b/src/backend/alembic/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/src/backend/alembic/versions/17892ab566d1_initial_models.py b/src/backend/alembic/versions/17892ab566d1_initial_models.py new file mode 100644 index 0000000..533705b --- /dev/null +++ b/src/backend/alembic/versions/17892ab566d1_initial_models.py @@ -0,0 +1,83 @@ +"""initial models + +Revision ID: 17892ab566d1 +Revises: +Create Date: 2024-06-25 23:10:27.366511 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "17892ab566d1" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "chat_thread", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("model_name", sa.String(), nullable=False), + sa.Column( + "time_updated", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "time_created", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "chat_message", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "role", sa.Enum("USER", "ASSISTANT", name="messagerole"), nullable=False + ), + sa.Column("content", sa.String(), nullable=False), + sa.Column("parent_message_id", sa.Integer(), nullable=True), + sa.Column("chat_thread_id", sa.Integer(), nullable=False), + sa.Column("related_queries", sa.ARRAY(sa.String()), nullable=True), + sa.Column("image_results", sa.ARRAY(sa.String()), nullable=True), + sa.ForeignKeyConstraint( + ["chat_thread_id"], + ["chat_thread.id"], + ), + sa.ForeignKeyConstraint( + ["parent_message_id"], + ["chat_message.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "search_result", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("title", sa.String(), nullable=False), + sa.Column("url", sa.String(), nullable=False), + sa.Column("content", sa.String(), nullable=False), + sa.Column("chat_message_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["chat_message_id"], + ["chat_message.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("search_result") + op.drop_table("chat_message") + op.drop_table("chat_thread") + # ### end Alembic commands ### diff --git a/src/backend/chat.py b/src/backend/chat.py index a9c7213..bc7c559 100644 --- a/src/backend/chat.py +++ b/src/backend/chat.py @@ -1,9 +1,12 @@ import asyncio +import os from typing import AsyncIterator, List from fastapi import HTTPException +from sqlalchemy.orm import Session from backend.constants import get_model_string +from backend.db.chat import append_message, create_chat_thread, create_message from backend.llm.base import BaseLLM, EveryLLM from backend.prompts import CHAT_PROMPT, HISTORY_QUERY_REPHRASE from backend.related_queries import generate_related_queries @@ -13,6 +16,7 @@ ChatResponseEvent, FinalResponseStream, Message, + MessageRole, RelatedQueriesStream, SearchResult, SearchResultStream, @@ -21,7 +25,7 @@ TextChunkStream, ) from backend.search.search_service import perform_search -from backend.utils import is_local_model +from backend.utils import is_local_model, strtobool def rephrase_query_with_history( @@ -49,9 +53,12 @@ def format_context(search_results: List[SearchResult]) -> str: ) -async def stream_qa_objects(request: ChatRequest) -> AsyncIterator[ChatResponseEvent]: +async def stream_qa_objects( + request: ChatRequest, session: Session +) -> AsyncIterator[ChatResponseEvent]: try: - llm = EveryLLM(model=get_model_string(request.model)) + model_name = get_model_string(request.model) + llm = EveryLLM(model=model_name) yield ChatResponseEvent( event=StreamEvent.BEGIN_STREAM, @@ -105,15 +112,43 @@ async def stream_qa_objects(request: ChatRequest) -> AsyncIterator[ChatResponseE data=RelatedQueriesStream(related_queries=related_queries), ) - yield ChatResponseEvent( - event=StreamEvent.STREAM_END, - data=StreamEndStream(), - ) + thread_id = None + DB_ENABLED = strtobool(os.environ.get("DB_ENABLED", "true")) + if DB_ENABLED: + if request.thread_id is None: + thread = create_chat_thread(session=session, model_name=request.model) + thread_id = thread.id + else: + thread_id = request.thread_id + + user_message = append_message( + session=session, + thread_id=thread_id, + role=MessageRole.USER, + content=request.query, + ) + + _assistant_message = create_message( + session=session, + thread_id=thread_id, + role=MessageRole.ASSISTANT, + content=full_response, + parent_message_id=user_message.id, + search_results=search_results, + image_results=images, + related_queries=related_queries, + ) yield ChatResponseEvent( event=StreamEvent.FINAL_RESPONSE, data=FinalResponseStream(message=full_response), ) + + yield ChatResponseEvent( + event=StreamEvent.STREAM_END, + data=StreamEndStream(thread_id=thread_id), + ) + except Exception as e: detail = str(e) raise HTTPException(status_code=500, detail=detail) diff --git a/src/backend/db/chat.py b/src/backend/db/chat.py new file mode 100644 index 0000000..92374fb --- /dev/null +++ b/src/backend/db/chat.py @@ -0,0 +1,167 @@ +import re + +from sqlalchemy import select +from sqlalchemy.orm import Session, contains_eager + +from backend.db.models import ChatMessage as DBChatMessage +from backend.db.models import ChatThread as DBChatThread +from backend.db.models import SearchResult as DBSearchResult +from backend.schemas import ( + ChatMessage, + ChatSnapshot, + MessageRole, + SearchResult, + ThreadResponse, +) + + +def create_chat_thread(*, session: Session, model_name: str): + chat_thread = DBChatThread(model_name=model_name) + session.add(chat_thread) + session.commit() + return chat_thread + + +def create_search_results( + *, session: Session, search_results: list[SearchResult], chat_message_id: int +) -> list[DBSearchResult]: + db_search_results = [ + DBSearchResult( + url=result.url, + title=result.title, + content=result.content, + chat_message_id=chat_message_id, + ) + for result in search_results + ] + session.add_all(db_search_results) + session.commit() + return db_search_results + + +def append_message( + *, + session: Session, + thread_id: int, + role: MessageRole, + content: str, + search_results: list[SearchResult] | None = None, + image_results: list[str] | None = None, + related_queries: list[str] | None = None, +): + last_message = ( + session.query(DBChatMessage) + .filter(DBChatMessage.chat_thread_id == thread_id) + .order_by(DBChatMessage.id.desc()) + .first() + ) + + return create_message( + session=session, + thread_id=thread_id, + role=role, + content=content, + parent_message_id=last_message.id if last_message else None, + search_results=search_results, + image_results=image_results, + related_queries=related_queries, + ) + + +def create_message( + *, + session: Session, + thread_id: int, + role: MessageRole, + content: str, + parent_message_id: int | None = None, + search_results: list[SearchResult] | None = None, + image_results: list[str] | None = None, + related_queries: list[str] | None = None, +): + message = DBChatMessage( + chat_thread_id=thread_id, + role=role, + content=content, + parent_message_id=parent_message_id, + image_results=image_results or [], + related_queries=related_queries or [], + ) + + session.add(message) + session.flush() + + db_search_results = None + if search_results is not None: + db_search_results = create_search_results( + session=session, search_results=search_results, chat_message_id=message.id + ) + message.search_results = db_search_results or [] + + session.add(message) + session.commit() + return message + + +def get_chat_history(*, session: Session) -> list[ChatSnapshot]: + threads = ( + session.query(DBChatThread) + .join(DBChatThread.messages) + .options(contains_eager(DBChatThread.messages)) + .order_by(DBChatThread.time_created.desc(), DBChatMessage.id.asc()) + .all() + ) + threads = [thread for thread in threads if len(thread.messages) > 1] + + snapshots = [] + for thread in threads: + title = thread.messages[0].content + preview = thread.messages[1].content + + # Remove citations from the preview + citation_regex = re.compile(r"\[[0-9]+\]") + preview = citation_regex.sub("", preview) + + snapshots.append( + ChatSnapshot( + id=thread.id, + title=title, + date=thread.time_created, + preview=preview, + model_name=thread.model_name, + ) + ) + return snapshots + + +def map_search_result(search_result: DBSearchResult) -> SearchResult: + return SearchResult( + url=search_result.url, + title=search_result.title, + content=search_result.content, + ) + + +def get_thread(*, session: Session, thread_id: int) -> ThreadResponse: + stmt = ( + select(DBChatMessage) + .where(DBChatMessage.chat_thread_id == thread_id) + .order_by(DBChatMessage.id.asc()) + ) + db_messages = session.execute(stmt).scalars().all() + if len(db_messages) == 0: + raise ValueError(f"Thread with id {thread_id} not found") + + messages = [ + ChatMessage( + content=message.content, + role=message.role, + related_queries=message.related_queries or [], + sources=[ + map_search_result(result) for result in message.search_results or [] + ], + images=message.image_results or [], + ) + for message in db_messages + ] + return ThreadResponse(thread_id=thread_id, messages=messages) diff --git a/src/backend/db/engine.py b/src/backend/db/engine.py new file mode 100644 index 0000000..6f35a93 --- /dev/null +++ b/src/backend/db/engine.py @@ -0,0 +1,25 @@ +import os + +from dotenv import load_dotenv +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +load_dotenv() + +POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres" +POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD") or "password" +POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost" +POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432" +POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres" + + +def create_connection_string(): + return f"postgresql+psycopg2://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DB}" + + +engine = create_engine(create_connection_string()) + + +def get_session(): + with Session(engine) as session: + yield session diff --git a/src/backend/db/models.py b/src/backend/db/models.py new file mode 100644 index 0000000..58caab6 --- /dev/null +++ b/src/backend/db/models.py @@ -0,0 +1,67 @@ +import datetime + +from sqlalchemy import ARRAY, DateTime, Enum, ForeignKey, String, func +from sqlalchemy.orm import Mapped, declarative_base, mapped_column, relationship + +from backend.schemas import MessageRole + +Base = declarative_base() + + +class ChatThread(Base): + __tablename__ = "chat_thread" + id: Mapped[int] = mapped_column(primary_key=True) + + messages: Mapped[list["ChatMessage"]] = relationship( + "ChatMessage", back_populates="chat_thread" + ) + model_name: Mapped[str] = mapped_column(String) + + time_updated: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + onupdate=func.now(), + ) + time_created: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + + +class SearchResult(Base): + __tablename__ = "search_result" + id: Mapped[int] = mapped_column(primary_key=True) + title: Mapped[str] = mapped_column(String) + url: Mapped[str] = mapped_column(String) + content: Mapped[str] = mapped_column(String) + + chat_message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id")) + chat_message: Mapped["ChatMessage"] = relationship( + "ChatMessage", back_populates="search_results" + ) + + +class ChatMessage(Base): + __tablename__ = "chat_message" + id: Mapped[int] = mapped_column(primary_key=True) + role: Mapped[MessageRole] = mapped_column(Enum(MessageRole)) + content: Mapped[str] = mapped_column(String) + parent_message_id: Mapped[int | None] = mapped_column( + ForeignKey("chat_message.id"), nullable=True + ) + + chat_thread_id: Mapped[int] = mapped_column(ForeignKey("chat_thread.id")) + chat_thread: Mapped[ChatThread] = relationship( + ChatThread, back_populates="messages" + ) + + # AI Only + related_queries: Mapped[list[str] | None] = mapped_column( + ARRAY(String), nullable=True + ) + image_results: Mapped[list[str] | None] = mapped_column( + ARRAY(String), nullable=True + ) + + search_results: Mapped[list[SearchResult] | None] = relationship( + SearchResult, back_populates="chat_message" + ) diff --git a/src/backend/main.py b/src/backend/main.py index 694d25d..eb60151 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -6,16 +6,26 @@ import logfire from dotenv import load_dotenv -from fastapi import FastAPI, Request +from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.encoders import jsonable_encoder from fastapi.middleware.cors import CORSMiddleware from slowapi import Limiter from slowapi.errors import RateLimitExceeded from slowapi.util import get_ipaddr +from sqlalchemy.orm import Session from sse_starlette.sse import EventSourceResponse, ServerSentEvent from backend.chat import stream_qa_objects -from backend.schemas import ChatRequest, ChatResponseEvent, ErrorStream, StreamEvent +from backend.db.chat import get_chat_history, get_thread +from backend.db.engine import get_session +from backend.schemas import ( + ChatHistoryResponse, + ChatRequest, + ChatResponseEvent, + ErrorStream, + StreamEvent, + ThreadResponse, +) from backend.utils import strtobool from backend.validators import validate_model @@ -58,7 +68,7 @@ def configure_rate_limiting( storage_uri=redis_url, ) app.state.limiter = limiter - app.add_exception_handler(RateLimitExceeded, rate_limit_exceeded_handler) + app.add_exception_handler(RateLimitExceeded, rate_limit_exceeded_handler) # type: ignore def configure_middleware(app: FastAPI): @@ -89,12 +99,12 @@ def create_app() -> FastAPI: @app.post("/chat") @app.state.limiter.limit("4/min") async def chat( - chat_request: ChatRequest, request: Request + chat_request: ChatRequest, request: Request, session: Session = Depends(get_session) ) -> Generator[ChatResponseEvent, None, None]: async def generator(): try: validate_model(chat_request.model) - async for obj in stream_qa_objects(chat_request): + async for obj in stream_qa_objects(request=chat_request, session=session): if await request.is_disconnected(): break yield json.dumps(jsonable_encoder(obj)) @@ -105,4 +115,28 @@ async def generator(): await asyncio.sleep(0) return - return EventSourceResponse(generator(), media_type="text/event-stream") + return EventSourceResponse(generator(), media_type="text/event-stream") # type: ignore + + +@app.get("/history") +async def recents(session: Session = Depends(get_session)) -> ChatHistoryResponse: + DB_ENABLED = strtobool(os.environ.get("DB_ENABLED", "true")) + if DB_ENABLED: + try: + history = get_chat_history(session=session) + return ChatHistoryResponse(snapshots=history) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + else: + raise HTTPException( + status_code=400, + detail="Chat history is not available when DB is disabled. Please try self-hosting the app by following the instructions here: https://github.com/rashadphz/farfalle", + ) + + +@app.get("/thread/{thread_id}") +async def thread( + thread_id: int, session: Session = Depends(get_session) +) -> ThreadResponse: + thread = get_thread(session=session, thread_id=thread_id) + return thread diff --git a/src/backend/schemas.py b/src/backend/schemas.py index 89d9406..7870054 100644 --- a/src/backend/schemas.py +++ b/src/backend/schemas.py @@ -1,6 +1,7 @@ # Some of the code here is based on github.com/cohere-ai/cohere-toolkit/ import os +from datetime import datetime from enum import Enum from typing import List, Union @@ -31,6 +32,7 @@ class Message(BaseModel): class ChatRequest(BaseModel, plugin_settings=record_all): + thread_id: int | None = None query: str history: List[Message] = Field(default_factory=list) model: ChatModel = ChatModel.GPT_3_5_TURBO @@ -90,6 +92,7 @@ class RelatedQueriesStream(ChatObject, plugin_settings=record_all): class StreamEndStream(ChatObject, plugin_settings=record_all): + thread_id: int | None = None event_type: StreamEvent = StreamEvent.STREAM_END @@ -114,3 +117,29 @@ class ChatResponseEvent(BaseModel): FinalResponseStream, ErrorStream, ] + + +class ChatSnapshot(BaseModel): + id: int + title: str + date: datetime + preview: str + model_name: str + + +class ChatHistoryResponse(BaseModel): + snapshots: List[ChatSnapshot] = Field(default_factory=list) + + +class ChatMessage(BaseModel): + content: str + role: MessageRole + related_queries: List[str] | None = None + sources: List[SearchResult] | None = None + images: List[str] | None = None + is_error_message: bool = False + + +class ThreadResponse(BaseModel): + thread_id: int + messages: List[ChatMessage] = Field(default_factory=list) diff --git a/src/backend/search/search_service.py b/src/backend/search/search_service.py index 54e16d6..51f58a7 100644 --- a/src/backend/search/search_service.py +++ b/src/backend/search/search_service.py @@ -20,7 +20,7 @@ def get_searxng_base_url(): - searxng_base_url = os.getenv("SEARXNG_BASE_URL") + searxng_base_url = os.getenv("SEARXNG_BASE_URL", "http://localhost:8080/") if not searxng_base_url: raise HTTPException( status_code=500, diff --git a/src/frontend/generated/schemas.gen.ts b/src/frontend/generated/schemas.gen.ts index f163767..b97de6f 100644 --- a/src/frontend/generated/schemas.gen.ts +++ b/src/frontend/generated/schemas.gen.ts @@ -20,6 +20,82 @@ export const $BeginStream = { title: "BeginStream", } as const; +export const $ChatHistoryResponse = { + properties: { + snapshots: { + items: { + $ref: "#/components/schemas/ChatSnapshot", + }, + type: "array", + title: "Snapshots", + }, + }, + type: "object", + title: "ChatHistoryResponse", +} as const; + +export const $ChatMessage = { + properties: { + content: { + type: "string", + title: "Content", + }, + role: { + $ref: "#/components/schemas/MessageRole", + }, + related_queries: { + anyOf: [ + { + items: { + type: "string", + }, + type: "array", + }, + { + type: "null", + }, + ], + title: "Related Queries", + }, + sources: { + anyOf: [ + { + items: { + $ref: "#/components/schemas/SearchResult", + }, + type: "array", + }, + { + type: "null", + }, + ], + title: "Sources", + }, + images: { + anyOf: [ + { + items: { + type: "string", + }, + type: "array", + }, + { + type: "null", + }, + ], + title: "Images", + }, + is_error_message: { + type: "boolean", + title: "Is Error Message", + default: false, + }, + }, + type: "object", + required: ["content", "role"], + title: "ChatMessage", +} as const; + export const $ChatModel = { type: "string", enum: [ @@ -37,6 +113,17 @@ export const $ChatModel = { export const $ChatRequest = { properties: { + thread_id: { + anyOf: [ + { + type: "integer", + }, + { + type: "null", + }, + ], + title: "Thread Id", + }, query: { type: "string", title: "Query", @@ -99,6 +186,35 @@ export const $ChatResponseEvent = { title: "ChatResponseEvent", } as const; +export const $ChatSnapshot = { + properties: { + id: { + type: "integer", + title: "Id", + }, + title: { + type: "string", + title: "Title", + }, + date: { + type: "string", + format: "date-time", + title: "Date", + }, + preview: { + type: "string", + title: "Preview", + }, + model_name: { + type: "string", + title: "Model Name", + }, + }, + type: "object", + required: ["id", "title", "date", "preview", "model_name"], + title: "ChatSnapshot", +} as const; + export const $ErrorStream = { properties: { event_type: { @@ -255,6 +371,17 @@ export const $StreamEndStream = { ], default: "stream-end", }, + thread_id: { + anyOf: [ + { + type: "integer", + }, + { + type: "null", + }, + ], + title: "Thread Id", + }, }, type: "object", title: "StreamEndStream", @@ -294,6 +421,25 @@ export const $TextChunkStream = { title: "TextChunkStream", } as const; +export const $ThreadResponse = { + properties: { + thread_id: { + type: "integer", + title: "Thread Id", + }, + messages: { + items: { + $ref: "#/components/schemas/ChatMessage", + }, + type: "array", + title: "Messages", + }, + }, + type: "object", + required: ["thread_id"], + title: "ThreadResponse", +} as const; + export const $ValidationError = { properties: { loc: { diff --git a/src/frontend/generated/types.gen.ts b/src/frontend/generated/types.gen.ts index a6b6e3f..d466083 100644 --- a/src/frontend/generated/types.gen.ts +++ b/src/frontend/generated/types.gen.ts @@ -5,6 +5,19 @@ export type BeginStream = { query: string; }; +export type ChatHistoryResponse = { + snapshots?: Array; +}; + +export type ChatMessage = { + content: string; + role: MessageRole; + related_queries?: Array | null; + sources?: Array | null; + images?: Array | null; + is_error_message?: boolean; +}; + export enum ChatModel { LLAMA_3_70B = "llama-3-70b", GPT_4O = "gpt-4o", @@ -17,6 +30,7 @@ export enum ChatModel { } export type ChatRequest = { + thread_id?: number | null; query: string; history?: Array; model?: ChatModel; @@ -34,6 +48,14 @@ export type ChatResponseEvent = { | ErrorStream; }; +export type ChatSnapshot = { + id: number; + title: string; + date: string; + preview: string; + model_name: string; +}; + export type ErrorStream = { event_type?: StreamEvent; detail: string; @@ -77,6 +99,7 @@ export type SearchResultStream = { export type StreamEndStream = { event_type?: StreamEvent; + thread_id?: number | null; }; export enum StreamEvent { @@ -94,6 +117,11 @@ export type TextChunkStream = { text: string; }; +export type ThreadResponse = { + thread_id: number; + messages?: Array; +}; + export type ValidationError = { loc: Array; msg: string; diff --git a/src/frontend/package.json b/src/frontend/package.json index f64318d..e8b35c3 100644 --- a/src/frontend/package.json +++ b/src/frontend/package.json @@ -34,6 +34,7 @@ "hast-util-from-dom": "^5.0.0", "lodash": "^4.17.21", "lucide-react": "^0.376.0", + "moment": "^2.30.1", "next": "14.2.3", "next-themes": "^0.3.0", "react": "^18", diff --git a/src/frontend/pnpm-lock.yaml b/src/frontend/pnpm-lock.yaml index a7abccc..ee40cc6 100644 --- a/src/frontend/pnpm-lock.yaml +++ b/src/frontend/pnpm-lock.yaml @@ -77,6 +77,9 @@ dependencies: lucide-react: specifier: ^0.376.0 version: 0.376.0(react@18.3.1) + moment: + specifier: ^2.30.1 + version: 2.30.1 next: specifier: 14.2.3 version: 14.2.3(react-dom@18.3.1)(react@18.3.1) @@ -3595,6 +3598,10 @@ packages: ufo: 1.5.3 dev: false + /moment@2.30.1: + resolution: {integrity: sha512-uEmtNhbDOrWPFS+hdjFCBfy9f2YoyzRpwcl+DqpC6taX21FzsTLQVbMV/W7PzNSX6x/bhC1zA3c2UQ5NzH6how==} + dev: false + /ms@2.1.2: resolution: {integrity: sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==} diff --git a/src/frontend/src/app/globals.css b/src/frontend/src/app/globals.css index bd4791f..ba696f3 100644 --- a/src/frontend/src/app/globals.css +++ b/src/frontend/src/app/globals.css @@ -25,7 +25,7 @@ --accent: 240 4.8% 95.9%; --accent-foreground: 240 5.9% 10%; - --tint: 27.5 34.7% 51.6%; + --tint: 22.4 31.5% 35.8%; --tint-foreground: 25 76% 31%; --destructive: 0 84.2% 60.2%; diff --git a/src/frontend/src/app/history/page.tsx b/src/frontend/src/app/history/page.tsx new file mode 100644 index 0000000..57bc7ac --- /dev/null +++ b/src/frontend/src/app/history/page.tsx @@ -0,0 +1,37 @@ +"use client"; + +import { ErrorMessage } from "@/components/assistant-message"; +import RecentChat from "@/components/recent-chat"; +import { Separator } from "@/components/ui/separator"; +import { useChatHistory } from "@/hooks/history"; +import { HistoryIcon } from "lucide-react"; +import React from "react"; + +export default function RecentsPage() { + const { data: chats, isLoading, error } = useChatHistory(); + + if (!error && !chats) return
Loading...
; + + return ( +
+
+
+ +

Chat History

+
+ + {error && } + {chats && ( +
    + {chats.map((chat, index) => ( + + + {index < chats.length - 1 && } + + ))} +
+ )} +
+
+ ); +} diff --git a/src/frontend/src/app/search/[slug]/page.tsx b/src/frontend/src/app/search/[slug]/page.tsx new file mode 100644 index 0000000..f587891 --- /dev/null +++ b/src/frontend/src/app/search/[slug]/page.tsx @@ -0,0 +1,22 @@ +"use client"; + +import { useState, useEffect, Suspense } from "react"; +import { useParams } from "next/navigation"; +import { Separator } from "@/components/ui/separator"; +import { ChatMessage } from "../../../../generated"; +import { ChatPanel } from "@/components/chat-panel"; + +export default function ChatPage() { + const { slug } = useParams(); + const threadId = parseInt(slug as string, 10); + + return ( +
+
+ + + +
+
+ ); +} diff --git a/src/frontend/src/components/assistant-message.tsx b/src/frontend/src/components/assistant-message.tsx index c912800..37edfbb 100644 --- a/src/frontend/src/components/assistant-message.tsx +++ b/src/frontend/src/components/assistant-message.tsx @@ -1,4 +1,3 @@ -import { AssistantMessage } from "@/types"; import { MessageComponent, MessageComponentSkeleton } from "./message"; import RelatedQuestions from "./related-questions"; import { SearchResultsSkeleton, SearchResults } from "./search-results"; @@ -6,6 +5,7 @@ import { Section } from "./section"; import { AlertCircle } from "lucide-react"; import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; import { ImageSection, ImageSectionSkeleton } from "./image-section"; +import { ChatMessage } from "../../generated"; export function ErrorMessage({ content }: { content: string }) { return ( @@ -39,19 +39,19 @@ export const AssistantMessageContent = ({ isStreaming = false, onRelatedQuestionSelect, }: { - message: AssistantMessage; + message: ChatMessage; isStreaming?: boolean; onRelatedQuestionSelect: (question: string) => void; }) => { const { sources, content, - relatedQuestions, + related_queries, images, - isErrorMessage = false, + is_error_message = false, } = message; - if (isErrorMessage) { + if (is_error_message) { return ; } @@ -80,10 +80,10 @@ export const AssistantMessageContent = ({ )} - {relatedQuestions && relatedQuestions.length > 0 && ( + {related_queries && related_queries.length > 0 && (
diff --git a/src/frontend/src/components/chat-panel.tsx b/src/frontend/src/components/chat-panel.tsx index b2c1fcb..e3120f1 100644 --- a/src/frontend/src/components/chat-panel.tsx +++ b/src/frontend/src/components/chat-panel.tsx @@ -2,8 +2,7 @@ import { useParams, useSearchParams } from "next/navigation"; import { useChat } from "@/hooks/chat"; -import { useMessageStore } from "@/stores"; -import { MessageType } from "@/types"; +import { useChatStore } from "@/stores"; import { Suspense, useCallback, useEffect, useRef, useState } from "react"; import { AskInput } from "./ask-input"; @@ -11,12 +10,15 @@ import MessagesList from "./messages-list"; import { ModelSelection } from "./model-selection"; import { StarterQuestionsList } from "./starter-questions"; import LocalToggle from "./local-toggle"; +import { useChatThread } from "@/hooks/threads"; +import { MessageRole } from "../../generated"; +import { LoaderIcon } from "lucide-react"; const useAutoScroll = (ref: React.RefObject) => { - const { messages } = useMessageStore(); + const { messages } = useChatStore(); useEffect(() => { - if (messages.at(-1)?.role === MessageType.USER) { + if (messages.at(-1)?.role === MessageRole.USER) { ref.current?.scrollIntoView({ behavior: "smooth", block: "end", @@ -29,7 +31,7 @@ const useAutoResizeInput = ( ref: React.RefObject, setWidth: (width: number) => void, ) => { - const { messages } = useMessageStore(); + const { messages } = useChatStore(); useEffect(() => { const updatePosition = () => { @@ -51,13 +53,14 @@ const useAutoFocus = (ref: React.RefObject) => { }, [ref]); }; -export const ChatPanel = () => { +export const ChatPanel = ({ threadId }: { threadId?: number }) => { const searchParams = useSearchParams(); const queryMessage = searchParams.get("q"); const hasRun = useRef(false); const { handleSend, streamingMessage } = useChat(); - const { messages } = useMessageStore(); + const { messages, setMessages, setThreadId } = useChatStore(); + const { data: thread, isLoading, error } = useChatThread(threadId); const [width, setWidth] = useState(0); const messagesRef = useRef(null); @@ -70,28 +73,47 @@ export const ChatPanel = () => { useEffect(() => { if (queryMessage && !hasRun.current) { + setThreadId(null); hasRun.current = true; handleSend(queryMessage); } }, [queryMessage]); + useEffect(() => { + if (!thread) return; + setThreadId(thread.thread_id); + setMessages(thread.messages || []); + }, [threadId, thread, setMessages, setThreadId]); + + useEffect(() => { + if (messages.length == 0) { + setThreadId(null); + } + }, [messages, setThreadId]); + return ( <> - {messages.length > 0 ? ( -
- -
-
- + {messages.length > 0 || threadId ? ( + isLoading ? ( +
+
-
+ ) : ( +
+ +
+
+ +
+
+ ) ) : (
diff --git a/src/frontend/src/components/message.tsx b/src/frontend/src/components/message.tsx index 8ee88ee..3f283f8 100644 --- a/src/frontend/src/components/message.tsx +++ b/src/frontend/src/components/message.tsx @@ -4,8 +4,8 @@ import rehypeRaw from "rehype-raw"; import _ from "lodash"; import { cn } from "@/lib/utils"; -import { AssistantMessage } from "@/types"; import { Skeleton } from "./ui/skeleton"; +import { ChatMessage } from "../../generated"; function chunkString(str: string): string[] { const words = str.split(" "); @@ -14,7 +14,7 @@ function chunkString(str: string): string[] { } export interface MessageProps { - message: AssistantMessage; + message: ChatMessage; isStreaming?: boolean; } diff --git a/src/frontend/src/components/messages-list.tsx b/src/frontend/src/components/messages-list.tsx index d603779..0a2c5a1 100644 --- a/src/frontend/src/components/messages-list.tsx +++ b/src/frontend/src/components/messages-list.tsx @@ -1,8 +1,8 @@ -import { AssistantMessage, ChatMessage, MessageType } from "@/types"; import { AssistantMessageContent } from "./assistant-message"; import { Separator } from "./ui/separator"; import { UserMessageContent } from "./user-message"; import { memo } from "react"; +import { ChatMessage, MessageRole } from "../../generated"; const MessagesList = ({ messages, @@ -10,13 +10,13 @@ const MessagesList = ({ onRelatedQuestionSelect, }: { messages: ChatMessage[]; - streamingMessage: AssistantMessage | null; + streamingMessage: ChatMessage | null; onRelatedQuestionSelect: (question: string) => void; }) => { return (
{messages.map((message, index) => - message.role === MessageType.USER ? ( + message.role === MessageRole.USER ? ( ) : ( <> diff --git a/src/frontend/src/components/mode-toggle.tsx b/src/frontend/src/components/mode-toggle.tsx index 328068b..12683c5 100644 --- a/src/frontend/src/components/mode-toggle.tsx +++ b/src/frontend/src/components/mode-toggle.tsx @@ -31,9 +31,9 @@ export function ModeToggle() { className="flex gap-2 items-center font-medium" onClick={() => setTheme(theme)} > - {theme === "light" && } - {theme === "dark" && } - {theme === "system" && } + {theme === "light" && } + {theme === "dark" && } + {theme === "system" && } {theme.charAt(0).toUpperCase() + theme.slice(1)} ))} diff --git a/src/frontend/src/components/model-selection.tsx b/src/frontend/src/components/model-selection.tsx index 3d857bf..716ac72 100644 --- a/src/frontend/src/components/model-selection.tsx +++ b/src/frontend/src/components/model-selection.tsx @@ -20,7 +20,7 @@ import { SparklesIcon, WandSparklesIcon, } from "lucide-react"; -import { useConfigStore, useMessageStore } from "@/stores"; +import { useConfigStore, useChatStore } from "@/stores"; import { ChatModel } from "../../generated"; import { isCloudModel, isLocalModel } from "@/lib/utils"; import _ from "lodash"; @@ -33,7 +33,7 @@ type Model = { icon: React.ReactNode; }; -const modelMap: Record = { +export const modelMap: Record = { [ChatModel.GPT_3_5_TURBO]: { name: "Fast", description: "OpenAI/GPT-3.5-turbo", diff --git a/src/frontend/src/components/nav.tsx b/src/frontend/src/components/nav.tsx index 1d02b5f..a8f6599 100644 --- a/src/frontend/src/components/nav.tsx +++ b/src/frontend/src/components/nav.tsx @@ -4,8 +4,9 @@ import Link from "next/link"; import { ModeToggle } from "./mode-toggle"; import { useTheme } from "next-themes"; import { Button } from "./ui/button"; -import { PlusIcon } from "lucide-react"; -import { useMessageStore } from "@/stores"; +import { HistoryIcon, PlusIcon } from "lucide-react"; +import { useChatStore } from "@/stores"; +import { useRouter } from "next/navigation"; const NewChatButton = () => { return ( @@ -21,15 +22,16 @@ const TextLogo = () => { }; export function Navbar() { + const router = useRouter(); const { theme } = useTheme(); - const { messages } = useMessageStore(); + const { messages } = useChatStore(); const onHomePage = messages.length === 0; return (
- location.reload()}> + (location.href = "/")}> Logo {onHomePage ? : }
-
+
+ +
+
+ + History +
+
+
diff --git a/src/frontend/src/components/recent-chat.tsx b/src/frontend/src/components/recent-chat.tsx new file mode 100644 index 0000000..5f74366 --- /dev/null +++ b/src/frontend/src/components/recent-chat.tsx @@ -0,0 +1,43 @@ +import { HourglassIcon } from "lucide-react"; +import { ChatModel, ChatSnapshot } from "../../generated"; +import moment from "moment"; +import Link from "next/link"; +import { modelMap } from "./model-selection"; + +export default function RecentChat({ + id, + title, + date, + preview, + model_name, +}: ChatSnapshot) { + const formattedDate = moment(date).fromNow(); + const model = + model_name in modelMap ? modelMap[model_name as ChatModel] : null; + + return ( + +
+
+

+ {title} +

+

{preview}

+
+
+
+ +

{formattedDate}

+
+
+ {model?.smallIcon} +

{model?.name}

+
+
+
+ + ); +} diff --git a/src/frontend/src/components/search-results.tsx b/src/frontend/src/components/search-results.tsx index eeb342c..42124cd 100644 --- a/src/frontend/src/components/search-results.tsx +++ b/src/frontend/src/components/search-results.tsx @@ -2,14 +2,13 @@ "use client"; import { useState } from "react"; import { Card, CardContent } from "@/components/ui/card"; -import { Button } from "@/components/ui/button"; -import { SearchResult } from "@/types"; import { Skeleton } from "./ui/skeleton"; import { HoverCard, HoverCardContent, HoverCardTrigger, } from "@/components/ui/hover-card"; +import { SearchResult } from "../../generated"; export const SearchResultsSkeleton = () => { return ( diff --git a/src/frontend/src/components/user-message.tsx b/src/frontend/src/components/user-message.tsx index e95217f..0857ae7 100644 --- a/src/frontend/src/components/user-message.tsx +++ b/src/frontend/src/components/user-message.tsx @@ -1,6 +1,6 @@ -import { UserMessage } from "@/types"; +import { ChatMessage } from "../../generated"; -export const UserMessageContent = ({ message }: { message: UserMessage }) => { +export const UserMessageContent = ({ message }: { message: ChatMessage }) => { return (
{message.content} diff --git a/src/frontend/src/hooks/chat.ts b/src/frontend/src/hooks/chat.ts index 7aeb24d..5e0d28d 100644 --- a/src/frontend/src/hooks/chat.ts +++ b/src/frontend/src/hooks/chat.ts @@ -1,5 +1,6 @@ import { useMutation } from "@tanstack/react-query"; import { + ChatMessage, ChatRequest, ChatResponseEvent, ErrorStream, @@ -8,6 +9,7 @@ import { RelatedQueriesStream, SearchResult, SearchResultStream, + StreamEndStream, StreamEvent, TextChunkStream, } from "../../generated"; @@ -17,10 +19,9 @@ import { FetchEventSourceInit, } from "@microsoft/fetch-event-source"; import { useState } from "react"; -import { AssistantMessage, ChatMessage, MessageType } from "@/types"; -import { useConfigStore, useMessageStore } from "@/stores"; -import { useToast } from "@/components/ui/use-toast"; +import { useConfigStore, useChatStore } from "@/stores"; import { env } from "../env.mjs"; +import { useRouter } from "next/navigation"; const BASE_URL = env.NEXT_PUBLIC_API_URL; @@ -47,7 +48,7 @@ const streamChat = async ({ const convertToChatRequest = (query: string, history: ChatMessage[]) => { const newHistory: Message[] = history.map((message) => ({ role: - message.role === MessageType.USER + message.role === MessageRole.USER ? MessageRole.USER : MessageRole.ASSISTANT, content: message.content, @@ -56,27 +57,29 @@ const convertToChatRequest = (query: string, history: ChatMessage[]) => { }; export const useChat = () => { - const { addMessage, messages } = useMessageStore(); + const { addMessage, messages, threadId, setThreadId } = useChatStore(); const { model } = useConfigStore(); + const router = useRouter(); - const [streamingMessage, setStreamingMessage] = - useState(null); + const [streamingMessage, setStreamingMessage] = useState( + null, + ); const handleEvent = ( eventItem: ChatResponseEvent, state: { response: string; sources: SearchResult[]; - relatedQuestions: string[]; + related_queries: string[]; images: string[]; }, ) => { switch (eventItem.event) { case StreamEvent.BEGIN_STREAM: setStreamingMessage({ - role: MessageType.ASSISTANT, + role: MessageRole.ASSISTANT, content: "", - relatedQuestions: [], + related_queries: [], sources: [], }); break; @@ -89,38 +92,45 @@ export const useChat = () => { state.response += (eventItem.data as TextChunkStream).text ?? ""; break; case StreamEvent.RELATED_QUERIES: - state.relatedQuestions = + state.related_queries = (eventItem.data as RelatedQueriesStream).related_queries ?? []; break; case StreamEvent.STREAM_END: + const endData = eventItem.data as StreamEndStream; addMessage({ - role: MessageType.ASSISTANT, + role: MessageRole.ASSISTANT, content: state.response, - relatedQuestions: state.relatedQuestions, + related_queries: state.related_queries, sources: state.sources, images: state.images, }); setStreamingMessage(null); + + // Only if the backend is using the DB + if (endData.thread_id) { + setThreadId(endData.thread_id); + window.history.pushState({}, "", `/search/${endData.thread_id}`); + } return; case StreamEvent.FINAL_RESPONSE: return; case StreamEvent.ERROR: const errorData = eventItem.data as ErrorStream; addMessage({ - role: MessageType.ASSISTANT, + role: MessageRole.ASSISTANT, content: errorData.detail, - relatedQuestions: [], + related_queries: [], sources: [], images: [], - isErrorMessage: true, + is_error_message: true, }); setStreamingMessage(null); return; } setStreamingMessage({ - role: MessageType.ASSISTANT, + role: MessageRole.ASSISTANT, content: state.response, - relatedQuestions: state.relatedQuestions, + related_queries: state.related_queries, sources: state.sources, images: state.images, }); @@ -132,13 +142,14 @@ export const useChat = () => { const state = { response: "", sources: [], - relatedQuestions: [], + related_queries: [], images: [], }; - addMessage({ role: MessageType.USER, content: request.query }); + addMessage({ role: MessageRole.USER, content: request.query }); const req = { ...request, + thread_id: threadId, model, }; await streamChat({ diff --git a/src/frontend/src/hooks/history.ts b/src/frontend/src/hooks/history.ts new file mode 100644 index 0000000..f6cad6d --- /dev/null +++ b/src/frontend/src/hooks/history.ts @@ -0,0 +1,23 @@ +import { useQuery } from "@tanstack/react-query"; +import { env } from "@/env.mjs"; +import { ChatSnapshot } from "../../generated"; + +const BASE_URL = env.NEXT_PUBLIC_API_URL; + +export const fetchChatHistory = async (): Promise => { + const response = await fetch(`${BASE_URL}/history`); + if (!response.ok) { + const errorData = await response.json(); + throw new Error(errorData.detail || "Failed to fetch chat history"); + } + const data = await response.json(); + return data.snapshots; +}; + +export const useChatHistory = () => { + return useQuery({ + queryKey: ["chatHistory"], + queryFn: fetchChatHistory, + retry: false, + }); +}; diff --git a/src/frontend/src/hooks/threads.ts b/src/frontend/src/hooks/threads.ts new file mode 100644 index 0000000..fd1684b --- /dev/null +++ b/src/frontend/src/hooks/threads.ts @@ -0,0 +1,23 @@ +import { useQuery } from "@tanstack/react-query"; +import { env } from "@/env.mjs"; +import { ThreadResponse } from "../../generated"; + +const BASE_URL = env.NEXT_PUBLIC_API_URL; + +const fetchChatThread = async (threadId: number): Promise => { + const response = await fetch(`${BASE_URL}/thread/${threadId}`); + return await response.json(); +}; + +export const useChatThread = (threadId?: number) => { + const { data, isLoading, error } = useQuery({ + queryKey: ["thread", threadId], + queryFn: async () => { + if (!threadId) { + return null; + } + return fetchChatThread(threadId); + }, + }); + return { data, isLoading, error }; +}; diff --git a/src/frontend/src/stores/index.ts b/src/frontend/src/stores/index.ts index 8ee9518..b910a98 100644 --- a/src/frontend/src/stores/index.ts +++ b/src/frontend/src/stores/index.ts @@ -1,9 +1,9 @@ import { create } from "zustand"; import { persist } from "zustand/middleware"; import { ConfigStore, createConfigSlice } from "./slices/configSlice"; -import { createMessageSlice, MessageStore } from "./slices/messageSlice"; +import { createMessageSlice, ChatStore } from "./slices/messageSlice"; -type StoreState = MessageStore & ConfigStore; +type StoreState = ChatStore & ConfigStore; const useStore = create()( persist( @@ -21,10 +21,13 @@ const useStore = create()( ), ); -export const useMessageStore = () => +export const useChatStore = () => useStore((state) => ({ messages: state.messages, addMessage: state.addMessage, + setMessages: state.setMessages, + threadId: state.threadId, + setThreadId: state.setThreadId, })); export const useConfigStore = () => diff --git a/src/frontend/src/stores/slices/messageSlice.ts b/src/frontend/src/stores/slices/messageSlice.ts index 5028978..d5c20fb 100644 --- a/src/frontend/src/stores/slices/messageSlice.ts +++ b/src/frontend/src/stores/slices/messageSlice.ts @@ -1,23 +1,26 @@ import { create, StateCreator } from "zustand"; -import { ChatMessage } from "@/types"; +import { ChatMessage } from "../../../generated"; type State = { + threadId: number | null; messages: ChatMessage[]; }; type Actions = { addMessage: (message: ChatMessage) => void; + setThreadId: (threadId: number | null) => void; + setMessages: (messages: ChatMessage[]) => void; }; -export type MessageStore = State & Actions; +export type ChatStore = State & Actions; -export const createMessageSlice: StateCreator< - MessageStore, - [], - [], - MessageStore -> = (set) => ({ +export const createMessageSlice: StateCreator = ( + set, +) => ({ + threadId: null, messages: [], addMessage: (message: ChatMessage) => set((state) => ({ messages: [...state.messages, message] })), + setThreadId: (threadId: number | null) => set((state) => ({ threadId })), + setMessages: (messages: ChatMessage[]) => set((state) => ({ messages })), }); diff --git a/src/frontend/src/types.ts b/src/frontend/src/types.ts deleted file mode 100644 index 6da6670..0000000 --- a/src/frontend/src/types.ts +++ /dev/null @@ -1,29 +0,0 @@ -export type SearchResult = { - title: string; - url: string; - content: string; -}; - -export enum MessageType { - USER = "user", - ASSISTANT = "assistant", -} - -export type BaseMessage = { - role: MessageType; - content: string; -}; - -export type UserMessage = BaseMessage & { - role: MessageType.USER; -}; - -export type AssistantMessage = BaseMessage & { - role: MessageType.ASSISTANT; - sources?: SearchResult[]; - relatedQuestions?: string[]; - images?: string[]; - isErrorMessage?: boolean; -}; - -export type ChatMessage = UserMessage | AssistantMessage;