Skip to content

Commit

Permalink
feature/tests-for-graph-node-sends (#1079)
Browse files Browse the repository at this point in the history
* added tests for sends

* added all state kvps

* simplified code
  • Loading branch information
gecBurton authored Oct 7, 2024
1 parent 1063ed3 commit 808c46b
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 19 deletions.
35 changes: 16 additions & 19 deletions redbox-core/redbox/graph/nodes/sends.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,36 @@
from redbox.models.chain import RedboxState


def _copy_state(state: RedboxState, **updates) -> RedboxState:
kwargs = dict(state) | updates
return RedboxState(**kwargs)


def build_document_group_send(target: str) -> Callable[[RedboxState], list[Send]]:
def _group_send(state: RedboxState) -> list[Send]:
if state.get("documents") is None:
raise KeyError
"""builds Sends per document-groups"""

def _group_send(state: RedboxState) -> list[Send]:
group_send_states: list[RedboxState] = [
RedboxState(
request=state["request"],
text=state.get("text"),
documents={group_key: state["documents"][group_key]},
route=state.get("route"),
_copy_state(state,
documents={document_group_key: document_group},
)
for group_key in state["documents"]
for document_group_key, document_group in state["documents"].items()
]
return [Send(node=target, arg=state) for state in group_send_states]

return _group_send


def build_document_chunk_send(target: str) -> Callable[[RedboxState], list[Send]]:
def _chunk_send(state: RedboxState) -> list[Send]:
if state.get("documents") is None:
raise KeyError
"""builds Sends per individual document"""

def _chunk_send(state: RedboxState) -> list[Send]:
chunk_send_states: list[RedboxState] = [
RedboxState(
request=state["request"],
text=state.get("text"),
documents={group_key: {document_key: state["documents"][group_key][document_key]}},
route=state.get("route"),
_copy_state(state,
documents={document_group_key: {document_key: document}},
)
for group_key in state["documents"]
for document_key in state["documents"][group_key]
for document_group_key, document_group in state["documents"].items()
for document_key, document in document_group.items()
]
return [Send(node=target, arg=state) for state in chunk_send_states]

Expand Down
Empty file.
66 changes: 66 additions & 0 deletions redbox-core/tests/graph/nodes/test_sends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from uuid import uuid4

from langchain_core.documents import Document
from langgraph.constants import Send

from redbox.graph.nodes.sends import build_document_group_send, build_document_chunk_send
from redbox.models.chain import RedboxState, RedboxQuery, DocumentState


def test_build_document_group_send():
target = "my-target"
request = RedboxQuery(question="what colour is the sky?", user_uuid=uuid4(), chat_history=[])
documents = DocumentState(
group={uuid4(): Document(page_content="Hello, world!"), uuid4(): Document(page_content="Goodbye, world!")}
)

document_group_send = build_document_group_send("my-target")
state = RedboxState(
request=request,
documents=documents,
text=None,
route_name=None,
)
actual = document_group_send(state)
expected = [Send(node=target, arg=state)]
assert expected == actual


def test_build_document_chunk_send():
target = "my-target"
request = RedboxQuery(question="what colour is the sky?", user_uuid=uuid4(), chat_history=[])

uuid_1 = uuid4()
doc_1 = Document(page_content="Hello, world!")
uuid_2 = uuid4()
doc_2 = Document(page_content="Goodbye, world!")

document_chunk_send = build_document_chunk_send("my-target")
state = RedboxState(
request=request,
documents=DocumentState(group={uuid_1: doc_1, uuid_2: doc_2}),
text=None,
route_name=None,
)
actual = document_chunk_send(state)
expected = [
Send(
node=target,
arg=RedboxState(
request=request,
documents=DocumentState(group={uuid_1: doc_1}),
text=None,
route_name=None,
),
),
Send(
node=target,
arg=RedboxState(
request=request,
documents=DocumentState(group={uuid_2: doc_2}),
text=None,
route_name=None,
),
),
]
assert expected == actual

0 comments on commit 808c46b

Please sign in to comment.