Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests[patch]: improve coverage of structured output tests #29478

Merged
merged 1 commit into from
Jan 29, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import base64
import json
from typing import List, Optional, cast
from typing import Any, List, Literal, Optional, cast

import httpx
import pytest
Expand Down Expand Up @@ -29,7 +29,9 @@
from langchain_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION


def _get_joke_class() -> type[BaseModel]:
def _get_joke_class(
schema_type: Literal["pydantic", "typeddict", "json_schema"],
) -> Any:
"""
:private:
"""
Expand All @@ -40,7 +42,28 @@ class Joke(BaseModel):
setup: str = Field(description="question to set up a joke")
punchline: str = Field(description="answer to resolve the joke")

return Joke
def validate_joke(result: Any) -> bool:
return isinstance(result, Joke)

class JokeDict(TypedDict):
"""Joke to tell user."""

setup: Annotated[str, ..., "question to set up a joke"]
punchline: Annotated[str, ..., "answer to resolve the joke"]

def validate_joke_dict(result: Any) -> bool:
return all(key in ["setup", "punchline"] for key in result.keys())

if schema_type == "pydantic":
return Joke, validate_joke

elif schema_type == "typeddict":
return JokeDict, validate_joke_dict

elif schema_type == "json_schema":
return Joke.model_json_schema(), validate_joke_dict
else:
raise ValueError("Invalid schema type")


class _MagicFunctionSchema(BaseModel):
Expand Down Expand Up @@ -1151,7 +1174,8 @@ def test_bind_runnables_as_tools(self, model: BaseChatModel) -> None:
assert tool_call["args"].get("answer_style")
assert tool_call["type"] == "tool_call"

def test_structured_output(self, model: BaseChatModel) -> None:
@pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"])
def test_structured_output(self, model: BaseChatModel, schema_type: str) -> None:
"""Test to verify structured output is generated both on invoke and stream.

This test is optional and should be skipped if the model does not support
Expand Down Expand Up @@ -1181,29 +1205,19 @@ def has_tool_calling(self) -> bool:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")

Joke = _get_joke_class()
# Pydantic class
chat = model.with_structured_output(Joke, **self.structured_output_kwargs)
schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type]
chat = model.with_structured_output(schema, **self.structured_output_kwargs)
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, Joke)
validation_function(result)

for chunk in chat.stream("Tell me a joke about cats."):
assert isinstance(chunk, Joke)
validation_function(chunk)
assert chunk

# Schema
chat = model.with_structured_output(
Joke.model_json_schema(), **self.structured_output_kwargs
)
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}

for chunk in chat.stream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}

async def test_structured_output_async(self, model: BaseChatModel) -> None:
@pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"])
async def test_structured_output_async(
self, model: BaseChatModel, schema_type: str
) -> None:
"""Test to verify structured output is generated both on invoke and stream.

This test is optional and should be skipped if the model does not support
Expand Down Expand Up @@ -1233,28 +1247,14 @@ def has_tool_calling(self) -> bool:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")

Joke = _get_joke_class()

# Pydantic class
chat = model.with_structured_output(Joke, **self.structured_output_kwargs)
schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type]
chat = model.with_structured_output(schema, **self.structured_output_kwargs)
result = await chat.ainvoke("Tell me a joke about cats.")
assert isinstance(result, Joke)
validation_function(result)

async for chunk in chat.astream("Tell me a joke about cats."):
assert isinstance(chunk, Joke)

# Schema
chat = model.with_structured_output(
Joke.model_json_schema(), **self.structured_output_kwargs
)
result = await chat.ainvoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}

async for chunk in chat.astream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}
validation_function(chunk)
assert chunk

@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Test requires pydantic 2.")
def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None:
Expand Down
Loading