Skip to content

Commit

Permalink
Merge pull request #825 from i-dot-ai/feature/redbox-497
Browse files Browse the repository at this point in the history
Feature/redbox 497
  • Loading branch information
jamesrichards4 authored Jul 22, 2024
2 parents 142a54c + c85d3b5 commit 0a11f61
Show file tree
Hide file tree
Showing 10 changed files with 252 additions and 597 deletions.
66 changes: 26 additions & 40 deletions core-api/core_api/routes/chat.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import logging
import re
from typing import Annotated, Literal
from typing import Annotated
from uuid import UUID

from core_api.auth import get_user_uuid, get_ws_user_uuid
from core_api.semantic_routes import get_routable_chains, get_semantic_route_layer
from core_api.semantic_routes import get_routable_chains
from fastapi import Depends, FastAPI, WebSocket
from fastapi.encoders import jsonable_encoder
from langchain_core.runnables import Runnable
from langchain_core.tools import Tool
from openai import APIError
from pydantic import BaseModel
from semantic_router import RouteLayer

from redbox.api.runnables import map_to_chat_response
from redbox.models.chain import ChainInput
from redbox.models.chat import ChatRequest, ChatResponse, SourceDocument
from redbox.models.chain import ChainInput, ChainChatMessage
from redbox.models.chat import ChatRequest, ChatResponse, SourceDocument, ClientResponse, ErrorDetail
from redbox.models.errors import NoDocumentSelected, QuestionLengthError
from redbox.transform import map_document_to_source_document

Expand All @@ -32,10 +31,6 @@
version="0.1.0",
openapi_tags=[
{"name": "chat", "description": "Chat interactions with LLM and RAG backend"},
{
"name": "embedding",
"description": "Embedding interactions with SentenceTransformer",
},
{"name": "llm", "description": "LLM information and parameters"},
],
docs_url="/docs",
Expand All @@ -44,8 +39,8 @@
)


async def semantic_router_to_chain(
chat_request: ChatRequest, user_uuid: UUID, routable_chains: dict[str, Runnable], route_layer: RouteLayer
async def route_chat(
chat_request: ChatRequest, user_uuid: UUID, routable_chains: dict[str, Tool]
) -> tuple[Runnable, ChainInput]:
question = chat_request.message_history[-1].text

Expand All @@ -59,20 +54,16 @@ def select_chat_chain(chat_request: ChatRequest, routable_chains: dict[str, Runn

# Match keyword
route_match = re_keyword_pattern.search(question)
if route_match:
route_name = route_match.group()[1:]
selected_chain = routable_chains.get(route_name)

# Semantic route
if selected_chain is None:
route_name = route_layer(question).name
selected_chain = routable_chains.get(route_name, select_chat_chain(chat_request, routable_chains))
route_name = route_match.group()[1:] if route_match else None
selected_chain = routable_chains.get(route_name, select_chat_chain(chat_request, routable_chains))

params = ChainInput(
question=chat_request.message_history[-1].text,
file_uuids=[f.uuid for f in chat_request.selected_files],
user_uuid=user_uuid,
chat_history=chat_request.message_history[:-1],
file_uuids=[str(f.uuid) for f in chat_request.selected_files],
user_uuid=str(user_uuid),
chat_history=[
ChainChatMessage(role=message.role, text=message.text) for message in chat_request.message_history[:-1]
],
)

log.info("Routed to %s", route_name)
Expand All @@ -85,30 +76,25 @@ def select_chat_chain(chat_request: ChatRequest, routable_chains: dict[str, Runn
async def rag_chat(
chat_request: ChatRequest,
user_uuid: Annotated[UUID, Depends(get_user_uuid)],
routable_chains: Annotated[dict[str, Runnable], Depends(get_routable_chains)],
route_layer: Annotated[RouteLayer, Depends(get_semantic_route_layer)],
routable_chains: Annotated[dict[str, Tool], Depends(get_routable_chains)],
) -> ChatResponse:
"""REST endpoint. Get a LLM response to a question history and file."""
selected_chain, params = await semantic_router_to_chain(chat_request, user_uuid, routable_chains, route_layer)
return (selected_chain | map_to_chat_response).invoke(params.model_dump())

selected_chain, params = await route_chat(chat_request, user_uuid, routable_chains)
return (selected_chain | map_to_chat_response).invoke(params.dict())

class ErrorDetail(BaseModel):
code: str
message: str


class ClientResponse(BaseModel):
# Needs to match CoreChatResponse in django_app/redbox_app/redbox_core/consumers.py
resource_type: Literal["text", "documents", "route_name", "end", "error"]
data: list[SourceDocument] | str | ErrorDetail | None = None
@chat_app.get("/tools", tags=["chat"])
async def available_tools(
routable_chains: Annotated[dict[str, Tool], Depends(get_routable_chains)],
):
"""REST endpoint. Get a mapping of all tools available via chat."""
return [{"name": chat_tool.name, "description": chat_tool.description} for chat_tool in routable_chains.values()]


@chat_app.websocket("/rag")
async def rag_chat_streamed(
websocket: WebSocket,
routable_chains: Annotated[dict[str, Runnable], Depends(get_routable_chains)],
route_layer: Annotated[RouteLayer, Depends(get_semantic_route_layer)],
routable_chains: Annotated[dict[str, Tool], Depends(get_routable_chains)],
):
"""Websocket. Get a LLM response to a question history and file."""
await websocket.accept()
Expand All @@ -118,10 +104,10 @@ async def rag_chat_streamed(
request = await websocket.receive_text()
chat_request = ChatRequest.model_validate_json(request)

selected_chain, params = await semantic_router_to_chain(chat_request, user_uuid, routable_chains, route_layer)
selected_chain, params = await route_chat(chat_request, user_uuid, routable_chains)

try:
async for event in selected_chain.astream(params.model_dump()):
async for event in selected_chain.astream(params.dict()):
response: str = event.get("response", "")
source_documents: list[SourceDocument] = [
map_document_to_source_document(doc) for doc in event.get("source_documents", [])
Expand Down
146 changes: 32 additions & 114 deletions core-api/core_api/semantic_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,142 +5,60 @@
build_chat_with_docs_chain,
build_condense_retrieval_chain,
build_static_response_chain,
build_summary_chain,
)
from core_api.dependencies import get_env
from fastapi import Depends
from langchain_core.runnables import Runnable
from semantic_router import Route
from semantic_router.encoders import AzureOpenAIEncoder, BaseEncoder, OpenAIEncoder
from semantic_router.layer import RouteLayer
from langchain_community.tools import Tool

from redbox.models import Settings
from redbox.models.chat import ChatRoute
from redbox.models.chain import ChainInput

# === Pre-canned responses for non-LLM routes ===
INFO_RESPONSE = """
I am Redbox, an AI focused on helping UK Civil Servants, Political Advisors and
Ministers triage and summarise information from a wide variety of sources.
"""

ABILITY_RESPONSE = """
* I can help you search over selected documents and do Q&A on them.
* I can help you summarise selected documents.
* I can help you extract information from selected documents.
* I can return information in a variety of formats, such as bullet points.
"""

COACH_RESPONSE = """
I am sorry that didn't work.
You could try rephrasing your task, i.e if you want to summarise a document please use the term,
"Summarise the selected document" or "extract all action items from the selected document."
If you want the results to be returned in a specific format, please specify the format in as much detail as possible.
"""

# === Set up the semantic router ===
info = Route(
name=ChatRoute.info.value,
utterances=[
"What is your name?",
"Who are you?",
"What is Redbox?",
],
)

ability = Route(
name=ChatRoute.ability.value,
utterances=[
"What can you do?",
"What can you do?",
"How can you help me?",
"What does Redbox do?",
"What can Redbox do",
"What don't you do",
"Please help me",
"Please help",
"Help me!",
"help",
],
)
def as_chat_tool(
name: str,
runnable: Runnable,
description: str,
):
return runnable.as_tool(name=name, description=description, args_schema=ChainInput)

coach = Route(
name=ChatRoute.coach.value,
utterances=[
"That is not the answer I wanted",
"Rubbish",
"No good",
"That's not what I wanted",
"How can I improve the results?",
],
)

gratitude = Route(
name=ChatRoute.gratitude.value,
utterances=[
"Thank you",
"Thank you ever so much for your help!",
"I'm really grateful for your assistance.",
"Cheers for the detailed response!",
"Thanks a lot, that was very informative.",
"Nice one",
"Thanks!",
],
)

__routable_chains = None
__semantic_route_layer = None


def get_semantic_routes():
return (info, ability, coach, gratitude)


def get_semantic_routing_encoder(env: Annotated[Settings, Depends(get_env)]):
"""
TODO: This is a duplication of the logic for getting the LangChain embedding model used elsewhere
We should replace semanticrouter with our own implementation to avoid this
"""
if env.embedding_backend == "azure":
return AzureOpenAIEncoder(
azure_endpoint=env.azure_openai_endpoint, api_version="2023-05-15", model=env.azure_embedding_model
)
elif env.embedding_backend == "openai":
return OpenAIEncoder(
openai_base_url=env.embedding_openai_base_url,
openai_api_key=env.openai_api_key,
name=env.embedding_openai_model,
)


def get_semantic_route_layer(
routes: Annotated[list[Route], Depends(get_semantic_routes)],
encoder: Annotated[BaseEncoder, Depends(get_semantic_routing_encoder)],
):
"""
Manual singleton creation as lru_cache can't handle the semantic router classes (non hashable)
"""
global __semantic_route_layer # noqa: PLW0603
if not __semantic_route_layer:
__semantic_route_layer = RouteLayer(encoder=encoder, routes=routes)
return __semantic_route_layer


def get_routable_chains(
summary_chain: Annotated[Runnable, Depends(build_summary_chain)],
condense_chain: Annotated[Runnable, Depends(build_condense_retrieval_chain)],
chat_chain: Annotated[Runnable, Depends(build_chat_chain)],
chat_with_docs_chain: Annotated[Runnable, Depends(build_chat_with_docs_chain)],
):
) -> dict[str, Tool]:
global __routable_chains # noqa: PLW0603
if not __routable_chains:
__routable_chains = {
ChatRoute.info: build_static_response_chain(INFO_RESPONSE, ChatRoute.info),
ChatRoute.ability: build_static_response_chain(ABILITY_RESPONSE, ChatRoute.ability),
ChatRoute.coach: build_static_response_chain(COACH_RESPONSE, ChatRoute.coach),
ChatRoute.gratitude: build_static_response_chain("You're welcome!", ChatRoute.gratitude),
ChatRoute.chat: chat_chain,
ChatRoute.chat_with_docs: chat_with_docs_chain,
ChatRoute.search: condense_chain,
ChatRoute.summarise: summary_chain,
}
chat_tools = (
as_chat_tool(
name=ChatRoute.info,
runnable=build_static_response_chain(INFO_RESPONSE, ChatRoute.info),
description="Give helpful information about Redbox",
),
as_chat_tool(
name=ChatRoute.chat,
runnable=chat_chain,
description="Answer questions as a helpful assistant",
),
as_chat_tool(
name=ChatRoute.chat_with_docs,
runnable=chat_with_docs_chain,
description="Answer questions as a helpful assistant using the documents provided",
),
as_chat_tool(
name=ChatRoute.search,
runnable=condense_chain,
description="Search for an answer to a question in provided documents",
),
)
__routable_chains = {chat_tool.name: chat_tool for chat_tool in chat_tools}
return __routable_chains
Loading

0 comments on commit 0a11f61

Please sign in to comment.