diff --git a/django_app/poetry.lock b/django_app/poetry.lock index bc17dd116..2a3ff8d4c 100644 --- a/django_app/poetry.lock +++ b/django_app/poetry.lock @@ -339,17 +339,17 @@ files = [ [[package]] name = "boto3" -version = "1.34.162" +version = "1.35.45" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.34.162-py3-none-any.whl", hash = "sha256:d6f6096bdab35a0c0deff469563b87d184a28df7689790f7fe7be98502b7c590"}, - {file = "boto3-1.34.162.tar.gz", hash = "sha256:873f8f5d2f6f85f1018cbb0535b03cceddc7b655b61f66a0a56995238804f41f"}, + {file = "boto3-1.35.45-py3-none-any.whl", hash = "sha256:f16c7edfcbbeb0a0c22d67d6ebbfcb332fa78d3ea88275e082260ba04fe65347"}, + {file = "boto3-1.35.45.tar.gz", hash = "sha256:9f4a081e1940846171b51d903000a04322f1356d53225ce1028fc1760a155a70"}, ] [package.dependencies] -botocore = ">=1.34.162,<1.35.0" +botocore = ">=1.35.45,<1.36.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -767,13 +767,13 @@ xray = ["mypy-boto3-xray (>=1.35.0,<1.36.0)"] [[package]] name = "botocore" -version = "1.34.162" +version = "1.35.45" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.34.162-py3-none-any.whl", hash = "sha256:2d918b02db88d27a75b48275e6fb2506e9adaaddbec1ffa6a8a0898b34e769be"}, - {file = "botocore-1.34.162.tar.gz", hash = "sha256:adc23be4fb99ad31961236342b7cbf3c0bfc62532cd02852196032e8c0d682f3"}, + {file = "botocore-1.35.45-py3-none-any.whl", hash = "sha256:e07e170975721c94ec1e3bf71a484552ad63e2499f769dd14f9f37375b4993fd"}, + {file = "botocore-1.35.45.tar.gz", hash = "sha256:9a898bfdd6b0027fee2018711192c15c2716bf6a7096b1168bd8a896df3664a1"}, ] [package.dependencies] @@ -782,7 +782,7 @@ python-dateutil = ">=2.1,<3.0.0" urllib3 = {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""} [package.extras] -crt = ["awscrt (==0.21.2)"] +crt = ["awscrt (==0.22.0)"] [[package]] name = "botocore-stubs" @@ -2472,133 +2472,135 @@ testing = ["matplotlib (>=2.2.5)", "pytest (>=5.0.1)", "pytest-cov (>=3.0.0)"] [[package]] name = "langchain" -version = "0.2.16" +version = "0.3.4" description = "Building applications with LLMs through composability" optional = false -python-versions = "<4.0,>=3.8.1" +python-versions = "<4.0,>=3.9" files = [ - {file = "langchain-0.2.16-py3-none-any.whl", hash = "sha256:8f59ee8b45f268df4b924ea3b9c63e49286efa756d16b3f6a9de5c6e502c36e1"}, - {file = "langchain-0.2.16.tar.gz", hash = "sha256:ffb426a76a703b73ac69abad77cd16eaf03dda76b42cff55572f592d74944166"}, + {file = "langchain-0.3.4-py3-none-any.whl", hash = "sha256:7a1241d9429510d2083c62df0da998a7b2b05c730cd4255b89da9d47c57f48fd"}, + {file = "langchain-0.3.4.tar.gz", hash = "sha256:3596515fcd0157dece6ec96e0240d29f4cf542d91ecffc815d32e35198dfff37"}, ] [package.dependencies] aiohttp = ">=3.8.3,<4.0.0" -langchain-core = ">=0.2.38,<0.3.0" -langchain-text-splitters = ">=0.2.0,<0.3.0" +langchain-core = ">=0.3.12,<0.4.0" +langchain-text-splitters = ">=0.3.0,<0.4.0" langsmith = ">=0.1.17,<0.2.0" numpy = {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""} -pydantic = ">=1,<3" +pydantic = ">=2.7.4,<3.0.0" PyYAML = ">=5.3" requests = ">=2,<3" SQLAlchemy = ">=1.4,<3" -tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0" +tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10" [[package]] name = "langchain-aws" -version = "0.1.18" +version = "0.2.3" description = "An integration package connecting AWS and LangChain" optional = false -python-versions = "<4.0,>=3.8.1" +python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_aws-0.1.18-py3-none-any.whl", hash = "sha256:54c65157c67837cd932a6a8c536b52e2308c4e917bccd187d6f42651943cdd51"}, - {file = "langchain_aws-0.1.18.tar.gz", hash = "sha256:2e59efc95ef6758581b7e769cd610b750c76635c437a4f43c454a14cfab67c64"}, + {file = "langchain_aws-0.2.3-py3-none-any.whl", hash = "sha256:517b946802d94b12f54c8464dae91108f54a93b6679e1b302979a9b50a99229d"}, + {file = "langchain_aws-0.2.3.tar.gz", hash = "sha256:ed6c5dcc1f2e9e814db3107e4968a94680f20b66a110492e6906a05972f8e4ea"}, ] [package.dependencies] -boto3 = ">=1.34.131,<1.35.0" -langchain-core = ">=0.2.33,<0.3" +boto3 = ">=1.34.131" +langchain-core = ">=0.3.2,<0.4" numpy = {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""} +pydantic = ">=2,<3" [[package]] name = "langchain-community" -version = "0.2.17" +version = "0.3.3" description = "Community contributed LangChain integrations." optional = false -python-versions = "<4.0,>=3.8.1" +python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_community-0.2.17-py3-none-any.whl", hash = "sha256:d07c31b641e425fb8c3e7148ad6a62e1b54a9adac6e1173021a7dd3148266063"}, - {file = "langchain_community-0.2.17.tar.gz", hash = "sha256:b0745c1fcf1bd532ed4388f90b47139d6a6c6ba48a87aa68aa32d4d6bb97259d"}, + {file = "langchain_community-0.3.3-py3-none-any.whl", hash = "sha256:319cfc2f923a066c91fbb8e02decd7814018af952b6b98298b8ac9d30ea1da56"}, + {file = "langchain_community-0.3.3.tar.gz", hash = "sha256:bfb3f2b219aed21087e0ecb7d2ebd1c81401c02b92239e11645c822d5be63f80"}, ] [package.dependencies] aiohttp = ">=3.8.3,<4.0.0" dataclasses-json = ">=0.5.7,<0.7" -langchain = ">=0.2.16,<0.3.0" -langchain-core = ">=0.2.39,<0.3.0" -langsmith = ">=0.1.112,<0.2.0" +langchain = ">=0.3.4,<0.4.0" +langchain-core = ">=0.3.12,<0.4.0" +langsmith = ">=0.1.125,<0.2.0" numpy = {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""} +pydantic-settings = ">=2.4.0,<3.0.0" PyYAML = ">=5.3" requests = ">=2,<3" SQLAlchemy = ">=1.4,<3" -tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0" +tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10" [[package]] name = "langchain-core" -version = "0.2.41" +version = "0.3.12" description = "Building applications with LLMs through composability" optional = false -python-versions = "<4.0,>=3.8.1" +python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_core-0.2.41-py3-none-any.whl", hash = "sha256:3278fda5ba9a05defae8bb19f1226032add6aab21917db7b3bc74e750e263e84"}, - {file = "langchain_core-0.2.41.tar.gz", hash = "sha256:bc12032c5a298d85be754ccb129bc13ea21ccb1d6e22f8d7ba18b8da64315bb5"}, + {file = "langchain_core-0.3.12-py3-none-any.whl", hash = "sha256:46050d34f5fa36dc57dca971c6a26f505643dd05ee0492c7ac286d0a78a82037"}, + {file = "langchain_core-0.3.12.tar.gz", hash = "sha256:98a3c078e375786aa84939bfd1111263af2f3bc402bbe2cac9fa18a387459cf2"}, ] [package.dependencies] jsonpatch = ">=1.33,<2.0" -langsmith = ">=0.1.112,<0.2.0" +langsmith = ">=0.1.125,<0.2.0" packaging = ">=23.2,<25" pydantic = [ - {version = ">=1,<3", markers = "python_full_version < \"3.12.4\""}, + {version = ">=2.5.2,<3.0.0", markers = "python_full_version < \"3.12.4\""}, {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, ] PyYAML = ">=5.3" -tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0" +tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10.0.0" typing-extensions = ">=4.7" [[package]] name = "langchain-elasticsearch" -version = "0.2.2" +version = "0.3.0" description = "An integration package connecting Elasticsearch and LangChain" optional = false -python-versions = "<4.0,>=3.8.1" +python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_elasticsearch-0.2.2-py3-none-any.whl", hash = "sha256:2b0ae1637afc6890c371fc67be2151954db3c7f5382a07dafe4cf3bd762c6f26"}, - {file = "langchain_elasticsearch-0.2.2.tar.gz", hash = "sha256:7abe1cdee9f3b1a5f4152d2f79514359d0b95dc6d3923d3adb4ac431be854545"}, + {file = "langchain_elasticsearch-0.3.0-py3-none-any.whl", hash = "sha256:83d26d2a29f7628bcf2c655fe988930c90249e78e6f326c0be218ad4fa96f1f4"}, + {file = "langchain_elasticsearch-0.3.0.tar.gz", hash = "sha256:36c2963c432777e05b6613f57ee477de4b6e4be4ac4c1306601ab64de004a297"}, ] [package.dependencies] elasticsearch = {version = ">=8.13.1,<9.0.0", extras = ["vectorstore-mmr"]} -langchain-core = ">=0.1.50,<0.3" +langchain-core = ">=0.3.0,<0.4.0" [[package]] name = "langchain-openai" -version = "0.1.25" +version = "0.2.3" description = "An integration package connecting OpenAI and LangChain" optional = false -python-versions = "<4.0,>=3.8.1" +python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_openai-0.1.25-py3-none-any.whl", hash = "sha256:f0b34a233d0d9cb8fce6006c903e57085c493c4f0e32862b99063b96eaedb109"}, - {file = "langchain_openai-0.1.25.tar.gz", hash = "sha256:eb116f744f820247a72f54313fb7c01524fba0927120d4e899e5e4ab41ad3928"}, + {file = "langchain_openai-0.2.3-py3-none-any.whl", hash = "sha256:f498c94817c980cb302439b95d3f3275cdf2743e022ee674692c75898523cf57"}, + {file = "langchain_openai-0.2.3.tar.gz", hash = "sha256:e142031704de1104735f503f76352c53b27ac0a2806466392993c4508c42bf0c"}, ] [package.dependencies] -langchain-core = ">=0.2.40,<0.3.0" -openai = ">=1.40.0,<2.0.0" +langchain-core = ">=0.3.12,<0.4.0" +openai = ">=1.52.0,<2.0.0" tiktoken = ">=0.7,<1" [[package]] name = "langchain-text-splitters" -version = "0.2.4" +version = "0.3.0" description = "LangChain text splitting utilities" optional = false -python-versions = "<4.0,>=3.8.1" +python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_text_splitters-0.2.4-py3-none-any.whl", hash = "sha256:2702dee5b7cbdd595ccbe43b8d38d01a34aa8583f4d6a5a68ad2305ae3e7b645"}, - {file = "langchain_text_splitters-0.2.4.tar.gz", hash = "sha256:f7daa7a3b0aa8309ce248e2e2b6fc8115be01118d336c7f7f7dfacda0e89bf29"}, + {file = "langchain_text_splitters-0.3.0-py3-none-any.whl", hash = "sha256:e84243e45eaff16e5b776cd9c81b6d07c55c010ebcb1965deb3d1792b7358e83"}, + {file = "langchain_text_splitters-0.3.0.tar.gz", hash = "sha256:f9fe0b4d244db1d6de211e7343d4abc4aa90295aa22e1f0c89e51f33c55cd7ce"}, ] [package.dependencies] -langchain-core = ">=0.2.38,<0.3.0" +langchain-core = ">=0.3.0,<0.4.0" [[package]] name = "langgraph" @@ -4093,12 +4095,12 @@ develop = true boto3 = "^1.34.160" elasticsearch = "^8.15.0" kneed = "^0.8.5" -langchain = "^0.2.13" -langchain-aws = "^0.1.17" -langchain-community = "^0.2.12" -langchain-elasticsearch = "^0.2.2" -langchain_openai = "^0.1.21" -langgraph = "^0.2.15" +langchain = "^0.3.4" +langchain-aws = ">0.1.17" +langchain-community = ">0.2.12" +langchain-elasticsearch = ">0.2.2" +langchain_openai = ">0.1.21" +langgraph = "^0.2.39" opensearch-py = "^2.7.1" pydantic = "^2.7.1" pydantic-settings = "^2.3.4" @@ -4281,13 +4283,13 @@ six = ">=1.7.0" [[package]] name = "rich" -version = "13.9.2" +version = "13.9.3" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" optional = false python-versions = ">=3.8.0" files = [ - {file = "rich-13.9.2-py3-none-any.whl", hash = "sha256:8c82a3d3f8dcfe9e734771313e606b39d8247bb6b826e196f4914b333b743cf1"}, - {file = "rich-13.9.2.tar.gz", hash = "sha256:51a2c62057461aaf7152b4d611168f93a9fc73068f8ded2790f29fe2b5366d0c"}, + {file = "rich-13.9.3-py3-none-any.whl", hash = "sha256:9836f5096eb2172c9e77df411c1b009bace4193d6a481d534fea75ebba758283"}, + {file = "rich-13.9.3.tar.gz", hash = "sha256:bc1e01b899537598cf02579d2b9f4a415104d3fc439313a7a2c165d76557a08e"}, ] [package.dependencies] @@ -4832,13 +4834,13 @@ yaml = ["pyyaml"] [[package]] name = "tenacity" -version = "8.5.0" +version = "9.0.0" description = "Retry code until it succeeds" optional = false python-versions = ">=3.8" files = [ - {file = "tenacity-8.5.0-py3-none-any.whl", hash = "sha256:b594c2a5945830c267ce6b79a166228323ed52718f30302c1359836112346687"}, - {file = "tenacity-8.5.0.tar.gz", hash = "sha256:8bc6c0c8a09b31e6cad13c47afbed1a567518250a9a171418582ed8d9c20ca78"}, + {file = "tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539"}, + {file = "tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b"}, ] [package.extras] diff --git a/django_app/redbox_app/redbox_core/consumers.py b/django_app/redbox_app/redbox_core/consumers.py index dd75b94f8..7ea81abf8 100644 --- a/django_app/redbox_app/redbox_core/consumers.py +++ b/django_app/redbox_app/redbox_core/consumers.py @@ -4,7 +4,6 @@ from collections.abc import Mapping, Sequence from typing import Any, ClassVar from uuid import UUID -from itertools import groupby from channels.db import database_sync_to_async from channels.generic.websocket import AsyncWebsocketConsumer @@ -210,14 +209,14 @@ def save_ai_message( file=file, text=citation.page_content, page_numbers=parse_page_number(citation.metadata.get("page_number")), - source=Citation.Origin.USER_UPLOADED_DOCUMENT + source=Citation.Origin.USER_UPLOADED_DOCUMENT, ) else: Citation.objects.create( chat_message=chat_message, url=citation.metadata.get("original_resource_ref"), text=citation.page_content, - source=Citation.Origin(citation.metadata.get("creator_type")) + source=Citation.Origin(citation.metadata.get("creator_type")), ) if self.metadata: @@ -228,7 +227,6 @@ def save_ai_message( model_name=model, token_count=token_count, ) - for model, token_count in self.metadata.output_tokens.items(): ChatMessageTokenUse.objects.create( chat_message=chat_message, @@ -275,9 +273,7 @@ async def handle_documents(self, response: list[Document]): files = File.objects.filter(original_file__in=sources_by_resource_ref.keys()) async for file in files: await self.send_to_client("source", {"url": str(file.url), "original_file_name": file.original_file_name}) - self.citations.append( - (file, sources_by_resource_ref[file.unique_name]) - ) + self.citations.append((file, sources_by_resource_ref[file.unique_name])) handled_sources.add(file.unique_name) additional_sources = [doc for doc in response if doc.metadata["original_resource_ref"] not in handled_sources] diff --git a/redbox-core/.vscode/settings.json b/redbox-core/.vscode/settings.json index cf3112f71..a629cea29 100644 --- a/redbox-core/.vscode/settings.json +++ b/redbox-core/.vscode/settings.json @@ -11,6 +11,7 @@ "python.testing.pytestEnabled": true, "python.testing.pytestArgs": [ ".", + "-v", "-m not ( ai )" ], "python.testing.pytestPath": "venv/bin/python -m pytest" diff --git a/redbox-core/poetry.lock b/redbox-core/poetry.lock index 195d2cba5..db9394003 100644 --- a/redbox-core/poetry.lock +++ b/redbox-core/poetry.lock @@ -220,17 +220,17 @@ lxml = ["lxml"] [[package]] name = "boto3" -version = "1.34.162" +version = "1.35.45" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.34.162-py3-none-any.whl", hash = "sha256:d6f6096bdab35a0c0deff469563b87d184a28df7689790f7fe7be98502b7c590"}, - {file = "boto3-1.34.162.tar.gz", hash = "sha256:873f8f5d2f6f85f1018cbb0535b03cceddc7b655b61f66a0a56995238804f41f"}, + {file = "boto3-1.35.45-py3-none-any.whl", hash = "sha256:f16c7edfcbbeb0a0c22d67d6ebbfcb332fa78d3ea88275e082260ba04fe65347"}, + {file = "boto3-1.35.45.tar.gz", hash = "sha256:9f4a081e1940846171b51d903000a04322f1356d53225ce1028fc1760a155a70"}, ] [package.dependencies] -botocore = ">=1.34.162,<1.35.0" +botocore = ">=1.35.45,<1.36.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -655,13 +655,13 @@ xray = ["mypy-boto3-xray (>=1.35.0,<1.36.0)"] [[package]] name = "botocore" -version = "1.34.162" +version = "1.35.45" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.34.162-py3-none-any.whl", hash = "sha256:2d918b02db88d27a75b48275e6fb2506e9adaaddbec1ffa6a8a0898b34e769be"}, - {file = "botocore-1.34.162.tar.gz", hash = "sha256:adc23be4fb99ad31961236342b7cbf3c0bfc62532cd02852196032e8c0d682f3"}, + {file = "botocore-1.35.45-py3-none-any.whl", hash = "sha256:e07e170975721c94ec1e3bf71a484552ad63e2499f769dd14f9f37375b4993fd"}, + {file = "botocore-1.35.45.tar.gz", hash = "sha256:9a898bfdd6b0027fee2018711192c15c2716bf6a7096b1168bd8a896df3664a1"}, ] [package.dependencies] @@ -670,7 +670,7 @@ python-dateutil = ">=2.1,<3.0.0" urllib3 = {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""} [package.extras] -crt = ["awscrt (==0.21.2)"] +crt = ["awscrt (==0.22.0)"] [[package]] name = "botocore-stubs" @@ -1100,13 +1100,13 @@ vision = ["Pillow (>=9.4.0)"] [[package]] name = "deepeval" -version = "1.4.3" +version = "1.4.4" description = "The open-source evaluation framework for LLMs." optional = false python-versions = "*" files = [ - {file = "deepeval-1.4.3-py3-none-any.whl", hash = "sha256:ef6bf7c02b40dde2615e3fcd4a0441431c01cb3f1e65205be88f62d02db77679"}, - {file = "deepeval-1.4.3.tar.gz", hash = "sha256:8e9887c498aa660bfc76e97e579ef65a5cad60549eef61d0c74f4b8fa543aa95"}, + {file = "deepeval-1.4.4-py3-none-any.whl", hash = "sha256:97481bf13eb97f934e4b1927017793c54a5528cf73515f47fc6f30da992d7e0b"}, + {file = "deepeval-1.4.4.tar.gz", hash = "sha256:f053a0ae485c4a137ed4c417a7c4564119c534b98e9d3f6de62e943d968020d8"}, ] [package.dependencies] @@ -1878,133 +1878,135 @@ testing = ["matplotlib (>=2.2.5)", "pytest (>=5.0.1)", "pytest-cov (>=3.0.0)"] [[package]] name = "langchain" -version = "0.2.16" +version = "0.3.4" description = "Building applications with LLMs through composability" optional = false -python-versions = "<4.0,>=3.8.1" +python-versions = "<4.0,>=3.9" files = [ - {file = "langchain-0.2.16-py3-none-any.whl", hash = "sha256:8f59ee8b45f268df4b924ea3b9c63e49286efa756d16b3f6a9de5c6e502c36e1"}, - {file = "langchain-0.2.16.tar.gz", hash = "sha256:ffb426a76a703b73ac69abad77cd16eaf03dda76b42cff55572f592d74944166"}, + {file = "langchain-0.3.4-py3-none-any.whl", hash = "sha256:7a1241d9429510d2083c62df0da998a7b2b05c730cd4255b89da9d47c57f48fd"}, + {file = "langchain-0.3.4.tar.gz", hash = "sha256:3596515fcd0157dece6ec96e0240d29f4cf542d91ecffc815d32e35198dfff37"}, ] [package.dependencies] aiohttp = ">=3.8.3,<4.0.0" -langchain-core = ">=0.2.38,<0.3.0" -langchain-text-splitters = ">=0.2.0,<0.3.0" +langchain-core = ">=0.3.12,<0.4.0" +langchain-text-splitters = ">=0.3.0,<0.4.0" langsmith = ">=0.1.17,<0.2.0" numpy = {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""} -pydantic = ">=1,<3" +pydantic = ">=2.7.4,<3.0.0" PyYAML = ">=5.3" requests = ">=2,<3" SQLAlchemy = ">=1.4,<3" -tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0" +tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10" [[package]] name = "langchain-aws" -version = "0.1.18" +version = "0.2.3" description = "An integration package connecting AWS and LangChain" optional = false -python-versions = "<4.0,>=3.8.1" +python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_aws-0.1.18-py3-none-any.whl", hash = "sha256:54c65157c67837cd932a6a8c536b52e2308c4e917bccd187d6f42651943cdd51"}, - {file = "langchain_aws-0.1.18.tar.gz", hash = "sha256:2e59efc95ef6758581b7e769cd610b750c76635c437a4f43c454a14cfab67c64"}, + {file = "langchain_aws-0.2.3-py3-none-any.whl", hash = "sha256:517b946802d94b12f54c8464dae91108f54a93b6679e1b302979a9b50a99229d"}, + {file = "langchain_aws-0.2.3.tar.gz", hash = "sha256:ed6c5dcc1f2e9e814db3107e4968a94680f20b66a110492e6906a05972f8e4ea"}, ] [package.dependencies] -boto3 = ">=1.34.131,<1.35.0" -langchain-core = ">=0.2.33,<0.3" +boto3 = ">=1.34.131" +langchain-core = ">=0.3.2,<0.4" numpy = {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""} +pydantic = ">=2,<3" [[package]] name = "langchain-community" -version = "0.2.17" +version = "0.3.3" description = "Community contributed LangChain integrations." optional = false -python-versions = "<4.0,>=3.8.1" +python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_community-0.2.17-py3-none-any.whl", hash = "sha256:d07c31b641e425fb8c3e7148ad6a62e1b54a9adac6e1173021a7dd3148266063"}, - {file = "langchain_community-0.2.17.tar.gz", hash = "sha256:b0745c1fcf1bd532ed4388f90b47139d6a6c6ba48a87aa68aa32d4d6bb97259d"}, + {file = "langchain_community-0.3.3-py3-none-any.whl", hash = "sha256:319cfc2f923a066c91fbb8e02decd7814018af952b6b98298b8ac9d30ea1da56"}, + {file = "langchain_community-0.3.3.tar.gz", hash = "sha256:bfb3f2b219aed21087e0ecb7d2ebd1c81401c02b92239e11645c822d5be63f80"}, ] [package.dependencies] aiohttp = ">=3.8.3,<4.0.0" dataclasses-json = ">=0.5.7,<0.7" -langchain = ">=0.2.16,<0.3.0" -langchain-core = ">=0.2.39,<0.3.0" -langsmith = ">=0.1.112,<0.2.0" +langchain = ">=0.3.4,<0.4.0" +langchain-core = ">=0.3.12,<0.4.0" +langsmith = ">=0.1.125,<0.2.0" numpy = {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""} +pydantic-settings = ">=2.4.0,<3.0.0" PyYAML = ">=5.3" requests = ">=2,<3" SQLAlchemy = ">=1.4,<3" -tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0" +tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10" [[package]] name = "langchain-core" -version = "0.2.41" +version = "0.3.12" description = "Building applications with LLMs through composability" optional = false -python-versions = "<4.0,>=3.8.1" +python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_core-0.2.41-py3-none-any.whl", hash = "sha256:3278fda5ba9a05defae8bb19f1226032add6aab21917db7b3bc74e750e263e84"}, - {file = "langchain_core-0.2.41.tar.gz", hash = "sha256:bc12032c5a298d85be754ccb129bc13ea21ccb1d6e22f8d7ba18b8da64315bb5"}, + {file = "langchain_core-0.3.12-py3-none-any.whl", hash = "sha256:46050d34f5fa36dc57dca971c6a26f505643dd05ee0492c7ac286d0a78a82037"}, + {file = "langchain_core-0.3.12.tar.gz", hash = "sha256:98a3c078e375786aa84939bfd1111263af2f3bc402bbe2cac9fa18a387459cf2"}, ] [package.dependencies] jsonpatch = ">=1.33,<2.0" -langsmith = ">=0.1.112,<0.2.0" +langsmith = ">=0.1.125,<0.2.0" packaging = ">=23.2,<25" pydantic = [ - {version = ">=1,<3", markers = "python_full_version < \"3.12.4\""}, + {version = ">=2.5.2,<3.0.0", markers = "python_full_version < \"3.12.4\""}, {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, ] PyYAML = ">=5.3" -tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0" +tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10.0.0" typing-extensions = ">=4.7" [[package]] name = "langchain-elasticsearch" -version = "0.2.2" +version = "0.3.0" description = "An integration package connecting Elasticsearch and LangChain" optional = false -python-versions = "<4.0,>=3.8.1" +python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_elasticsearch-0.2.2-py3-none-any.whl", hash = "sha256:2b0ae1637afc6890c371fc67be2151954db3c7f5382a07dafe4cf3bd762c6f26"}, - {file = "langchain_elasticsearch-0.2.2.tar.gz", hash = "sha256:7abe1cdee9f3b1a5f4152d2f79514359d0b95dc6d3923d3adb4ac431be854545"}, + {file = "langchain_elasticsearch-0.3.0-py3-none-any.whl", hash = "sha256:83d26d2a29f7628bcf2c655fe988930c90249e78e6f326c0be218ad4fa96f1f4"}, + {file = "langchain_elasticsearch-0.3.0.tar.gz", hash = "sha256:36c2963c432777e05b6613f57ee477de4b6e4be4ac4c1306601ab64de004a297"}, ] [package.dependencies] elasticsearch = {version = ">=8.13.1,<9.0.0", extras = ["vectorstore-mmr"]} -langchain-core = ">=0.1.50,<0.3" +langchain-core = ">=0.3.0,<0.4.0" [[package]] name = "langchain-openai" -version = "0.1.25" +version = "0.2.3" description = "An integration package connecting OpenAI and LangChain" optional = false -python-versions = "<4.0,>=3.8.1" +python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_openai-0.1.25-py3-none-any.whl", hash = "sha256:f0b34a233d0d9cb8fce6006c903e57085c493c4f0e32862b99063b96eaedb109"}, - {file = "langchain_openai-0.1.25.tar.gz", hash = "sha256:eb116f744f820247a72f54313fb7c01524fba0927120d4e899e5e4ab41ad3928"}, + {file = "langchain_openai-0.2.3-py3-none-any.whl", hash = "sha256:f498c94817c980cb302439b95d3f3275cdf2743e022ee674692c75898523cf57"}, + {file = "langchain_openai-0.2.3.tar.gz", hash = "sha256:e142031704de1104735f503f76352c53b27ac0a2806466392993c4508c42bf0c"}, ] [package.dependencies] -langchain-core = ">=0.2.40,<0.3.0" -openai = ">=1.40.0,<2.0.0" +langchain-core = ">=0.3.12,<0.4.0" +openai = ">=1.52.0,<2.0.0" tiktoken = ">=0.7,<1" [[package]] name = "langchain-text-splitters" -version = "0.2.4" +version = "0.3.0" description = "LangChain text splitting utilities" optional = false -python-versions = "<4.0,>=3.8.1" +python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_text_splitters-0.2.4-py3-none-any.whl", hash = "sha256:2702dee5b7cbdd595ccbe43b8d38d01a34aa8583f4d6a5a68ad2305ae3e7b645"}, - {file = "langchain_text_splitters-0.2.4.tar.gz", hash = "sha256:f7daa7a3b0aa8309ce248e2e2b6fc8115be01118d336c7f7f7dfacda0e89bf29"}, + {file = "langchain_text_splitters-0.3.0-py3-none-any.whl", hash = "sha256:e84243e45eaff16e5b776cd9c81b6d07c55c010ebcb1965deb3d1792b7358e83"}, + {file = "langchain_text_splitters-0.3.0.tar.gz", hash = "sha256:f9fe0b4d244db1d6de211e7343d4abc4aa90295aa22e1f0c89e51f33c55cd7ce"}, ] [package.dependencies] -langchain-core = ">=0.2.38,<0.3.0" +langchain-core = ">=0.3.0,<0.4.0" [[package]] name = "langgraph" @@ -3513,13 +3515,13 @@ files = [ [[package]] name = "ragas" -version = "0.2.1" +version = "0.2.2" description = "" optional = false python-versions = "*" files = [ - {file = "ragas-0.2.1-py3-none-any.whl", hash = "sha256:05364c121dc02ea3f23bf413b6f3fcf6d757a490969ca219eafaf28df8d26f8b"}, - {file = "ragas-0.2.1.tar.gz", hash = "sha256:7c377af9d83442403c660ee47c6b23ffd8902166d151b388d7f26b769a9b1bf7"}, + {file = "ragas-0.2.2-py3-none-any.whl", hash = "sha256:32e22d355db20be2e9a4d78df6b094e6b8f0c967c3f7f489aeaa9e005545601b"}, + {file = "ragas-0.2.2.tar.gz", hash = "sha256:ccec576d635592898eed241af0ce1b7a31c2260665c5fbb1fbb6b787d51dab05"}, ] [package.dependencies] @@ -3716,13 +3718,13 @@ tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asy [[package]] name = "rich" -version = "13.9.2" +version = "13.9.3" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" optional = false python-versions = ">=3.8.0" files = [ - {file = "rich-13.9.2-py3-none-any.whl", hash = "sha256:8c82a3d3f8dcfe9e734771313e606b39d8247bb6b826e196f4914b333b743cf1"}, - {file = "rich-13.9.2.tar.gz", hash = "sha256:51a2c62057461aaf7152b4d611168f93a9fc73068f8ded2790f29fe2b5366d0c"}, + {file = "rich-13.9.3-py3-none-any.whl", hash = "sha256:9836f5096eb2172c9e77df411c1b009bace4193d6a481d534fea75ebba758283"}, + {file = "rich-13.9.3.tar.gz", hash = "sha256:bc1e01b899537598cf02579d2b9f4a415104d3fc439313a7a2c165d76557a08e"}, ] [package.dependencies] @@ -4663,4 +4665,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.12,<3.13" -content-hash = "cb3fd809781cac407ff2a46bdde03c266692b8c71626ec49ebf4ec962439cfb5" +content-hash = "f97d2c485b5ebb17b1ba8525dddc01c1c4adcae87acf13b665674c1ca6466c23" diff --git a/redbox-core/pyproject.toml b/redbox-core/pyproject.toml index d1a2bdace..445391337 100644 --- a/redbox-core/pyproject.toml +++ b/redbox-core/pyproject.toml @@ -14,17 +14,17 @@ readme = "../README.md" python = ">=3.12,<3.13" pydantic = "^2.7.1" elasticsearch = "^8.15.0" -langchain-community = "^0.2.12" -langchain = "^0.2.13" -langchain_openai = "^0.1.21" +langchain-community = ">0.2.12" +langchain = "^0.3.4" +langchain_openai = ">0.1.21" tiktoken = "^0.7.0" boto3 = "^1.34.160" pydantic-settings = "^2.3.4" -langchain-elasticsearch = "^0.2.2" +langchain-elasticsearch = ">0.2.2" pytest-dotenv = "^0.5.2" kneed = "^0.8.5" langgraph = "^0.2.39" -langchain-aws = "^0.1.17" +langchain-aws = ">0.1.17" wikipedia = "^1.4.0" opensearch-py = "^2.7.1" diff --git a/redbox-core/redbox/app.py b/redbox-core/redbox/app.py index 3f59c3caa..5b132e9f9 100644 --- a/redbox-core/redbox/app.py +++ b/redbox-core/redbox/app.py @@ -8,7 +8,10 @@ get_metadata_retriever, get_parameterised_retriever, ) -from redbox.graph.nodes.tools import build_search_documents_tool, build_search_wikipedia_tool +from redbox.graph.nodes.tools import ( + build_search_documents_tool, + build_search_wikipedia_tool, +) from redbox.graph.root import get_root_graph from redbox.models.chain import RedboxState from redbox.models.chat import ChatRoute @@ -82,7 +85,9 @@ async def run( ) -> RedboxState: final_state = None async for event in self.graph.astream_events( - input=input, version="v2", config={"recursion_limit": input["request"].ai_settings.recursion_limit} + input=input, + version="v2", + config={"recursion_limit": input["request"].ai_settings.recursion_limit}, ): kind = event["event"] tags = event.get("tags", []) @@ -119,4 +124,6 @@ def get_available_keywords(self) -> dict[ChatRoute, str]: def draw(self, output_path="RedboxAIArchitecture.png"): from langchain_core.runnables.graph import MermaidDrawMethod - self.graph.get_graph(xray=True).draw_mermaid_png(draw_method=MermaidDrawMethod.API, output_file_path=output_path) + self.graph.get_graph(xray=True).draw_mermaid_png( + draw_method=MermaidDrawMethod.API, output_file_path=output_path + ) diff --git a/redbox-core/redbox/chains/runnables.py b/redbox-core/redbox/chains/runnables.py index 83c3e779b..493b0dec1 100644 --- a/redbox-core/redbox/chains/runnables.py +++ b/redbox-core/redbox/chains/runnables.py @@ -3,13 +3,22 @@ from operator import itemgetter from typing import Any, Callable, Iterable, Iterator -from langchain_core.callbacks.manager import CallbackManagerForLLMRun, dispatch_custom_event +from langchain_core.callbacks.manager import ( + CallbackManagerForLLMRun, + dispatch_custom_event, +) from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.output_parsers import StrOutputParser from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.prompts import ChatPromptTemplate -from langchain_core.runnables import Runnable, RunnableGenerator, RunnableLambda, chain +from langchain_core.runnables import ( + Runnable, + RunnableGenerator, + RunnableLambda, + RunnablePassthrough, + chain, +) from tiktoken import Encoding from redbox.api.format import format_documents, format_toolstate @@ -38,7 +47,11 @@ def _combined(obj): return _combined -def build_chat_prompt_from_messages_runnable(prompt_set: PromptSet, tokeniser: Encoding = None) -> Runnable: +def build_chat_prompt_from_messages_runnable( + prompt_set: PromptSet, + tokeniser: Encoding = None, + partial_variables: dict = None, +) -> Runnable: @chain def _chat_prompt_from_messages(state: RedboxState) -> Runnable: """ @@ -46,6 +59,7 @@ def _chat_prompt_from_messages(state: RedboxState) -> Runnable: Returns the PromptValue using values in the input_dict """ _tokeniser = tokeniser or get_tokeniser() + _partial_variables = partial_variables or dict() system_prompt, question_prompt = get_prompts(state, prompt_set) log.debug("Setting chat prompt") @@ -77,10 +91,13 @@ def _chat_prompt_from_messages(state: RedboxState) -> Runnable: | {"tool_calls": format_toolstate(state.get("tool_calls"))} ) - return ChatPromptTemplate.from_messages( - system_prompt_message - + [(msg["role"], msg["text"]) for msg in truncated_history] - + [("user", question_prompt)] + return ChatPromptTemplate( + messages=( + system_prompt_message + + [(msg["role"], msg["text"]) for msg in truncated_history] + + [("user", question_prompt)] + ), + partial_variables=_partial_variables, ).invoke(prompt_template_context) return _chat_prompt_from_messages @@ -90,6 +107,7 @@ def build_llm_chain( prompt_set: PromptSet, llm: BaseChatModel, output_parser: Runnable | Callable = None, + format_instructions: str | None = None, final_response_chain: bool = False, ) -> Runnable: """Builds a chain that correctly forms a text and metadata state update. @@ -99,36 +117,41 @@ def build_llm_chain( model_name = getattr(llm, "model_name", "unknown-model") _llm = llm.with_config(tags=["response_flag"]) if final_response_chain else llm _output_parser = output_parser if output_parser else StrOutputParser() - return ( - build_chat_prompt_from_messages_runnable(prompt_set) + build_chat_prompt_from_messages_runnable(prompt_set, partial_variables={"format_arg": format_instructions}) | { "text_and_tools": ( _llm | { - "text": _output_parser, + "parsed_response": _output_parser, "tool_calls": (RunnableLambda(lambda r: r.tool_calls) | tool_calls_to_toolstate), } ), "prompt": RunnableLambda(lambda prompt: prompt.to_string()), } | { - "text": combine_getters(itemgetter("text_and_tools"), itemgetter("text")), + "text": RunnableLambda(combine_getters(itemgetter("text_and_tools"), itemgetter("parsed_response"))) + | (lambda r: r if isinstance(r, str) else r.markdown_answer), "tool_calls": combine_getters(itemgetter("text_and_tools"), itemgetter("tool_calls")), - "metadata": ( - { - "prompt": itemgetter("prompt"), - "response": combine_getters(itemgetter("text_and_tools"), itemgetter("text")), - "model": lambda _: model_name, - } - | to_request_metadata - ), + "citations": RunnableLambda(combine_getters(itemgetter("text_and_tools"), itemgetter("parsed_response"))) + | (lambda r: [] if isinstance(r, str) else r.citations), + "prompt": itemgetter("prompt"), } + | RunnablePassthrough.assign( + metadata={ + "prompt": itemgetter("prompt"), + "response": itemgetter("text"), + "model": lambda _: model_name, + } + | to_request_metadata + ) ) def build_self_route_output_parser( - match_condition: Callable[[str], bool], max_tokens_to_check: int, final_response_chain: bool = False + match_condition: Callable[[str], bool], + max_tokens_to_check: int, + final_response_chain: bool = False, ) -> Runnable[Iterable[AIMessageChunk], Iterable[str]]: """ This Runnable reads the streamed responses from an LLM until the match diff --git a/redbox-core/redbox/graph/nodes/processes.py b/redbox-core/redbox/graph/nodes/processes.py index 57e129ed3..afb058cf6 100644 --- a/redbox-core/redbox/graph/nodes/processes.py +++ b/redbox-core/redbox/graph/nodes/processes.py @@ -17,8 +17,19 @@ from redbox.chains.runnables import CannedChatLLM, build_llm_chain from redbox.graph.nodes.tools import has_injected_state, is_valid_tool from redbox.models import ChatRoute -from redbox.models.chain import DocumentState, PromptSet, RedboxState, RequestMetadata, merge_redbox_state_updates -from redbox.models.graph import ROUTE_NAME_TAG, SOURCE_DOCUMENTS_TAG, RedboxActivityEvent, RedboxEventType +from redbox.models.chain import ( + DocumentState, + PromptSet, + RedboxState, + RequestMetadata, + merge_redbox_state_updates, +) +from redbox.models.graph import ( + ROUTE_NAME_TAG, + SOURCE_DOCUMENTS_TAG, + RedboxActivityEvent, + RedboxEventType, +) from redbox.transform import combine_documents, flatten_document_state log = logging.getLogger(__name__) @@ -133,6 +144,7 @@ def _merge(state: RedboxState) -> dict[str, Any]: def build_stuff_pattern( prompt_set: PromptSet, output_parser: Runnable = None, + format_instructions: str | None = None, tools: list[StructuredTool] | None = None, final_response_chain: bool = False, ) -> Runnable[RedboxState, dict[str, Any]]: @@ -151,6 +163,7 @@ def _stuff(state: RedboxState) -> dict[str, Any]: prompt_set=prompt_set, llm=llm, output_parser=output_parser, + format_instructions=format_instructions, final_response_chain=final_response_chain, ).stream(state) ] @@ -293,6 +306,7 @@ def _tool(state: RedboxState) -> dict[str, Any]: tool_called_state_update = {"tool_calls": {tool_id: {"called": True, "tool": tool_call}}} state_updates.append(result_state_update | tool_called_state_update) except Exception as e: + raise e state_updates.append({"tool_calls": {tool_id: {"called": True, "tool": tool_call}}}) log.warning(f"Error invoking tool {tool_call['name']}: {e} \n") return {} @@ -337,7 +351,7 @@ def _log_node(state: RedboxState): group_id: {doc_id: d.metadata for doc_id, d in group_documents.items()} for group_id, group_documents in state["documents"] }, - "text": state["text"] if len(state["text"]) < 32 else f"{state['text'][:29]}...", + "text": (state["text"] if len(state["text"]) < 32 else f"{state['text'][:29]}..."), "route": state["route_name"], "message": message, } diff --git a/redbox-core/redbox/graph/nodes/tools.py b/redbox-core/redbox/graph/nodes/tools.py index c94a929f2..adaef5e7f 100644 --- a/redbox-core/redbox/graph/nodes/tools.py +++ b/redbox-core/redbox/graph/nodes/tools.py @@ -1,17 +1,25 @@ from typing import Annotated, Any, get_args, get_origin, get_type_hints +import tiktoken from elasticsearch import Elasticsearch +from langchain_community.utilities import WikipediaAPIWrapper +from langchain_core.documents import Document from langchain_core.embeddings.embeddings import Embeddings from langchain_core.tools import StructuredTool, Tool, tool -from langchain_core.documents import Document from langgraph.prebuilt import InjectedState -from langchain_community.utilities import WikipediaAPIWrapper -import tiktoken +from redbox.models.chain import RedboxState from redbox.models.file import ChunkMetadata, ChunkResolution -from redbox.retriever.queries import add_document_filter_scores_to_query, build_document_query +from redbox.retriever.queries import ( + add_document_filter_scores_to_query, + build_document_query, +) from redbox.retriever.retrievers import query_to_documents -from redbox.transform import merge_documents, sort_documents, structure_documents_by_group_and_indices +from redbox.transform import ( + merge_documents, + sort_documents, + structure_documents_by_group_and_indices, +) def is_valid_tool(tool: StructuredTool) -> bool: @@ -70,7 +78,7 @@ def build_search_documents_tool( """Constructs a tool that searches the index and sets state["documents"].""" @tool - def _search_documents(query: str, state: Annotated[dict, InjectedState]) -> dict[str, Any]: + def _search_documents(query: str, state: Annotated[RedboxState, InjectedState]) -> dict[str, Any]: """ Search for documents uploaded by the user based on a query string. @@ -128,12 +136,13 @@ def _search_documents(query: str, state: Annotated[dict, InjectedState]) -> dict def build_search_wikipedia_tool(number_wikipedia_results=1, max_chars_per_wiki_page=12000) -> Tool: """Constructs a tool that searches Wikipedia""" _wikipedia_wrapper = WikipediaAPIWrapper( - top_k_results=number_wikipedia_results, doc_content_chars_max=max_chars_per_wiki_page + top_k_results=number_wikipedia_results, + doc_content_chars_max=max_chars_per_wiki_page, ) tokeniser = tiktoken.encoding_for_model("gpt-4o") @tool - def _search_wikipedia(query: str, state: Annotated[dict, InjectedState]) -> dict[str, Any]: + def _search_wikipedia(query: str, state: Annotated[RedboxState, InjectedState]) -> dict[str, Any]: """ Search Wikipedia for information about the queried entity. Useful for when you need to answer general questions about people, places, objects, companies, facts, historical events, or other subjects. diff --git a/redbox-core/redbox/graph/root.py b/redbox-core/redbox/graph/root.py index 238cd8003..8f27dbc02 100644 --- a/redbox-core/redbox/graph/root.py +++ b/redbox-core/redbox/graph/root.py @@ -1,9 +1,9 @@ +from langchain.output_parsers import PydanticOutputParser from langchain_core.tools import StructuredTool from langchain_core.vectorstores import VectorStoreRetriever from langgraph.graph import END, START, StateGraph from langgraph.graph.graph import CompiledGraph - from redbox.chains.runnables import build_self_route_output_parser from redbox.graph.edges import ( build_documents_bigger_than_context_conditional, @@ -31,11 +31,18 @@ empty_process, report_sources_process, ) -from redbox.graph.nodes.sends import build_document_chunk_send, build_document_group_send, build_tool_send -from redbox.models.chain import RedboxState +from redbox.graph.nodes.sends import ( + build_document_chunk_send, + build_document_group_send, + build_tool_send, +) +from redbox.models.chain import LLM_Response, RedboxState from redbox.models.chat import ChatRoute, ErrorRoute from redbox.models.graph import ROUTABLE_KEYWORDS, RedboxActivityEvent -from redbox.transform import structure_documents_by_file_name, structure_documents_by_group_and_indices +from redbox.transform import ( + structure_documents_by_file_name, + structure_documents_by_group_and_indices, +) # Subgraphs @@ -51,7 +58,9 @@ def self_route_question_is_unanswerable(llm_response: str): builder.add_node( "p_retrieve_docs", build_retrieve_pattern( - retriever=retriever, structure_func=structure_documents_by_file_name, final_source_chain=False + retriever=retriever, + structure_func=structure_documents_by_file_name, + final_source_chain=False, ), ) builder.add_node( @@ -59,7 +68,9 @@ def self_route_question_is_unanswerable(llm_response: str): build_stuff_pattern( prompt_set=prompt_set, output_parser=build_self_route_output_parser( - match_condition=self_route_question_is_unanswerable, max_tokens_to_check=4, final_response_chain=True + match_condition=self_route_question_is_unanswerable, + max_tokens_to_check=4, + final_response_chain=True, ), final_response_chain=False, ), @@ -82,7 +93,10 @@ def self_route_question_is_unanswerable(llm_response: str): builder.add_conditional_edges( "p_set_route_name_from_answer", lambda state: state["route_name"], - {ChatRoute.chat_with_docs_map_reduce: "p_clear_documents", ChatRoute.search: END}, + { + ChatRoute.chat_with_docs_map_reduce: "p_clear_documents", + ChatRoute.search: END, + }, ) builder.add_edge("p_clear_documents", END) @@ -97,7 +111,10 @@ def get_chat_graph( # Processes builder.add_node("p_set_chat_route", build_set_route_pattern(route=ChatRoute.chat)) - builder.add_node("p_chat", build_chat_pattern(prompt_set=PromptSet.Chat, final_response_chain=True)) + builder.add_node( + "p_chat", + build_chat_pattern(prompt_set=PromptSet.Chat, final_response_chain=True), + ) # Edges builder.add_edge(START, "p_set_chat_route") @@ -128,7 +145,10 @@ def get_search_graph( final_source_chain=final_sources, ), ) - builder.add_node("p_stuff_docs", build_stuff_pattern(prompt_set=prompt_set, final_response_chain=final_response)) + builder.add_node( + "p_stuff_docs", + build_stuff_pattern(prompt_set=prompt_set, final_response_chain=final_response), + ) # Edges builder.add_edge(START, "p_set_search_route") @@ -142,6 +162,8 @@ def get_search_graph( def get_agentic_search_graph(tools: dict[str, StructuredTool], debug: bool = False) -> CompiledGraph: """Creates a subgraph for agentic RAG.""" + + citations_output_parser = PydanticOutputParser(pydantic_object=LLM_Response) builder = StateGraph(RedboxState) # Tools agent_tool_names = ["_search_documents", "_search_wikipedia"] @@ -153,10 +175,18 @@ def get_agentic_search_graph(tools: dict[str, StructuredTool], debug: bool = Fal "p_search_agent", build_stuff_pattern(prompt_set=PromptSet.SearchAgentic, tools=agent_tools), ) - builder.add_node("p_retrieval_tools", build_tool_pattern(tools=agent_tools, final_source_chain=False)) + builder.add_node( + "p_retrieval_tools", + build_tool_pattern(tools=agent_tools, final_source_chain=False), + ) builder.add_node( "p_stuff_docs_agent", - build_stuff_pattern(prompt_set=PromptSet.Search, final_response_chain=True), + build_stuff_pattern( + prompt_set=PromptSet.Search, + final_response_chain=True, + output_parser=citations_output_parser, + format_instructions=citations_output_parser.get_format_instructions(), + ), ) builder.add_node( "p_give_up_agent", @@ -194,7 +224,11 @@ def get_agentic_search_graph(tools: dict[str, StructuredTool], debug: bool = Fal builder.add_conditional_edges( "d_answer_or_give_up", build_strings_end_text_conditional("answer", "give_up"), - {"answer": "p_stuff_docs_agent", "give_up": "p_give_up_agent", "DEFAULT": "d_x_steps_left_or_less"}, + { + "answer": "p_stuff_docs_agent", + "give_up": "p_give_up_agent", + "DEFAULT": "d_x_steps_left_or_less", + }, ) builder.add_edge("p_stuff_docs_agent", "p_report_sources") builder.add_edge("p_give_up_agent", "p_report_sources") @@ -217,11 +251,16 @@ def get_chat_with_documents_graph( builder.add_node("p_pass_question_to_text", build_passthrough_pattern()) builder.add_node("p_set_chat_docs_route", build_set_route_pattern(route=ChatRoute.chat_with_docs)) builder.add_node( - "p_set_chat_docs_map_reduce_route", build_set_route_pattern(route=ChatRoute.chat_with_docs_map_reduce) + "p_set_chat_docs_map_reduce_route", + build_set_route_pattern(route=ChatRoute.chat_with_docs_map_reduce), + ) + builder.add_node( + "p_summarise_each_document", + build_merge_pattern(prompt_set=PromptSet.ChatwithDocsMapReduce), ) - builder.add_node("p_summarise_each_document", build_merge_pattern(prompt_set=PromptSet.ChatwithDocsMapReduce)) builder.add_node( - "p_summarise_document_by_document", build_merge_pattern(prompt_set=PromptSet.ChatwithDocsMapReduce) + "p_summarise_document_by_document", + build_merge_pattern(prompt_set=PromptSet.ChatwithDocsMapReduce), ) builder.add_node( "p_summarise", @@ -233,13 +272,21 @@ def get_chat_with_documents_graph( builder.add_node("p_clear_documents", clear_documents_process) builder.add_node( "p_too_large_error", - build_error_pattern(text="These documents are too large to work with.", route_name=ErrorRoute.files_too_large), + build_error_pattern( + text="These documents are too large to work with.", + route_name=ErrorRoute.files_too_large, + ), + ) + builder.add_node( + "p_answer_or_decide_route", + get_self_route_graph(parameterised_retriever, PromptSet.SelfRoute), ) - builder.add_node("p_answer_or_decide_route", get_self_route_graph(parameterised_retriever, PromptSet.SelfRoute)) builder.add_node( "p_retrieve_all_chunks", build_retrieve_pattern( - retriever=all_chunks_retriever, structure_func=structure_documents_by_file_name, final_source_chain=True + retriever=all_chunks_retriever, + structure_func=structure_documents_by_file_name, + final_source_chain=True, ), ) @@ -291,10 +338,15 @@ def get_chat_with_documents_graph( builder.add_conditional_edges( "p_retrieve_all_chunks", lambda s: s["route_name"], - {ChatRoute.chat_with_docs: "p_summarise", ChatRoute.chat_with_docs_map_reduce: "s_chunk"}, + { + ChatRoute.chat_with_docs: "p_summarise", + ChatRoute.chat_with_docs_map_reduce: "s_chunk", + }, ) builder.add_conditional_edges( - "s_chunk", build_document_chunk_send("p_summarise_each_document"), path_map=["p_summarise_each_document"] + "s_chunk", + build_document_chunk_send("p_summarise_each_document"), + path_map=["p_summarise_each_document"], ) builder.add_edge("p_summarise_each_document", "d_groups_have_multiple_docs") builder.add_conditional_edges( @@ -345,7 +397,10 @@ def get_retrieve_metadata_graph(metadata_retriever: VectorStoreRetriever, debug: # Processes builder.add_node( "p_retrieve_metadata", - build_retrieve_pattern(retriever=metadata_retriever, structure_func=structure_documents_by_file_name), + build_retrieve_pattern( + retriever=metadata_retriever, + structure_func=structure_documents_by_file_name, + ), ) builder.add_node("p_set_metadata", build_set_metadata_pattern()) builder.add_node("p_clear_metadata_documents", clear_documents_process) @@ -375,7 +430,9 @@ def get_root_graph( rag_subgraph = get_search_graph(retriever=parameterised_retriever, debug=debug) agent_subgraph = get_agentic_search_graph(tools=tools, debug=debug) cwd_subgraph = get_chat_with_documents_graph( - all_chunks_retriever=all_chunks_retriever, parameterised_retriever=parameterised_retriever, debug=debug + all_chunks_retriever=all_chunks_retriever, + parameterised_retriever=parameterised_retriever, + debug=debug, ) metadata_subgraph = get_retrieve_metadata_graph(metadata_retriever=metadata_retriever, debug=debug) @@ -396,7 +453,11 @@ def get_root_graph( builder.add_conditional_edges( "d_keyword_exists", build_keyword_detection_conditional(*ROUTABLE_KEYWORDS.keys()), - {ChatRoute.search: "p_search", ChatRoute.gadget: "p_search_agentic", "DEFAULT": "d_docs_selected"}, + { + ChatRoute.search: "p_search", + ChatRoute.gadget: "p_search_agentic", + "DEFAULT": "d_docs_selected", + }, ) builder.add_conditional_edges( "d_docs_selected", diff --git a/redbox-core/redbox/models/chain.py b/redbox-core/redbox/models/chain.py index cc68f02eb..06b3f55e4 100644 --- a/redbox-core/redbox/models/chain.py +++ b/redbox-core/redbox/models/chain.py @@ -1,19 +1,20 @@ -""" -There is some repeated definition and non-pydantic style code in here. -These classes are pydantic v1 which is compatible with langchain tools classes, we need -to provide a pydantic v1 definition to work with these. As these models are mostly -used in conjunction with langchain this is the tidiest boxing of pydantic v1 we can do -""" - from datetime import UTC, datetime from enum import StrEnum from functools import reduce -from typing import Annotated, Literal, NotRequired, Required, TypedDict, get_args, get_origin +from typing import ( + Annotated, + Literal, + NotRequired, + Required, + TypedDict, + get_args, + get_origin, +) from uuid import UUID, uuid4 from langchain_core.documents import Document from langchain_core.messages import ToolCall -from langgraph.managed.is_last_step import RemainingSteps +from langgraph.managed.is_last_step import RemainingStepsManager from pydantic import BaseModel, Field from redbox.models import prompts @@ -82,6 +83,44 @@ class AISettings(BaseModel): chat_backend: ChatLLMBackend = ChatLLMBackend() +class Source(BaseModel): + source: str = Field(description="URL or reference to the source") + last_edited: str = "" + document_name: str = "" + highlighted_text: str = "" + page_no: str = Field(description="") + + +class Citation(BaseModel): + text: str + sources: list[Source] + + +class LLM_Response(BaseModel): + markdown_answer: str + citations: list[Citation] + + @classmethod + def model_json_schema(self): + return { + "markdown_answer": "Hello Kitty is a fictional character from Japan.", + "citations": [ + { + "text": "Hello Kitty is a fictional character from Japan.", + "sources": [ + { + "source": "https://en.wikipedia.org/wiki/Hello_Kitty", + "last_edited": "4 October 2024", + "document_name": "Hello Kitty", + "highlighted_text": "Hello Kitty (Japanese: ハロー・キティ, Hepburn: Harō Kiti),[6] also known by her real name Kitty White (キティ・ホワイト, Kiti Howaito),[5] is a fictional character created by Yuko Shimizu", + "page_no": "1", + } + ], + } + ], + } + + class DocumentState(TypedDict): group: dict[UUID, Document] @@ -156,7 +195,7 @@ class LLMCallMetadata(BaseModel): class RequestMetadata(BaseModel): - llm_calls: set[LLMCallMetadata] = Field(default_factory=set) + llm_calls: list[LLMCallMetadata] = Field(default_factory=list) selected_files_total_tokens: int = 0 number_of_selected_files: int = 0 @@ -179,7 +218,10 @@ def output_tokens(self): return tokens_by_model -def metadata_reducer(current: RequestMetadata | None, update: RequestMetadata | list[RequestMetadata] | None): +def metadata_reducer( + current: RequestMetadata | None, + update: RequestMetadata | list[RequestMetadata] | None, +): """Merges two metadata states.""" # If update is actually a list of state updates, run them one by one if isinstance(update, list): @@ -192,7 +234,7 @@ def metadata_reducer(current: RequestMetadata | None, update: RequestMetadata | return current return RequestMetadata( - llm_calls=current.llm_calls | update.llm_calls, + llm_calls=sorted(list(set(current.llm_calls) | set(update.llm_calls)), key=lambda c: c.timestamp), selected_files_total_tokens=update.selected_files_total_tokens or current.selected_files_total_tokens, number_of_selected_files=update.number_of_selected_files or current.number_of_selected_files, ) @@ -242,7 +284,8 @@ class RedboxState(TypedDict): route_name: NotRequired[str | None] tool_calls: Annotated[NotRequired[ToolState], tool_calls_reducer] metadata: Annotated[NotRequired[RequestMetadata], metadata_reducer] - steps_left: RemainingSteps + citations: NotRequired[list[Citation] | None] + steps_left: Annotated[NotRequired[int], RemainingStepsManager] class PromptSet(StrEnum): diff --git a/redbox-core/redbox/models/prompts.py b/redbox-core/redbox/models/prompts.py index 4b36254e3..91a57dd81 100644 --- a/redbox-core/redbox/models/prompts.py +++ b/redbox-core/redbox/models/prompts.py @@ -2,6 +2,7 @@ "You are an AI assistant called Redbox tasked with answering questions and providing information objectively." ) + CHAT_WITH_DOCS_SYSTEM_PROMPT = "You are an AI assistant called Redbox tasked with answering questions on user provided documents and providing information objectively." CHAT_WITH_DOCS_REDUCE_SYSTEM_PROMPT = ( @@ -15,13 +16,12 @@ ) RETRIEVAL_SYSTEM_PROMPT = ( - "Given the following conversation and extracted parts of a long document and a question, create a final answer. \n" - "If you don't know the answer, just say that you don't know. Don't try to make up an answer. " - "If a user asks for a particular format to be returned, such as bullet points, then please use that format. " - "If a user asks for bullet points you MUST give bullet points. " - "If the user asks for a specific number or range of bullet points you MUST give that number of bullet points. \n" - "Use **bold** to highlight the most question relevant parts in your response. " - "If dealing dealing with lots of data return it in markdown table format. " + "You are a specialized GPT-4o agent. Your task is to answer user queries with reliable sources.\n" + "**You must provide the sources where you use the information to answer.**\n" + "- If the information is found in the provided documents, state that in the `text` field and provide citation details in the `citation` field.\n" + "- If the information is not found in the provided documents, state that in the `text` field (e.g., 'The response to your question is not available in provided documents, however, this is what I found' and then provide your response.)\n" + "\n" + "{format_arg}" ) AGENTIC_RETRIEVAL_SYSTEM_PROMPT = ( diff --git a/redbox-core/redbox/test/data.py b/redbox-core/redbox/test/data.py index d933a1d63..5f0531899 100644 --- a/redbox-core/redbox/test/data.py +++ b/redbox-core/redbox/test/data.py @@ -65,14 +65,15 @@ class Config: number_of_docs: int tokens_in_all_docs: int chunk_resolution: ChunkResolution = ChunkResolution.largest - expected_llm_response: list[str | AIMessage] = Field(default_factory=list) + llm_responses: list[str | AIMessage] = Field(default_factory=list) + expected_text: str | None = None expected_route: ChatRoute | ErrorRoute | None = None expected_activity_events: Callable[[list[RedboxActivityEvent]], bool] = Field( default=lambda _: True ) # Function to check activity events are as expected s3_keys: list[str] | None = None - @validator("expected_llm_response", pre=True) + @validator("llm_responses", pre=True) @classmethod def coerce_to_aimessage(cls, value: str | AIMessage): coerced: list[AIMessage] = [] @@ -94,10 +95,7 @@ def __init__( # Use separate file_uuids if specified else match the query all_s3_keys = test_data.s3_keys if test_data.s3_keys else query.s3_keys - if ( - test_data.expected_llm_response is not None - and len(test_data.expected_llm_response) < test_data.number_of_docs - ): + if test_data.llm_responses is not None and len(test_data.llm_responses) < test_data.number_of_docs: log.warning( "Number of configured LLM responses might be less than number of docs. For Map-Reduce actions this will give a Generator Error!" ) diff --git a/redbox-core/redbox/transform.py b/redbox-core/redbox/transform.py index 78ec80935..4793967b2 100644 --- a/redbox-core/redbox/transform.py +++ b/redbox-core/redbox/transform.py @@ -7,7 +7,13 @@ from langchain_core.messages import ToolCall from langchain_core.runnables import RunnableLambda -from redbox.models.chain import DocumentState, LLMCallMetadata, RedboxState, RequestMetadata, ToolState +from redbox.models.chain import ( + DocumentState, + LLMCallMetadata, + RedboxState, + RequestMetadata, + ToolState, +) from redbox.models.graph import RedboxEventType @@ -224,7 +230,10 @@ def process_group(group: list[Document]) -> list[list[Document]]: return sorted_blocks # Step 1: Sort by file_name and then index to prepare for grouping consecutive documents - documents_sorted = sorted(documents, key=lambda d: (d.metadata["original_resource_ref"], d.metadata["index"])) + documents_sorted = sorted( + documents, + key=lambda d: (d.metadata["original_resource_ref"], d.metadata["index"]), + ) # Step 2: Group by file_name and handle consecutive indices grouped_by_file = itertools.groupby(documents_sorted, key=lambda d: d.metadata["original_resource_ref"]) diff --git a/redbox-core/tests/graph/test_app.py b/redbox-core/tests/graph/test_app.py index a5daf1d98..84ed909ea 100644 --- a/redbox-core/tests/graph/test_app.py +++ b/redbox-core/tests/graph/test_app.py @@ -10,7 +10,7 @@ from tiktoken.core import Encoding from redbox import Redbox -from redbox.models.chain import AISettings, RedboxQuery, RedboxState, RequestMetadata, metadata_reducer +from redbox.models.chain import AISettings, LLM_Response, RedboxQuery, RedboxState, RequestMetadata, metadata_reducer from redbox.models.chat import ChatRoute, ErrorRoute from redbox.models.file import ChunkResolution from redbox.models.graph import RedboxActivityEvent @@ -47,19 +47,19 @@ def assert_number_of_events(num_of_events: int): RedboxTestData( number_of_docs=0, tokens_in_all_docs=0, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat, ), RedboxTestData( number_of_docs=1, tokens_in_all_docs=100, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat, ), RedboxTestData( number_of_docs=10, tokens_in_all_docs=1200, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat, ), ], @@ -73,19 +73,19 @@ def assert_number_of_events(num_of_events: int): RedboxTestData( number_of_docs=1, tokens_in_all_docs=1_000, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat_with_docs, ), RedboxTestData( number_of_docs=1, tokens_in_all_docs=50_000, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat_with_docs, ), RedboxTestData( number_of_docs=1, tokens_in_all_docs=80_000, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat_with_docs, ), ], @@ -104,26 +104,26 @@ def assert_number_of_events(num_of_events: int): RedboxTestData( number_of_docs=2, tokens_in_all_docs=40_000, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat_with_docs, ), RedboxTestData( number_of_docs=2, tokens_in_all_docs=80_000, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat_with_docs, ), RedboxTestData( number_of_docs=2, tokens_in_all_docs=140_000, - expected_llm_response=SELF_ROUTE_TO_CHAT + ["Map Step Response"] * 2 + ["Testing Response 1"], + llm_responses=SELF_ROUTE_TO_CHAT + ["Map Step Response"] * 2 + ["Testing Response 1"], expected_route=ChatRoute.chat_with_docs_map_reduce, expected_activity_events=assert_number_of_events(1), ), RedboxTestData( number_of_docs=4, tokens_in_all_docs=140_000, - expected_llm_response=SELF_ROUTE_TO_CHAT + llm_responses=SELF_ROUTE_TO_CHAT + ["Map Step Response"] * 4 + ["Merge Per Document Response"] * 2 + ["Testing Response 1"], @@ -141,25 +141,25 @@ def assert_number_of_events(num_of_events: int): RedboxTestData( number_of_docs=2, tokens_in_all_docs=40_000, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat_with_docs, ), RedboxTestData( number_of_docs=2, tokens_in_all_docs=80_000, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat_with_docs, ), RedboxTestData( number_of_docs=2, tokens_in_all_docs=140_000, - expected_llm_response=["Map Step Response"] * 2 + ["Testing Response 1"], + llm_responses=["Map Step Response"] * 2 + ["Testing Response 1"], expected_route=ChatRoute.chat_with_docs_map_reduce, ), RedboxTestData( number_of_docs=4, tokens_in_all_docs=140_000, - expected_llm_response=["Map Step Response"] * 4 + llm_responses=["Map Step Response"] * 4 + ["Merge Per Document Response"] * 2 + ["Testing Response 1"], expected_route=ChatRoute.chat_with_docs_map_reduce, @@ -179,9 +179,7 @@ def assert_number_of_events(num_of_events: int): RedboxTestData( number_of_docs=2, tokens_in_all_docs=200_000, - expected_llm_response=["Map Step Response"] * 2 - + ["Merge Per Document Response"] - + ["Testing Response 1"], + llm_responses=["Map Step Response"] * 2 + ["Merge Per Document Response"] + ["Testing Response 1"], expected_route=ChatRoute.chat_with_docs_map_reduce, ), ], @@ -199,7 +197,7 @@ def assert_number_of_events(num_of_events: int): RedboxTestData( number_of_docs=2, tokens_in_all_docs=200_000, - expected_llm_response=SELF_ROUTE_TO_CHAT + llm_responses=SELF_ROUTE_TO_CHAT + ["Map Step Response"] * 2 + ["Merge Per Document Response"] + ["Testing Response 1"], @@ -223,7 +221,7 @@ def assert_number_of_events(num_of_events: int): number_of_docs=2, tokens_in_all_docs=200_000, chunk_resolution=ChunkResolution.normal, - expected_llm_response=SELF_ROUTE_TO_SEARCH, # + ["Condense Question", "Testing Response 1"], + llm_responses=SELF_ROUTE_TO_SEARCH, # + ["Condense Question", "Testing Response 1"], expected_route=ChatRoute.search, expected_activity_events=assert_number_of_events(1), ), @@ -242,7 +240,7 @@ def assert_number_of_events(num_of_events: int): RedboxTestData( number_of_docs=10, tokens_in_all_docs=2_000_000, - expected_llm_response=["These documents are too large to work with."], + llm_responses=["These documents are too large to work with."], expected_route=ErrorRoute.files_too_large, ), ], @@ -260,13 +258,13 @@ def assert_number_of_events(num_of_events: int): RedboxTestData( number_of_docs=1, tokens_in_all_docs=10000, - expected_llm_response=["Condense response", "The cake is a lie"], + llm_responses=["Condense response", "The cake is a lie"], expected_route=ChatRoute.search, ), RedboxTestData( number_of_docs=5, tokens_in_all_docs=10000, - expected_llm_response=["Condense response", "The cake is a lie"], + llm_responses=["Condense response", "The cake is a lie"], expected_route=ChatRoute.search, ), ], @@ -284,7 +282,7 @@ def assert_number_of_events(num_of_events: int): RedboxTestData( number_of_docs=1, tokens_in_all_docs=10000, - expected_llm_response=["Condense response", "The cake is a lie"], + llm_responses=["Condense response", "The cake is a lie"], expected_route=ChatRoute.search, s3_keys=["s3_key"], ), @@ -303,7 +301,7 @@ def assert_number_of_events(num_of_events: int): RedboxTestData( number_of_docs=1, tokens_in_all_docs=10000, - expected_llm_response=[ + llm_responses=[ AIMessage( content="", additional_kwargs={ @@ -317,14 +315,15 @@ def assert_number_of_events(num_of_events: int): }, ), "answer", - "AI is a lie", + LLM_Response(markdown_answer="AI is a lie", citations=[]).model_dump_json(), ], + expected_text="AI is a lie", expected_route=ChatRoute.gadget, ), RedboxTestData( number_of_docs=1, tokens_in_all_docs=10000, - expected_llm_response=[ + llm_responses=[ AIMessage( content="", additional_kwargs={ @@ -357,7 +356,7 @@ def assert_number_of_events(num_of_events: int): RedboxTestData( number_of_docs=1, tokens_in_all_docs=10000, - expected_llm_response=[ + llm_responses=[ AIMessage( content="", additional_kwargs={ @@ -371,8 +370,9 @@ def assert_number_of_events(num_of_events: int): }, ), "answer", - "AI is a lie", + LLM_Response(markdown_answer="AI is a lie", citations=[]).model_dump_json(), ], + expected_text="AI is a lie", expected_route=ChatRoute.gadget, s3_keys=["s3_key"], ), @@ -391,7 +391,7 @@ def assert_number_of_events(num_of_events: int): RedboxTestData( number_of_docs=10, tokens_in_all_docs=1000, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat, ), ], @@ -409,7 +409,7 @@ def assert_number_of_events(num_of_events: int): RedboxTestData( number_of_docs=1, tokens_in_all_docs=50_000, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat_with_docs, ), ], @@ -427,7 +427,7 @@ async def test_streaming(test: RedboxChatTestCase, env: Settings, mocker: Mocker test_case = copy.deepcopy(test) # Mock the LLM and relevant tools - llm = GenericFakeChatModelWithTools(messages=iter(test_case.test_data.expected_llm_response)) + llm = GenericFakeChatModelWithTools(messages=iter(test_case.test_data.llm_responses)) @tool def _search_documents(query: str) -> dict[str, Any]: @@ -490,8 +490,8 @@ async def documents_response_handler(documents: list[Document]): if not route_name.startswith("error"): assert len(token_events) > 1, f"Expected tokens as a stream. Received: {token_events}" assert len(metadata_events) == len( - test_case.test_data.expected_llm_response - ), f"Expected {len(test_case.test_data.expected_llm_response)} metadata events. Received {len(metadata_events)}" + test_case.test_data.llm_responses + ), f"Expected {len(test_case.test_data.llm_responses)} metadata events. Received {len(metadata_events)}" assert test_case.test_data.expected_activity_events( activity_events @@ -510,7 +510,7 @@ async def documents_response_handler(documents: list[Document]): ) assert ( - final_state["text"] == llm_response + final_state["text"] == test.test_data.expected_text if test.test_data.expected_text else llm_response ), f"Expected LLM response: '{llm_response}'. Received '{final_state["text"]}'" assert ( final_state.get("route_name") == test_case.test_data.expected_route diff --git a/redbox-core/tests/graph/test_patterns.py b/redbox-core/tests/graph/test_patterns.py index ba3859f8e..b8b025d3c 100644 --- a/redbox-core/tests/graph/test_patterns.py +++ b/redbox-core/tests/graph/test_patterns.py @@ -43,13 +43,13 @@ RedboxTestData( number_of_docs=0, tokens_in_all_docs=0, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat, ), RedboxTestData( number_of_docs=2, tokens_in_all_docs=40_000, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat_with_docs, ), ], @@ -60,7 +60,7 @@ @pytest.mark.parametrize(("test_case"), CHAT_PROMPT_TEST_CASES, ids=[t.test_id for t in CHAT_PROMPT_TEST_CASES]) def test_build_chat_prompt_from_messages_runnable(test_case: RedboxChatTestCase, tokeniser: Encoding): """Tests a given state can be turned into a chat prompt.""" - chat_prompt = build_chat_prompt_from_messages_runnable(PromptSet.Chat, tokeniser) + chat_prompt = build_chat_prompt_from_messages_runnable(PromptSet.Chat, tokeniser=tokeniser) state = RedboxState(request=test_case.query, documents=test_case.docs) response = chat_prompt.invoke(state) @@ -75,13 +75,13 @@ def test_build_chat_prompt_from_messages_runnable(test_case: RedboxChatTestCase, RedboxTestData( number_of_docs=2, tokens_in_all_docs=40_000, - expected_llm_response=[AIMessage(content="Testing Response 1")], + llm_responses=[AIMessage(content="Testing Response 1")], expected_route=ChatRoute.chat_with_docs, ), RedboxTestData( number_of_docs=2, tokens_in_all_docs=40_000, - expected_llm_response=[ + llm_responses=[ AIMessage( content="Tool Response 1", tool_calls=[ToolCall(name="foo", args={"query": "bar"}, id="tool_1")], @@ -97,14 +97,14 @@ def test_build_chat_prompt_from_messages_runnable(test_case: RedboxChatTestCase, @pytest.mark.parametrize(("test_case"), BUILD_LLM_TEST_CASES, ids=[t.test_id for t in BUILD_LLM_TEST_CASES]) def test_build_llm_chain(test_case: RedboxChatTestCase): """Tests a given state can update the data and metadata correctly.""" - llm = GenericFakeChatModel(messages=iter(test_case.test_data.expected_llm_response)) + llm = GenericFakeChatModel(messages=iter(test_case.test_data.llm_responses)) llm_chain = build_llm_chain(PromptSet.Chat, llm) state = RedboxState(request=test_case.query, documents=test_case.docs) final_state = llm_chain.invoke(state) - test_case_content = test_case.test_data.expected_llm_response[-1].content - test_case_tool_calls = tool_calls_to_toolstate(test_case.test_data.expected_llm_response[-1].tool_calls) + test_case_content = test_case.test_data.llm_responses[-1].content + test_case_tool_calls = tool_calls_to_toolstate(test_case.test_data.llm_responses[-1].tool_calls) assert ( final_state["text"] == test_case_content @@ -120,7 +120,7 @@ def test_build_llm_chain(test_case: RedboxChatTestCase): RedboxTestData( number_of_docs=0, tokens_in_all_docs=0, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat, ) ], @@ -131,7 +131,7 @@ def test_build_llm_chain(test_case: RedboxChatTestCase): @pytest.mark.parametrize(("test_case"), CHAT_TEST_CASES, ids=[t.test_id for t in CHAT_TEST_CASES]) def test_build_chat_pattern(test_case: RedboxChatTestCase, mocker: MockerFixture): """Tests a given state["request"] correctly changes state["text"].""" - llm = GenericFakeChatModel(messages=iter(test_case.test_data.expected_llm_response)) + llm = GenericFakeChatModel(messages=iter(test_case.test_data.llm_responses)) state = RedboxState(request=test_case.query, documents=[]) chat = build_chat_pattern(prompt_set=PromptSet.Chat, final_response_chain=True) @@ -140,7 +140,7 @@ def test_build_chat_pattern(test_case: RedboxChatTestCase, mocker: MockerFixture response = chat(state) final_state = RedboxState(response) - test_case_content = test_case.test_data.expected_llm_response[-1].content + test_case_content = test_case.test_data.llm_responses[-1].content assert ( final_state["text"] == test_case_content @@ -153,13 +153,13 @@ def test_build_chat_pattern(test_case: RedboxChatTestCase, mocker: MockerFixture RedboxTestData( number_of_docs=0, tokens_in_all_docs=0, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat, ), RedboxTestData( number_of_docs=2, tokens_in_all_docs=40_000, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat_with_docs, ), ], @@ -193,19 +193,19 @@ def test_build_set_route_pattern(test_case: RedboxChatTestCase): RedboxTestData( number_of_docs=2, tokens_in_all_docs=40_000, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat_with_docs, ), RedboxTestData( number_of_docs=2, tokens_in_all_docs=80_000, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat_with_docs, ), RedboxTestData( number_of_docs=4, tokens_in_all_docs=140_000, - expected_llm_response=["Map Step Response"] * 4 + ["Testing Response 1"], + llm_responses=["Map Step Response"] * 4 + ["Testing Response 1"], expected_route=ChatRoute.chat_with_docs, ), ], @@ -244,13 +244,13 @@ def test_build_retrieve_pattern(test_case: RedboxChatTestCase, mock_retriever: B RedboxTestData( number_of_docs=2, tokens_in_all_docs=40_000, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat_with_docs, ), RedboxTestData( number_of_docs=4, tokens_in_all_docs=40_000, - expected_llm_response=["Testing Response 2"], + llm_responses=["Testing Response 2"], expected_route=ChatRoute.chat_with_docs, ), ], @@ -261,7 +261,7 @@ def test_build_retrieve_pattern(test_case: RedboxChatTestCase, mock_retriever: B @pytest.mark.parametrize(("test_case"), MERGE_TEST_CASES, ids=[t.test_id for t in MERGE_TEST_CASES]) def test_build_merge_pattern(test_case: RedboxChatTestCase, mocker: MockerFixture): """Tests a given state["request"] and state["documents"] correctly changes state["documents"].""" - llm = GenericFakeChatModel(messages=iter(test_case.test_data.expected_llm_response)) + llm = GenericFakeChatModel(messages=iter(test_case.test_data.llm_responses)) state = RedboxState(request=test_case.query, documents=structure_documents_by_file_name(test_case.docs)) merge = build_merge_pattern(prompt_set=PromptSet.ChatwithDocsMapReduce, final_response_chain=True) @@ -273,7 +273,7 @@ def test_build_merge_pattern(test_case: RedboxChatTestCase, mocker: MockerFixtur response_documents = [doc for doc in flatten_document_state(final_state.get("documents")) if doc is not None] noned_documents = sum(1 for doc in final_state.get("documents", {}).values() for v in doc.values() if v is None) - test_case_content = test_case.test_data.expected_llm_response[-1].content + test_case_content = test_case.test_data.llm_responses[-1].content assert len(response_documents) == 1 assert noned_documents == len(test_case.docs) - 1 @@ -294,13 +294,13 @@ def test_build_merge_pattern(test_case: RedboxChatTestCase, mocker: MockerFixtur RedboxTestData( number_of_docs=2, tokens_in_all_docs=40_000, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat_with_docs, ), RedboxTestData( number_of_docs=4, tokens_in_all_docs=40_000, - expected_llm_response=["Testing Response 2"], + llm_responses=["Testing Response 2"], expected_route=ChatRoute.chat_with_docs, ), ], @@ -311,7 +311,7 @@ def test_build_merge_pattern(test_case: RedboxChatTestCase, mocker: MockerFixtur @pytest.mark.parametrize(("test_case"), STUFF_TEST_CASES, ids=[t.test_id for t in STUFF_TEST_CASES]) def test_build_stuff_pattern(test_case: RedboxChatTestCase, mocker: MockerFixture): """Tests a given state["request"] and state["documents"] correctly changes state["text"].""" - llm = GenericFakeChatModel(messages=iter(test_case.test_data.expected_llm_response)) + llm = GenericFakeChatModel(messages=iter(test_case.test_data.llm_responses)) state = RedboxState(request=test_case.query, documents=structure_documents_by_file_name(test_case.docs)) stuff = build_stuff_pattern(prompt_set=PromptSet.ChatwithDocs, final_response_chain=True) @@ -320,7 +320,7 @@ def test_build_stuff_pattern(test_case: RedboxChatTestCase, mocker: MockerFixtur response = stuff.invoke(state) final_state = RedboxState(response) - test_case_content = test_case.test_data.expected_llm_response[-1].content + test_case_content = test_case.test_data.llm_responses[-1].content assert ( final_state["text"] == test_case_content @@ -339,7 +339,7 @@ def test_build_stuff_pattern(test_case: RedboxChatTestCase, mocker: MockerFixtur RedboxTestData( number_of_docs=2, tokens_in_all_docs=2_000, - expected_llm_response=["Testing Response 1"], + llm_responses=["Testing Response 1"], expected_route=ChatRoute.chat_with_docs, ), ], diff --git a/redbox-core/tests/graph/test_state.py b/redbox-core/tests/graph/test_state.py index 790e7412b..7b6f55ea0 100644 --- a/redbox-core/tests/graph/test_state.py +++ b/redbox-core/tests/graph/test_state.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime, timedelta, timezone from uuid import uuid4 import pytest @@ -134,31 +134,32 @@ def test_document_reducer(a: DocumentState, b: DocumentState, expected: Document assert result == expected, f"Expected: {expected}. Result: {result}" +now = datetime.now(UTC) GPT_4o_multiple_calls_1 = [ - LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=0, output_tokens=0), - LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=10, output_tokens=10), - LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=10, output_tokens=10), + LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=0, output_tokens=0, timestamp=now - timedelta(days=10)), + LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=10, output_tokens=10, timestamp=now - timedelta(days=9)), + LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=10, output_tokens=10, timestamp=now - timedelta(days=8)), ] GPT_4o_multiple_calls_1a = GPT_4o_multiple_calls_1 + [ - LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=50, output_tokens=50), - LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=60, output_tokens=60), + LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=50, output_tokens=50, timestamp=now - timedelta(days=7)), + LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=60, output_tokens=60, timestamp=now - timedelta(days=6)), ] GPT_4o_multiple_calls_2 = [ - LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=100, output_tokens=200), - LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=0, output_tokens=10), - LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=100, output_tokens=210), + LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=100, output_tokens=200, timestamp=now - timedelta(days=5)), + LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=0, output_tokens=10, timestamp=now - timedelta(days=4)), + LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=100, output_tokens=210, timestamp=now - timedelta(days=3)), ] multiple_models_multiple_calls_1 = [ - LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=100, output_tokens=200), - LLMCallMetadata(llm_model_name="gpt-3.5", input_tokens=20, output_tokens=20), - LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=100, output_tokens=210), + LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=100, output_tokens=200, timestamp=now - timedelta(days=2)), + LLMCallMetadata(llm_model_name="gpt-3.5", input_tokens=20, output_tokens=20, timestamp=now - timedelta(days=1)), + LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=100, output_tokens=210, timestamp=now - timedelta(hours=10)), ] multiple_models_multiple_calls_1a = multiple_models_multiple_calls_1 + [ - LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=300, output_tokens=310), + LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=300, output_tokens=310, timestamp=now - timedelta(hours=1)), ] @@ -168,7 +169,9 @@ def test_document_reducer(a: DocumentState, b: DocumentState, expected: Document ( RequestMetadata(llm_calls=GPT_4o_multiple_calls_1), RequestMetadata(llm_calls=GPT_4o_multiple_calls_2), - RequestMetadata(llm_calls=GPT_4o_multiple_calls_1 + GPT_4o_multiple_calls_2), + RequestMetadata( + llm_calls=sorted(GPT_4o_multiple_calls_1 + GPT_4o_multiple_calls_2, key=lambda c: c.timestamp) + ), ), ( RequestMetadata(llm_calls=GPT_4o_multiple_calls_1), @@ -178,7 +181,9 @@ def test_document_reducer(a: DocumentState, b: DocumentState, expected: Document ( RequestMetadata(llm_calls=multiple_models_multiple_calls_1), RequestMetadata(llm_calls=GPT_4o_multiple_calls_2), - RequestMetadata(llm_calls=multiple_models_multiple_calls_1 + GPT_4o_multiple_calls_2), + RequestMetadata( + llm_calls=sorted(GPT_4o_multiple_calls_2 + multiple_models_multiple_calls_1, key=lambda c: c.timestamp) + ), ), ( RequestMetadata(llm_calls=GPT_4o_multiple_calls_1), diff --git a/redbox-core/tests/test_citations.py b/redbox-core/tests/test_citations.py new file mode 100644 index 000000000..001cd6548 --- /dev/null +++ b/redbox-core/tests/test_citations.py @@ -0,0 +1,28 @@ +from uuid import uuid4 + +import pytest + +from redbox.app import Redbox +from redbox.models.chain import AISettings, RedboxQuery, RedboxState + + +@pytest.mark.asyncio +async def test_citation(): + app = Redbox(debug=False) + q = RedboxQuery( + question="@gadget Who is Hello Kitty?", + s3_keys=[], + user_uuid=uuid4(), + chat_history=[], + ai_settings=AISettings(rag_k=3), + permitted_s3_keys=[], + ) + + x = RedboxState( + request=q, + ) + + response = await app.run(x) + + assert len(response.get("text")) > 0 + assert isinstance(response.get("citations"), list) diff --git a/redbox-core/tests/test_tools.py b/redbox-core/tests/test_tools.py index c13ed4733..abc0329c5 100644 --- a/redbox-core/tests/test_tools.py +++ b/redbox-core/tests/test_tools.py @@ -1,6 +1,6 @@ from typing import Annotated, Any -from uuid import UUID, uuid4 from urllib.parse import urlparse +from uuid import UUID, uuid4 import pytest from elasticsearch import Elasticsearch @@ -10,9 +10,9 @@ from redbox.graph.nodes.tools import ( build_search_documents_tool, + build_search_wikipedia_tool, has_injected_state, is_valid_tool, - build_search_wikipedia_tool, ) from redbox.models import Settings from redbox.models.chain import AISettings, RedboxQuery, RedboxState @@ -116,7 +116,8 @@ def test_search_documents_tool( { "query": stored_file_parameterised.query.question, "state": RedboxState( - request=stored_file_parameterised.query, text=stored_file_parameterised.query.question + request=stored_file_parameterised.query, + text=stored_file_parameterised.query.question, ), } )