Skip to content

Commit

Permalink
tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
George Burton committed Sep 5, 2024
1 parent 9f0d60a commit a66b876
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 113 deletions.
40 changes: 20 additions & 20 deletions django_app/redbox_app/redbox_core/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,15 @@ async def llm_conversation(self, selected_files: Sequence[File], session: Chat,
),
)

await self.redbox.run(
state,
response_tokens_callback=self.handle_text,
route_name_callback=self.handle_route,
documents_callback=self.handle_documents,
metadata_tokens_callback=self.handle_metadata,
)

try:
await self.redbox.run(
state,
response_tokens_callback=self.handle_text,
route_name_callback=self.handle_route,
documents_callback=self.handle_documents,
metadata_tokens_callback=self.handle_metadata,
)

message = await self.save_message(
session,
"".join(self.full_reply),
Expand Down Expand Up @@ -193,25 +193,25 @@ def get_ai_settings(user: User) -> AISettings:
fields=[field.name for field in user.ai_settings._meta.fields if field.name != "label"], # noqa: SLF001
)

async def handle_text(self, response: ClientResponse) -> str:
await self.send_to_client("text", response.data)
self.full_reply.append(response.data)
async def handle_text(self, response: str) -> str:
await self.send_to_client("text", response)
self.full_reply.append(response)

async def handle_route(self, response: ClientResponse) -> str:
await self.send_to_client("route", response.data)
self.routes.append(response.data)
async def handle_route(self, response: str) -> str:
await self.send_to_client("route", response)
self.route = response

async def handle_metadata(self, response: ClientResponse):
for model, token_count in response.data.input_tokens.items():
async def handle_metadata(self, response: MetadataDetail):
for model, token_count in response.input_tokens.items():
self.metadata.input_tokens[model] = self.metadata.input_tokens.get(model, 0) + token_count
for model, token_count in response.data.output_tokens.items():
for model, token_count in response.output_tokens.items():
self.metadata.output_tokens[model] = self.metadata.output_tokens.get(model, 0) + token_count

async def handle_documents(self, response: ClientResponse) -> Sequence[tuple[File, SourceDocument]]:
s3_keys = [doc.s3_key for doc in response.data]
async def handle_documents(self, response: list[SourceDocument]) -> Sequence[tuple[File, SourceDocument]]:
s3_keys = [doc.s3_key for doc in response]
files = File.objects.filter(original_file__in=s3_keys)

async for file in files:
await self.send_to_client("source", {"url": str(file.url), "original_file_name": file.original_file_name})
for file in files:
self.citations.append((file, [doc for doc in response.data if doc.s3_key == file.unique_name]))
self.citations.append((file, [doc for doc in response if doc.s3_key == file.unique_name]))
Loading

0 comments on commit a66b876

Please sign in to comment.