diff --git a/django_app/redbox_app/redbox_core/consumers.py b/django_app/redbox_app/redbox_core/consumers.py index c13b53de9..dc4f1e045 100644 --- a/django_app/redbox_app/redbox_core/consumers.py +++ b/django_app/redbox_app/redbox_core/consumers.py @@ -64,12 +64,18 @@ async def receive(self, text_data=None, bytes_data=None): logger.debug("received %s from browser", data) user_message_text: str = data.get("message", "") selected_file_uuids: Sequence[UUID] = [UUID(u) for u in data.get("selectedFiles", [])] - user: User = self.scope.get("user", None) + user: User = self.scope.get("user") + chat_backend = self.scope.get("llm") if session_id := data.get("sessionId"): session = await Chat.objects.aget(id=session_id) + if chat_backend and session.chat_backend != chat_backend: + session.chat_backend = chat_backend + await session.asave() else: - session = await Chat.objects.acreate(name=user_message_text[: settings.CHAT_TITLE_LENGTH], user=user) + session = await Chat.objects.acreate( + name=user_message_text[: settings.CHAT_TITLE_LENGTH], user=user, chat_backend=chat_backend + ) # save user message selected_files = File.objects.filter(id__in=selected_file_uuids, user=user) diff --git a/django_app/redbox_app/redbox_core/views/chat_views.py b/django_app/redbox_app/redbox_core/views/chat_views.py index 64e2eb781..1e0eff57e 100644 --- a/django_app/redbox_app/redbox_core/views/chat_views.py +++ b/django_app/redbox_app/redbox_core/views/chat_views.py @@ -16,7 +16,7 @@ from django.views import View from yarl import URL -from redbox_app.redbox_core.models import Chat, ChatMessage, ChatRoleEnum, File +from redbox_app.redbox_core.models import AbstractAISettings, Chat, ChatMessage, ChatRoleEnum, File logger = logging.getLogger(__name__) @@ -40,6 +40,8 @@ def get(self, request: HttpRequest, chat_id: uuid.UUID | None = None) -> HttpRes self.decorate_selected_files(completed_files, messages) chat_grouped_by_date_group = groupby(chat, attrgetter("date_group")) + chat_backend = current_chat.chat_backend if current_chat else None + context = { "chat_id": chat_id, "messages": messages, @@ -50,6 +52,10 @@ def get(self, request: HttpRequest, chat_id: uuid.UUID | None = None) -> HttpRes "completed_files": completed_files, "processing_files": processing_files, "chat_title_length": settings.CHAT_TITLE_LENGTH, + "llm_options": [ + {"name": llm, "default": llm == chat_backend, "selected": llm == chat_backend} + for _, llm in AbstractAISettings.ChatBackend.choices + ], } return render( diff --git a/django_app/redbox_app/templates/chats.html b/django_app/redbox_app/templates/chats.html index 3c9f0ee7a..ed932cc2a 100644 --- a/django_app/redbox_app/templates/chats.html +++ b/django_app/redbox_app/templates/chats.html @@ -65,12 +65,6 @@

Documents to use

- {# Temporary until this comes from the view data #} - {% set llm_options = [ - {"name": "gpt-4o", "default": True, "selected": True}, - {"name": "gpt-4-turbo-2024-04-09"}, - {"name": "gpt-35-turbo-16k"} - ] %}