Skip to content

Commit

Permalink
add services.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rachaelcodes committed Sep 10, 2024
1 parent bc7dd8a commit f1b0bc8
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 14 deletions.
13 changes: 1 addition & 12 deletions django_app/redbox_app/redbox_core/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,13 @@
File,
User,
)
from redbox_app.redbox_core.utils import parse_page_number

OptFileSeq = Sequence[File] | None
logger = logging.getLogger(__name__)
logger.info("WEBSOCKET_SCHEME is: %s", settings.WEBSOCKET_SCHEME)


def parse_page_number(obj: int | list[int] | None) -> list[int]:
if isinstance(obj, int):
return [obj]
if isinstance(obj, list) and all(isinstance(item, int) for item in obj):
return obj
if obj is None:
return []

msg = "expected, int | list[int] | None got %s"
raise ValueError(msg, type(obj))


class ChatConsumer(AsyncWebsocketConsumer):
full_reply: ClassVar = []
citations: ClassVar = []
Expand Down
159 changes: 159 additions & 0 deletions django_app/redbox_app/redbox_core/services.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from collections.abc import Mapping, Sequence
from typing import ClassVar

from django.forms.models import model_to_dict
from django.utils import timezone
from langchain_core.documents import Document

from redbox import Redbox
from redbox.models import Settings
from redbox.models.chain import ChainChatMessage, RedboxQuery, RedboxState
from redbox.models.chat import MetadataDetail
from redbox_app.redbox_core.models import (
AISettings,
Chat,
ChatMessage,
ChatMessageTokenUse,
ChatRoleEnum,
Citation,
File,
User,
)
from redbox_app.redbox_core.utils import parse_page_number


def retrieve_llm_response(selected_files: Sequence[File], session: Chat, user: User):
# TODO: fixme - currently the task runs 'successfully', but chats aren't being saved
# coroutine 'RedboxDjangoInterface.post_llm_request' was never awaited
interface = RedboxDjangoInterface()

interface.post_llm_request(selected_files, session, user)


class RedboxDjangoInterface:
# TODO: see if you can move some of ChatConsumer here to be DRY
full_reply: ClassVar = []
citations: ClassVar = []
route = None
metadata: MetadataDetail = MetadataDetail()
redbox = Redbox(env=Settings(), debug=True)

async def post_llm_request(self, selected_files: Sequence[File], session: Chat, user: User):
session_messages = ChatMessage.objects.filter(chat=session).order_by("created_at")
message_history: Sequence[Mapping[str, str]] = [message async for message in session_messages]

ai_settings = await self.get_ai_settings(session)

state = RedboxState(
request=RedboxQuery(
question=message_history[-1].text,
s3_keys=[f.unique_name for f in selected_files],
user_uuid=user.id,
chat_history=[
ChainChatMessage(role=message.role, text=message.text) for message in message_history[:-1]
],
ai_settings=ai_settings,
),
)

await self.redbox.run(
state,
response_tokens_callback=self.handle_text,
route_name_callback=self.handle_route,
documents_callback=self.handle_documents,
metadata_tokens_callback=self.handle_metadata,
)

await self.save_message(
session,
"".join(self.full_reply),
ChatRoleEnum.ai,
)

session.awaiting_llm_response = False
session.save()

@staticmethod
def get_ai_settings(chat: Chat) -> AISettings:
ai_settings = model_to_dict(chat.user.ai_settings, exclude=["label"])

match str(chat.chat_backend):
case "claude-3-sonnet":
chat_backend = "anthropic.claude-3-sonnet-20240229-v1:0"
case "claude-3-haiku":
chat_backend = "anthropic.claude-3-haiku-20240307-v1:0"
case _:
chat_backend = str(chat.chat_backend)

ai_settings["chat_backend"] = chat_backend
return AISettings.parse_obj(ai_settings)

@staticmethod
def save_message(
session: Chat,
user_message_text: str,
role: ChatRoleEnum,
sources: Sequence[tuple[File, Document]] | None = None,
selected_files: Sequence[File] | None = None,
metadata: MetadataDetail | None = None,
route: str | None = None,
) -> ChatMessage:
chat_message = ChatMessage(chat=session, text=user_message_text, role=role, route=route)
chat_message.save()
if sources:
for file, citations in sources:
file.last_referenced = timezone.now()
file.save()

for citation in citations:
Citation.objects.create(
chat_message=chat_message,
file=file,
text=citation.page_content,
page_numbers=parse_page_number(citation.metadata.get("page_number")),
)
if selected_files:
chat_message.selected_files.set(selected_files)

if metadata and metadata.input_tokens:
for model, token_count in metadata.input_tokens.items():
ChatMessageTokenUse.objects.create(
chat_message=chat_message,
use_type=ChatMessageTokenUse.UseTypeEnum.INPUT,
model_name=model,
token_count=token_count,
)
if metadata and metadata.output_tokens:
for model, token_count in metadata.output_tokens.items():
ChatMessageTokenUse.objects.create(
chat_message=chat_message,
use_type=ChatMessageTokenUse.UseTypeEnum.OUTPUT,
model_name=model,
token_count=token_count,
)
return chat_message

def handle_text(self, response: str) -> None:
self.full_reply.append(response)

def handle_route(self, response: str) -> None:
self.route = response

def handle_metadata(self, response: dict):
metadata_detail = MetadataDetail.parse_obj(response)
for model, token_count in metadata_detail.input_tokens.items():
self.metadata.input_tokens[model] = self.metadata.input_tokens.get(model, 0) + token_count
for model, token_count in metadata_detail.output_tokens.items():
self.metadata.output_tokens[model] = self.metadata.output_tokens.get(model, 0) + token_count

def handle_documents(self, response: list[Document]):
s3_keys = [doc.metadata["file_name"] for doc in response]
files = File.objects.filter(original_file__in=s3_keys)

for file in files:
self.citations.append(
(
file,
[doc for doc in response if doc.metadata["file_name"] == file.unique_name],
)
)
12 changes: 12 additions & 0 deletions django_app/redbox_app/redbox_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,15 @@ def get_date_group(on: date) -> str:
if age > 0:
return "Yesterday"
return "Today"


def parse_page_number(obj: int | list[int] | None) -> list[int]:
if isinstance(obj, int):
return [obj]
if isinstance(obj, list) and all(isinstance(item, int) for item in obj):
return obj
if obj is None:
return []

msg = "expected, int | list[int] | None got %s"
raise ValueError(msg, type(obj))
17 changes: 15 additions & 2 deletions django_app/redbox_app/redbox_core/views/chat_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
from django.utils.decorators import method_decorator
from django.views import View
from django.views.decorators.csrf import csrf_protect
from django_q.tasks import async_task
from yarl import URL

from redbox_app.redbox_core.models import AbstractAISettings, Chat, ChatMessage, ChatRoleEnum, File
from redbox_app.redbox_core.services import retrieve_llm_response

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -81,13 +83,15 @@ def post(self, request: HttpRequest, chat_id: uuid.UUID | None = None) -> HttpRe
form_data = getattr(
request, "_post", None
) # request.POST doesn't work here; the channels docs are unclear about the ASGI response format
# TODO: use .loads and schema 👻 to get the form data
# TODO: use .loads and schema to get the form data
if not form_data:
return redirect(reverse("chats"))
user_message = form_data.get("message")

# Get or create Chat session
if session_id := form_data.get("session-id"):
if session_id != chat_id:
return redirect(reverse("chats"))
session = get_object_or_404(Chat, id=session_id)
if session.user != user:
return redirect(reverse("chats"))
Expand All @@ -105,12 +109,21 @@ def post(self, request: HttpRequest, chat_id: uuid.UUID | None = None) -> HttpRe
chat_message.selected_files.set(selected_files)

# Enqueue request to LLM
# (including existing messages)
session.awaiting_llm_response = True
session.save()

async_task(
retrieve_llm_response, selected_files, session, user, task_name=session.name, group="chat_request"
)

# Redirect to chat view
# This is being slowed down to allow time for the LLM response
# while making sure the user request doesn't time out

# TODO: check every 0.5 second for up to 10 seconds;
# if session.awaiting_llm_response has changed to False, then redirect
# after 10 seconds, redirect anyway

return redirect("chats", chat_id=session.id)

@staticmethod
Expand Down

0 comments on commit f1b0bc8

Please sign in to comment.