Skip to content

Commit

Permalink
Making the token count check do both max_tokens and context_window ch…
Browse files Browse the repository at this point in the history
…ecks to ensure an error is thrown even when max_tokens is below context window
  • Loading branch information
jamesrichards4 committed Sep 2, 2024
1 parent 1438a69 commit b639b8d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 20 deletions.
8 changes: 4 additions & 4 deletions django_app/redbox_app/redbox_core/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ async def handle_route(self, response: ClientResponse, show_route: bool) -> str:

async def handle_metadata(self, current_metadata: MetadataDetail, metadata_event: MetadataDetail):
result = current_metadata.model_copy(deep=True)
for model,token_count in metadata_event.input_tokens.items():
result.input_tokens[model] = current_metadata.input_tokens.get(model, 0) + token_count
for model,token_count in metadata_event.output_tokens.items():
for model, token_count in metadata_event.input_tokens.items():
result.input_tokens[model] = current_metadata.input_tokens.get(model, 0) + token_count
for model, token_count in metadata_event.output_tokens.items():
result.output_tokens[model] = current_metadata.output_tokens.get(model, 0) + token_count
return result

async def handle_error(self, response: ClientResponse) -> str:
match response.data.code:
case "no-document-selected":
Expand Down
25 changes: 25 additions & 0 deletions redbox-core/redbox/graph/edges.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import re
from typing import Literal

from langchain_core.runnables import Runnable

Expand All @@ -24,6 +25,30 @@ def calculate_token_budget(state: RedboxState, system_prompt: str, question_prom
return ai_settings.context_window_size - ai_settings.llm_max_tokens - len_system_prompt - len_question_prompt


def build_total_tokens_request_handler_conditional(prompt_set: PromptSet) -> Runnable:
"""Uses a set of prompts to calculate the total tokens used in this request and returns a label
for the request handler to be used
"""

def _total_tokens_request_handler_conditional(
state: RedboxState,
) -> Literal["max_exceeded", "context_exceeded", "pass"]:
system_prompt, question_prompt = get_prompts(state, prompt_set)
token_budget_remaining_in_context = calculate_token_budget(state, system_prompt, question_prompt)
max_tokens_allowed = state["request"].ai_settings.max_document_tokens

total_tokens = sum(d.metadata["token_count"] for d in flatten_document_state(state["documents"]))

if total_tokens > max_tokens_allowed:
return "max_exceeded"
elif total_tokens > token_budget_remaining_in_context:
return "context_exceeded"
else:
return "pass"

return _total_tokens_request_handler_conditional


def build_documents_bigger_than_context_conditional(prompt_set: PromptSet) -> Runnable:
"""Uses a set of prompts to build the correct conditional for exceeding the context window."""

Expand Down
24 changes: 8 additions & 16 deletions redbox-core/redbox/graph/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from redbox.graph.edges import (
build_documents_bigger_than_context_conditional,
build_total_tokens_request_handler_conditional,
multiple_docs_in_group_conditional,
build_keyword_detection_conditional,
documents_bigger_than_n_conditional,
documents_selected_conditional,
)
from redbox.graph.nodes.processes import (
Expand Down Expand Up @@ -114,8 +114,7 @@ def get_chat_with_documents_graph(
)

# Decisions
builder.add_node("d_all_docs_bigger_than_context", empty_process)
builder.add_node("d_all_docs_bigger_than_n", empty_process)
builder.add_node("d_request_handler_from_total_tokens", empty_process)
builder.add_node("d_single_doc_summaries_bigger_than_context", empty_process)
builder.add_node("d_doc_summaries_bigger_than_context", empty_process)
builder.add_node("d_groups_have_multiple_docs", empty_process)
Expand All @@ -128,21 +127,14 @@ def get_chat_with_documents_graph(
# Edges
builder.add_edge(START, "p_pass_question_to_text")
builder.add_edge("p_pass_question_to_text", "p_retrieve_docs")
builder.add_edge("p_retrieve_docs", "d_all_docs_bigger_than_context")
builder.add_edge("p_retrieve_docs", "d_request_handler_from_total_tokens")
builder.add_conditional_edges(
"d_all_docs_bigger_than_context",
build_documents_bigger_than_context_conditional(PromptSet.ChatwithDocsMapReduce),
{
True: "d_all_docs_bigger_than_n",
False: "p_set_chat_docs_route",
},
)
builder.add_conditional_edges(
"d_all_docs_bigger_than_n",
documents_bigger_than_n_conditional,
"d_request_handler_from_total_tokens",
build_total_tokens_request_handler_conditional(PromptSet.ChatwithDocsMapReduce),
{
True: "p_too_large_error",
False: "p_set_chat_docs_large_route",
"max_exceeded": "p_too_large_error",
"context_exceeded": "p_set_chat_docs_large_route",
"pass": "p_set_chat_docs_route",
},
)
builder.add_edge("p_set_chat_docs_route", "p_summarise")
Expand Down

0 comments on commit b639b8d

Please sign in to comment.