Skip to content

Commit

Permalink
Formatting and a var name change
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesrichards4 committed Oct 22, 2024
1 parent de493d3 commit d7fb704
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 26 deletions.
10 changes: 6 additions & 4 deletions redbox-core/redbox/chains/runnables.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def _combined(obj):


def build_chat_prompt_from_messages_runnable(
prompt_set: PromptSet, tokeniser: Encoding = None, partial_variables: dict = None,
prompt_set: PromptSet,
tokeniser: Encoding = None,
partial_variables: dict = None,
) -> Runnable:
@chain
def _chat_prompt_from_messages(state: RedboxState) -> Runnable:
Expand Down Expand Up @@ -121,17 +123,17 @@ def build_llm_chain(
"text_and_tools": (
_llm
| {
"responses": _output_parser,
"parsed_response": _output_parser,
"tool_calls": (RunnableLambda(lambda r: r.tool_calls) | tool_calls_to_toolstate),
}
),
"prompt": RunnableLambda(lambda prompt: prompt.to_string()),
}
| {
"text": RunnableLambda(combine_getters(itemgetter("text_and_tools"), itemgetter("responses")))
"text": RunnableLambda(combine_getters(itemgetter("text_and_tools"), itemgetter("parsed_response")))
| (lambda r: r if isinstance(r, str) else r.markdown_answer),
"tool_calls": combine_getters(itemgetter("text_and_tools"), itemgetter("tool_calls")),
"citations": RunnableLambda(combine_getters(itemgetter("text_and_tools"), itemgetter("responses")))
"citations": RunnableLambda(combine_getters(itemgetter("text_and_tools"), itemgetter("parsed_response")))
| (lambda r: [] if isinstance(r, str) else r.citations),
"prompt": itemgetter("prompt"),
}
Expand Down
1 change: 0 additions & 1 deletion redbox-core/redbox/models/chain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from datetime import UTC, datetime
from enum import StrEnum
from functools import reduce
Expand Down
5 changes: 1 addition & 4 deletions redbox-core/redbox/test/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,7 @@ def __init__(
# Use separate file_uuids if specified else match the query
all_s3_keys = test_data.s3_keys if test_data.s3_keys else query.s3_keys

if (
test_data.llm_responses is not None
and len(test_data.llm_responses) < test_data.number_of_docs
):
if test_data.llm_responses is not None and len(test_data.llm_responses) < test_data.number_of_docs:
log.warning(
"Number of configured LLM responses might be less than number of docs. For Map-Reduce actions this will give a Generator Error!"
)
Expand Down
4 changes: 1 addition & 3 deletions redbox-core/tests/graph/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,7 @@ def assert_number_of_events(num_of_events: int):
RedboxTestData(
number_of_docs=2,
tokens_in_all_docs=200_000,
llm_responses=["Map Step Response"] * 2
+ ["Merge Per Document Response"]
+ ["Testing Response 1"],
llm_responses=["Map Step Response"] * 2 + ["Merge Per Document Response"] + ["Testing Response 1"],
expected_route=ChatRoute.chat_with_docs_map_reduce,
),
],
Expand Down
33 changes: 19 additions & 14 deletions redbox-core/tests/graph/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,32 +133,33 @@ def test_document_reducer(a: DocumentState, b: DocumentState, expected: Document
result = document_reducer(a, b)
assert result == expected, f"Expected: {expected}. Result: {result}"


now = datetime.now(UTC)
GPT_4o_multiple_calls_1 = [
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=0, output_tokens=0, timestamp=now-timedelta(days=10)),
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=10, output_tokens=10, timestamp=now-timedelta(days=9)),
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=10, output_tokens=10, timestamp=now-timedelta(days=8)),
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=0, output_tokens=0, timestamp=now - timedelta(days=10)),
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=10, output_tokens=10, timestamp=now - timedelta(days=9)),
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=10, output_tokens=10, timestamp=now - timedelta(days=8)),
]

GPT_4o_multiple_calls_1a = GPT_4o_multiple_calls_1 + [
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=50, output_tokens=50, timestamp=now-timedelta(days=7)),
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=60, output_tokens=60, timestamp=now-timedelta(days=6)),
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=50, output_tokens=50, timestamp=now - timedelta(days=7)),
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=60, output_tokens=60, timestamp=now - timedelta(days=6)),
]

GPT_4o_multiple_calls_2 = [
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=100, output_tokens=200, timestamp=now-timedelta(days=5)),
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=0, output_tokens=10, timestamp=now-timedelta(days=4)),
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=100, output_tokens=210, timestamp=now-timedelta(days=3)),
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=100, output_tokens=200, timestamp=now - timedelta(days=5)),
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=0, output_tokens=10, timestamp=now - timedelta(days=4)),
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=100, output_tokens=210, timestamp=now - timedelta(days=3)),
]

multiple_models_multiple_calls_1 = [
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=100, output_tokens=200, timestamp=now-timedelta(days=2)),
LLMCallMetadata(llm_model_name="gpt-3.5", input_tokens=20, output_tokens=20, timestamp=now-timedelta(days=1)),
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=100, output_tokens=210, timestamp=now-timedelta(hours=10)),
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=100, output_tokens=200, timestamp=now - timedelta(days=2)),
LLMCallMetadata(llm_model_name="gpt-3.5", input_tokens=20, output_tokens=20, timestamp=now - timedelta(days=1)),
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=100, output_tokens=210, timestamp=now - timedelta(hours=10)),
]

multiple_models_multiple_calls_1a = multiple_models_multiple_calls_1 + [
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=300, output_tokens=310, timestamp=now-timedelta(hours=1)),
LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=300, output_tokens=310, timestamp=now - timedelta(hours=1)),
]


Expand All @@ -168,7 +169,9 @@ def test_document_reducer(a: DocumentState, b: DocumentState, expected: Document
(
RequestMetadata(llm_calls=GPT_4o_multiple_calls_1),
RequestMetadata(llm_calls=GPT_4o_multiple_calls_2),
RequestMetadata(llm_calls=sorted(GPT_4o_multiple_calls_1 + GPT_4o_multiple_calls_2, key=lambda c: c.timestamp)),
RequestMetadata(
llm_calls=sorted(GPT_4o_multiple_calls_1 + GPT_4o_multiple_calls_2, key=lambda c: c.timestamp)
),
),
(
RequestMetadata(llm_calls=GPT_4o_multiple_calls_1),
Expand All @@ -178,7 +181,9 @@ def test_document_reducer(a: DocumentState, b: DocumentState, expected: Document
(
RequestMetadata(llm_calls=multiple_models_multiple_calls_1),
RequestMetadata(llm_calls=GPT_4o_multiple_calls_2),
RequestMetadata(llm_calls=sorted(GPT_4o_multiple_calls_2 + multiple_models_multiple_calls_1, key=lambda c: c.timestamp)),
RequestMetadata(
llm_calls=sorted(GPT_4o_multiple_calls_2 + multiple_models_multiple_calls_1, key=lambda c: c.timestamp)
),
),
(
RequestMetadata(llm_calls=GPT_4o_multiple_calls_1),
Expand Down

0 comments on commit d7fb704

Please sign in to comment.