Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add settings to chat model #1027

Merged
merged 12 commits into from
Sep 9, 2024
16 changes: 15 additions & 1 deletion django_app/redbox_app/redbox_core/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,21 @@ def export_as_csv(self, request, queryset: QuerySet): # noqa:ARG002
return response

export_as_csv.short_description = "Export Selected"
fields = ["name", "user"]
fieldsets = [
(
None,
{
"fields": ["name", "user"],
},
),
(
"AI Settings",
{
"classes": ["collapse"],
"fields": ["chat_backend", "temperature"],
},
),
]
inlines = [ChatMessageInline]
list_display = ["name", "user", "created_at"]
list_filter = ["user"]
Expand Down
20 changes: 15 additions & 5 deletions django_app/redbox_app/redbox_core/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,20 @@ async def receive(self, text_data=None, bytes_data=None):
logger.debug("received %s from browser", data)
user_message_text: str = data.get("message", "")
selected_file_uuids: Sequence[UUID] = [UUID(u) for u in data.get("selectedFiles", [])]
user: User = self.scope.get("user", None)
user: User = self.scope.get("user")
chat_backend = data.get("llm")
temperature = data.get("temperature")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will default to None, i.e. the frontend does not need to set this


if session_id := data.get("sessionId"):
session = await Chat.objects.aget(id=session_id)
logger.info("updating: chat_backend=%s -> ai_settings=%s", session.chat_backend, chat_backend)
session.chat_backend = chat_backend
session.temperature = temperature
await session.asave()
else:
session = await Chat.objects.acreate(name=user_message_text[: settings.CHAT_TITLE_LENGTH], user=user)
session = await Chat.objects.acreate(
name=user_message_text[: settings.CHAT_TITLE_LENGTH], user=user, chat_backend=chat_backend
)

# save user message
selected_files = File.objects.filter(id__in=selected_file_uuids, user=user)
Expand All @@ -85,7 +93,7 @@ async def llm_conversation(self, selected_files: Sequence[File], session: Chat,
session_messages = ChatMessage.objects.filter(chat=session).order_by("created_at")
message_history: Sequence[Mapping[str, str]] = [message async for message in session_messages]

ai_settings = await self.get_ai_settings(user)
ai_settings = await self.get_ai_settings(session)
state = RedboxState(
request=RedboxQuery(
question=message_history[-1].text,
Expand Down Expand Up @@ -182,8 +190,10 @@ def save_message(

@staticmethod
@database_sync_to_async
def get_ai_settings(user: User) -> dict:
return model_to_dict(user.ai_settings, exclude=["label"])
def get_ai_settings(chat: Chat) -> dict:
ai_settings = model_to_dict(chat.user.ai_settings, exclude=["label"])
ai_settings["chat_backend"] = chat.chat_backend
return ai_settings

async def handle_text(self, response: str) -> str:
await self.send_to_client("text", response)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Generated by Django 5.1.1 on 2024-09-07 11:53

from django.db import migrations, models



def back_populate_ai_settings_on_chat(apps, schema_editor):
Chat = apps.get_model("redbox_core", "Chat")
for chat in Chat.objects.all():
chat.chat_backend = chat.user.ai_settings.chat_backend
chat.temperature = 0
chat.save()

AISettings = apps.get_model("redbox_core", "AISettings")
for ai_settings in AISettings.objects.all():
ai_settings.temperature = 0
ai_settings.save()


class Migration(migrations.Migration):

dependencies = [
('redbox_core', '0041_alter_aisettings_chat_backend'),
]

operations = [
migrations.AddField(
model_name='chat',
name='chat_backend',
field=models.CharField(blank=True, choices=[('gpt-35-turbo-16k', 'gpt-35-turbo-16k'), ('gpt-4-turbo-2024-04-09', 'gpt-4-turbo-2024-04-09'), ('gpt-4o', 'gpt-4o'), ('anthropic.claude-3-sonnet-20240229-v1:0', 'claude-3-sonnet'), ('anthropic.claude-3-haiku-20240307-v1:0', 'claude-3-haiku')], default='gpt-4o', help_text='LLM to use in chat', max_length=64, null=True),
),
migrations.AddField(
model_name='aisettings',
name='temperature',
field=models.FloatField(blank=True, default=0, help_text='temperature for LLM', null=True),
),
migrations.AddField(
model_name='chat',
name='temperature',
field=models.FloatField(blank=True, default=0, help_text='temperature for LLM', null=True),
),
migrations.RunPython(back_populate_ai_settings_on_chat, migrations.RunPython.noop),
migrations.AlterField(
model_name='chat',
name='chat_backend',
field=models.CharField(
choices=[('gpt-35-turbo-16k', 'gpt-35-turbo-16k'), ('gpt-4-turbo-2024-04-09', 'gpt-4-turbo-2024-04-09'),
('gpt-4o', 'gpt-4o'), ('anthropic.claude-3-sonnet-20240229-v1:0', 'claude-3-sonnet'),
('anthropic.claude-3-haiku-20240307-v1:0', 'claude-3-haiku')], default='gpt-4o',
help_text='LLM to use in chat', max_length=64),
),
migrations.AlterField(
model_name='aisettings',
name='temperature',
field=models.FloatField(default=0, help_text='temperature for LLM'),
),
migrations.AlterField(
model_name='chat',
name='temperature',
field=models.FloatField(default=0, help_text='temperature for LLM'),
),

]
24 changes: 19 additions & 5 deletions django_app/redbox_app/redbox_core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,24 @@ def sanitise_string(string: str | None) -> str | None:
return string.replace("\x00", "\ufffd") if string else string


class AISettings(UUIDPrimaryKeyBase, TimeStampedModel):
class AbstractAISettings(models.Model):
class ChatBackend(models.TextChoices):
GPT_35_TURBO = "gpt-35-turbo-16k", _("gpt-35-turbo-16k")
GPT_4_TURBO = "gpt-4-turbo-2024-04-09", _("gpt-4-turbo-2024-04-09")
GPT_4_OMNI = "gpt-4o", _("gpt-4o")
CLAUDE_3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0", _("claude-3-sonnet")
CLAUDE_3_HAIKU = "anthropic.claude-3-haiku-20240307-v1:0", _("claude-3-haiku")

chat_backend = models.CharField(
max_length=64, choices=ChatBackend, help_text="LLM to use in chat", default=ChatBackend.GPT_4_OMNI
)
temperature = models.FloatField(default=0, help_text="temperature for LLM")

class Meta:
abstract = True


class AISettings(UUIDPrimaryKeyBase, TimeStampedModel, AbstractAISettings):
label = models.CharField(max_length=50, unique=True)
max_document_tokens = models.PositiveIntegerField(default=1_000_000, null=True, blank=True)
context_window_size = models.PositiveIntegerField(default=128_000)
Expand All @@ -85,9 +95,6 @@ class ChatBackend(models.TextChoices):
match_boost = models.PositiveIntegerField(default=1)
knn_boost = models.PositiveIntegerField(default=1)
similarity_threshold = models.PositiveIntegerField(default=0)
chat_backend = models.CharField(
max_length=64, choices=ChatBackend, help_text="LLM to use in chat", default=ChatBackend.GPT_4_OMNI
)

def __str__(self) -> str:
return str(self.label)
Expand Down Expand Up @@ -404,7 +411,7 @@ def get_ordered_by_citation_priority(cls, chat_message_id: uuid.UUID) -> Sequenc
)


class Chat(UUIDPrimaryKeyBase, TimeStampedModel):
class Chat(UUIDPrimaryKeyBase, TimeStampedModel, AbstractAISettings):
name = models.TextField(max_length=1024, null=False, blank=False)
user = models.ForeignKey(User, on_delete=models.CASCADE)

Expand All @@ -414,6 +421,13 @@ def __str__(self) -> str: # pragma: no cover
@override
def save(self, force_insert=False, force_update=False, using=None, update_fields=None):
self.name = sanitise_string(self.name)

if self.chat_backend is None:
self.chat_backend = self.user.ai_settings.chat_backend

if self.temperature is None:
self.temperature = self.user.ai_settings.temperature

super().save(force_insert, force_update, using, update_fields)

@classmethod
Expand Down
8 changes: 7 additions & 1 deletion django_app/redbox_app/redbox_core/views/chat_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from django.views import View
from yarl import URL

from redbox_app.redbox_core.models import Chat, ChatMessage, ChatRoleEnum, File
from redbox_app.redbox_core.models import AbstractAISettings, Chat, ChatMessage, ChatRoleEnum, File

logger = logging.getLogger(__name__)

Expand All @@ -40,6 +40,8 @@ def get(self, request: HttpRequest, chat_id: uuid.UUID | None = None) -> HttpRes
self.decorate_selected_files(completed_files, messages)
chat_grouped_by_date_group = groupby(chat, attrgetter("date_group"))

chat_backend = current_chat.chat_backend if current_chat else None

context = {
"chat_id": chat_id,
"messages": messages,
Expand All @@ -50,6 +52,10 @@ def get(self, request: HttpRequest, chat_id: uuid.UUID | None = None) -> HttpRes
"completed_files": completed_files,
"processing_files": processing_files,
"chat_title_length": settings.CHAT_TITLE_LENGTH,
"llm_options": [
{"name": llm, "default": llm == chat_backend, "selected": llm == chat_backend}
for _, llm in AbstractAISettings.ChatBackend.choices
],
}

return render(
Expand Down
6 changes: 0 additions & 6 deletions django_app/redbox_app/templates/chats.html
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,6 @@ <h3 class="govuk-fieldset__heading">Documents to use</h3>
</fieldset>
</document-selector>

{# Temporary until this comes from the view data #}
{% set llm_options = [
{"name": "gpt-4o", "default": True, "selected": True},
{"name": "gpt-4-turbo-2024-04-09"},
{"name": "gpt-35-turbo-16k"}
] %}
<div class="iai-panel govuk-!-margin-top-5 govuk-!-padding-top-3">
<label class="govuk-body-s govuk-!-font-weight-bold" for="llm-selector">Model</label>
<select id="llm-selector" name="llm" class="govuk-select govuk-!-margin-top-1">
Expand Down
4 changes: 3 additions & 1 deletion django_app/tests/test_consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,9 @@ async def test_chat_consumer_get_ai_settings(
connected, _ = await communicator.connect()
assert connected

ai_settings = await ChatConsumer.get_ai_settings(alice)
chat = Chat(name=" a chat", user=alice)

ai_settings = await ChatConsumer.get_ai_settings(chat)

assert ai_settings["chat_map_question_prompt"] == CHAT_MAP_QUESTION_PROMPT
with pytest.raises(KeyError):
Expand Down
23 changes: 23 additions & 0 deletions django_app/tests/test_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,26 @@ def test_0032_user_new_business_unit(migrator):
NewUser = new_state.apps.get_model("redbox_core", "User") # noqa: N806
user = NewUser.objects.get(pk=user.pk)
assert user.business_unit == "Prime Minister's Office"


@pytest.mark.django_db()
def test_0042_chat_chat_backend_chat_chat_map_question_prompt_and_more(migrator):
old_state = migrator.apply_initial_migration(("redbox_core", "0041_alter_aisettings_chat_backend"))

User = old_state.apps.get_model("redbox_core", "User")
user = User.objects.create(email="someone@example.com")

Chat = old_state.apps.get_model("redbox_core", "Chat")
chat = Chat.objects.create(name="my chat", user=user)

assert not hasattr(chat, "chat_backend")

new_state = migrator.apply_tested_migration(
("redbox_core", "0042_chat_chat_backend_chat_chat_map_question_prompt_and_more"),
)

new_chat_model = new_state.apps.get_model("redbox_core", "Chat")
new_chat = new_chat_model.objects.get(id=chat.id)

assert new_chat.chat_backend == chat.user.ai_settings.chat_backend
assert new_chat.chat_backend is not None
Loading