Skip to content

Commit

Permalink
tests[patch]: improve coverage of structured output tests (#29478)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme authored Jan 29, 2025
1 parent c79274c commit 284c935
Showing 1 changed file with 42 additions and 42 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import base64
import json
from typing import List, Optional, cast
from typing import Any, List, Literal, Optional, cast

import httpx
import pytest
Expand Down Expand Up @@ -29,7 +29,9 @@
from langchain_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION


def _get_joke_class() -> type[BaseModel]:
def _get_joke_class(
schema_type: Literal["pydantic", "typeddict", "json_schema"],
) -> Any:
"""
:private:
"""
Expand All @@ -40,7 +42,28 @@ class Joke(BaseModel):
setup: str = Field(description="question to set up a joke")
punchline: str = Field(description="answer to resolve the joke")

return Joke
def validate_joke(result: Any) -> bool:
return isinstance(result, Joke)

class JokeDict(TypedDict):
"""Joke to tell user."""

setup: Annotated[str, ..., "question to set up a joke"]
punchline: Annotated[str, ..., "answer to resolve the joke"]

def validate_joke_dict(result: Any) -> bool:
return all(key in ["setup", "punchline"] for key in result.keys())

if schema_type == "pydantic":
return Joke, validate_joke

elif schema_type == "typeddict":
return JokeDict, validate_joke_dict

elif schema_type == "json_schema":
return Joke.model_json_schema(), validate_joke_dict
else:
raise ValueError("Invalid schema type")


class _MagicFunctionSchema(BaseModel):
Expand Down Expand Up @@ -1151,7 +1174,8 @@ def test_bind_runnables_as_tools(self, model: BaseChatModel) -> None:
assert tool_call["args"].get("answer_style")
assert tool_call["type"] == "tool_call"

def test_structured_output(self, model: BaseChatModel) -> None:
@pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"])
def test_structured_output(self, model: BaseChatModel, schema_type: str) -> None:
"""Test to verify structured output is generated both on invoke and stream.
This test is optional and should be skipped if the model does not support
Expand Down Expand Up @@ -1181,29 +1205,19 @@ def has_tool_calling(self) -> bool:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")

Joke = _get_joke_class()
# Pydantic class
chat = model.with_structured_output(Joke, **self.structured_output_kwargs)
schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type]
chat = model.with_structured_output(schema, **self.structured_output_kwargs)
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, Joke)
validation_function(result)

for chunk in chat.stream("Tell me a joke about cats."):
assert isinstance(chunk, Joke)
validation_function(chunk)
assert chunk

# Schema
chat = model.with_structured_output(
Joke.model_json_schema(), **self.structured_output_kwargs
)
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}

for chunk in chat.stream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}

async def test_structured_output_async(self, model: BaseChatModel) -> None:
@pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"])
async def test_structured_output_async(
self, model: BaseChatModel, schema_type: str
) -> None:
"""Test to verify structured output is generated both on invoke and stream.
This test is optional and should be skipped if the model does not support
Expand Down Expand Up @@ -1233,28 +1247,14 @@ def has_tool_calling(self) -> bool:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")

Joke = _get_joke_class()

# Pydantic class
chat = model.with_structured_output(Joke, **self.structured_output_kwargs)
schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type]
chat = model.with_structured_output(schema, **self.structured_output_kwargs)
result = await chat.ainvoke("Tell me a joke about cats.")
assert isinstance(result, Joke)
validation_function(result)

async for chunk in chat.astream("Tell me a joke about cats."):
assert isinstance(chunk, Joke)

# Schema
chat = model.with_structured_output(
Joke.model_json_schema(), **self.structured_output_kwargs
)
result = await chat.ainvoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}

async for chunk in chat.astream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}
validation_function(chunk)
assert chunk

@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Test requires pydantic 2.")
def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None:
Expand Down

0 comments on commit 284c935

Please sign in to comment.