Skip to content

Commit

Permalink
fix llm tests with after #93
Browse files Browse the repository at this point in the history
  • Loading branch information
RLKRo committed Feb 17, 2025
1 parent 6f64e24 commit cb3cd70
Showing 1 changed file with 52 additions and 49 deletions.
101 changes: 52 additions & 49 deletions tests/llm/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from chatsky.llm.filters import IsImportant, FromModel
from chatsky.llm.methods import Contains, LogProb, BaseMethod
from chatsky.core.message import Message
from chatsky.core.context import Context
from chatsky.core.script import Node
from chatsky.core.node_label import AbsoluteNodeLabel

Expand Down Expand Up @@ -158,33 +157,34 @@ def pipeline(mock_model):


@pytest.fixture
def filter_context():
ctx = Context.init(AbsoluteNodeLabel(flow_name="flow", node_name="node"))
def filter_context(context_factory):
ctx = context_factory(start_label=AbsoluteNodeLabel(flow_name="flow", node_name="node"))
ctx.framework_data.current_node = Node(misc={"prompt": "1"})
ctx.add_request(
Message(text="Request 1", misc={"important": True}, annotations={"__generated_by_model__": "test_model"})
ctx.requests[1] = Message(
text="Request 1", misc={"important": True}, annotations={"__generated_by_model__": "test_model"}
)
ctx.add_request(
Message(text="Request 2", misc={"important": False}, annotations={"__generated_by_model__": "other_model"})
ctx.requests[2] = Message(
text="Request 2", misc={"important": False}, annotations={"__generated_by_model__": "other_model"}
)
ctx.add_request(
Message(text="Request 3", misc={"important": False}, annotations={"__generated_by_model__": "test_model"})
ctx.requests[3] = Message(
text="Request 3", misc={"important": False}, annotations={"__generated_by_model__": "test_model"}
)
ctx.add_response(
Message(text="Response 1", misc={"important": False}, annotations={"__generated_by_model__": "test_model"})
ctx.responses[1] = Message(
text="Response 1", misc={"important": False}, annotations={"__generated_by_model__": "test_model"}
)
ctx.add_response(
Message(text="Response 2", misc={"important": True}, annotations={"__generated_by_model__": "other_model"})
ctx.responses[2] = Message(
text="Response 2", misc={"important": True}, annotations={"__generated_by_model__": "other_model"}
)
ctx.add_response(
Message(text="Response 3", misc={"important": False}, annotations={"__generated_by_model__": "test_model"})
ctx.responses[3] = Message(
text="Response 3", misc={"important": False}, annotations={"__generated_by_model__": "test_model"}
)
ctx.current_turn_id = 3
return ctx


@pytest.fixture
def context(pipeline):
ctx = Context.init(AbsoluteNodeLabel(flow_name="flow", node_name="node"))
def context(pipeline, context_factory):
ctx = context_factory(start_label=AbsoluteNodeLabel(flow_name="flow", node_name="node"))
ctx.framework_data.pipeline = pipeline
ctx.framework_data.current_node = Node(
misc={
Expand All @@ -193,10 +193,11 @@ def context(pipeline):
"prompt_last": Prompt(message=Message("last prompt"), position=1000),
}
)
for i in range(3):
ctx.add_request(f"Request {i}")
ctx.add_response(f"Response {i}")
ctx.add_request("Last request")
for i in range(1, 4):
ctx.requests[i] = f"Request {i}"
ctx.responses[i] = f"Response {i}"
ctx.requests[4] = "Last request"
ctx.current_turn_id = 4
return ctx


Expand All @@ -216,17 +217,17 @@ class TestHistory:
[
(
2,
"Mock response with history: ['Request 1', 'Response 1', "
"'Request 2', 'Response 2', 'prompt', 'Last request', 'last prompt']",
"Mock response with history: ['Request 2', 'Response 2', "
"'Request 3', 'Response 3', 'prompt', 'Last request', 'last prompt']",
),
(
0,
"Mock response with history: ['prompt', 'Last request', 'last prompt']",
),
(
4,
"Mock response with history: ['Request 0', 'Response 0', "
"'Request 1', 'Response 1', 'Request 2', 'Response 2', 'prompt', 'Last request', 'last prompt']",
"Mock response with history: ['Request 1', 'Response 1', "
"'Request 2', 'Response 2', 'Request 3', 'Response 3', 'prompt', 'Last request', 'last prompt']",
),
],
)
Expand All @@ -241,20 +242,20 @@ async def test_context_to_history(self, context):
ctx=context, length=-1, filter_func=lambda *args: True, llm_model_name="test_model", max_size=100
)
expected = [
HumanMessage(content=[{"type": "text", "text": "Request 0"}]),
AIMessage(content=[{"type": "text", "text": "Response 0"}]),
HumanMessage(content=[{"type": "text", "text": "Request 1"}]),
AIMessage(content=[{"type": "text", "text": "Response 1"}]),
HumanMessage(content=[{"type": "text", "text": "Request 2"}]),
AIMessage(content=[{"type": "text", "text": "Response 2"}]),
HumanMessage(content=[{"type": "text", "text": "Request 3"}]),
AIMessage(content=[{"type": "text", "text": "Response 3"}]),
]
assert res == expected
res = await context_to_history(
ctx=context, length=1, filter_func=lambda *args: True, llm_model_name="test_model", max_size=100
)
expected = [
HumanMessage(content=[{"type": "text", "text": "Request 2"}]),
AIMessage(content=[{"type": "text", "text": "Response 2"}]),
HumanMessage(content=[{"type": "text", "text": "Request 3"}]),
AIMessage(content=[{"type": "text", "text": "Response 3"}]),
]
assert res == expected

Expand All @@ -267,12 +268,12 @@ class TestGetLangchainContext:
PositionConfig(),
[
SystemMessage(content=[{"type": "text", "text": "system prompt"}]),
HumanMessage(content=[{"type": "text", "text": "Request 0"}]),
AIMessage(content=[{"type": "text", "text": "Response 0"}]),
HumanMessage(content=[{"type": "text", "text": "Request 1"}]),
AIMessage(content=[{"type": "text", "text": "Response 1"}]),
HumanMessage(content=[{"type": "text", "text": "Request 2"}]),
AIMessage(content=[{"type": "text", "text": "Response 2"}]),
HumanMessage(content=[{"type": "text", "text": "Request 3"}]),
AIMessage(content=[{"type": "text", "text": "Response 3"}]),
HumanMessage(content=[{"type": "text", "text": "prompt"}]),
HumanMessage(content=[{"type": "text", "text": "call prompt"}]),
HumanMessage(content=[{"type": "text", "text": "Last request"}]),
Expand All @@ -290,12 +291,12 @@ class TestGetLangchainContext:
[
HumanMessage(content=[{"type": "text", "text": "Last request"}]),
HumanMessage(content=[{"type": "text", "text": "prompt"}]),
HumanMessage(content=[{"type": "text", "text": "Request 0"}]),
AIMessage(content=[{"type": "text", "text": "Response 0"}]),
HumanMessage(content=[{"type": "text", "text": "Request 1"}]),
AIMessage(content=[{"type": "text", "text": "Response 1"}]),
HumanMessage(content=[{"type": "text", "text": "Request 2"}]),
AIMessage(content=[{"type": "text", "text": "Response 2"}]),
HumanMessage(content=[{"type": "text", "text": "Request 3"}]),
AIMessage(content=[{"type": "text", "text": "Response 3"}]),
HumanMessage(content=[{"type": "text", "text": "call prompt"}]),
SystemMessage(content=[{"type": "text", "text": "system prompt"}]),
HumanMessage(content=[{"type": "text", "text": "last prompt"}]),
Expand All @@ -312,12 +313,12 @@ class TestGetLangchainContext:
),
[
SystemMessage(content=[{"type": "text", "text": "system prompt"}]),
HumanMessage(content=[{"type": "text", "text": "Request 0"}]),
AIMessage(content=[{"type": "text", "text": "Response 0"}]),
HumanMessage(content=[{"type": "text", "text": "Request 1"}]),
AIMessage(content=[{"type": "text", "text": "Response 1"}]),
HumanMessage(content=[{"type": "text", "text": "Request 2"}]),
AIMessage(content=[{"type": "text", "text": "Response 2"}]),
HumanMessage(content=[{"type": "text", "text": "Request 3"}]),
AIMessage(content=[{"type": "text", "text": "Response 3"}]),
HumanMessage(content=[{"type": "text", "text": "absolutely not a prompt"}]),
HumanMessage(content=[{"type": "text", "text": "call prompt"}]),
HumanMessage(content=[{"type": "text", "text": "Last request"}]),
Expand Down Expand Up @@ -359,26 +360,26 @@ async def test_conditions(self, context):


class TestFilters:
def test_is_important_filter(self, filter_context):
async def test_is_important_filter(self, filter_context):
filter_func = IsImportant()
ctx = filter_context

# Test filtering important messages
assert filter_func(ctx, ctx.requests[1], ctx.responses[1], llm_model_name="test_model")
assert filter_func(ctx, ctx.requests[2], ctx.responses[2], llm_model_name="test_model")
assert not filter_func(ctx, ctx.requests[3], ctx.responses[3], llm_model_name="test_model")
assert filter_func(ctx, await ctx.requests[1], await ctx.responses[1], llm_model_name="test_model")
assert filter_func(ctx, await ctx.requests[2], await ctx.responses[2], llm_model_name="test_model")
assert not filter_func(ctx, await ctx.requests[3], await ctx.responses[3], llm_model_name="test_model")

assert not filter_func(ctx, None, ctx.responses[1], llm_model_name="test_model")
assert filter_func(ctx, ctx.requests[1], None, llm_model_name="test_model")
assert not filter_func(ctx, None, await ctx.responses[1], llm_model_name="test_model")
assert filter_func(ctx, await ctx.requests[1], None, llm_model_name="test_model")

def test_model_filter(self, filter_context):
async def test_model_filter(self, filter_context):
filter_func = FromModel()
ctx = filter_context
# Test filtering messages from a certain model
assert filter_func(ctx, ctx.requests[1], ctx.responses[1], llm_model_name="test_model")
assert not filter_func(ctx, ctx.requests[2], ctx.responses[2], llm_model_name="test_model")
assert filter_func(ctx, ctx.requests[3], ctx.responses[3], llm_model_name="test_model")
assert filter_func(ctx, ctx.requests[2], ctx.responses[3], llm_model_name="test_model")
assert filter_func(ctx, await ctx.requests[1], await ctx.responses[1], llm_model_name="test_model")
assert not filter_func(ctx, await ctx.requests[2], await ctx.responses[2], llm_model_name="test_model")
assert filter_func(ctx, await ctx.requests[3], await ctx.responses[3], llm_model_name="test_model")
assert filter_func(ctx, await ctx.requests[2], await ctx.responses[3], llm_model_name="test_model")


class TestBaseMethod:
Expand Down Expand Up @@ -408,12 +409,13 @@ async def test_logprob_method(self, filter_context, llmresult):
class TestSlots:
async def test_llm_slot(self, pipeline, context):
slot = LLMSlot(caption="test_caption", model="test_model")
context.current_turn_id = 5
# Test empty request
context.add_request("")
context.requests[5] = ""
assert isinstance(await slot.extract_value(context), SlotNotExtracted)

# Test normal request
context.add_request("test request")
context.requests[5] = "test request"
result = await slot.extract_value(context)
assert isinstance(result, str)

Expand All @@ -425,7 +427,8 @@ async def test_llm_group_slot(self, pipeline, context):
nested=LLMGroupSlot(model="test_model", city=LLMSlot(caption="Extract person's city")),
)

context.add_request("John is 25 years old and lives in New York")
context.current_turn_id = 5
context.requests[5] = "John is 25 years old and lives in New York"
result = await slot.get_value(context)

assert isinstance(result, ExtractedGroupSlot)
Expand Down

0 comments on commit cb3cd70

Please sign in to comment.