diff --git a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py index d1c3ab3e6e3d3..6fef3158930a7 100644 --- a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py @@ -1,6 +1,6 @@ import base64 import json -from typing import List, Optional, cast +from typing import Any, List, Literal, Optional, cast import httpx import pytest @@ -29,7 +29,9 @@ from langchain_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION -def _get_joke_class() -> type[BaseModel]: +def _get_joke_class( + schema_type: Literal["pydantic", "typeddict", "json_schema"], +) -> Any: """ :private: """ @@ -40,7 +42,28 @@ class Joke(BaseModel): setup: str = Field(description="question to set up a joke") punchline: str = Field(description="answer to resolve the joke") - return Joke + def validate_joke(result: Any) -> bool: + return isinstance(result, Joke) + + class JokeDict(TypedDict): + """Joke to tell user.""" + + setup: Annotated[str, ..., "question to set up a joke"] + punchline: Annotated[str, ..., "answer to resolve the joke"] + + def validate_joke_dict(result: Any) -> bool: + return all(key in ["setup", "punchline"] for key in result.keys()) + + if schema_type == "pydantic": + return Joke, validate_joke + + elif schema_type == "typeddict": + return JokeDict, validate_joke_dict + + elif schema_type == "json_schema": + return Joke.model_json_schema(), validate_joke_dict + else: + raise ValueError("Invalid schema type") class _MagicFunctionSchema(BaseModel): @@ -1151,7 +1174,8 @@ def test_bind_runnables_as_tools(self, model: BaseChatModel) -> None: assert tool_call["args"].get("answer_style") assert tool_call["type"] == "tool_call" - def test_structured_output(self, model: BaseChatModel) -> None: + @pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"]) + def test_structured_output(self, model: BaseChatModel, schema_type: str) -> None: """Test to verify structured output is generated both on invoke and stream. This test is optional and should be skipped if the model does not support @@ -1181,29 +1205,19 @@ def has_tool_calling(self) -> bool: if not self.has_tool_calling: pytest.skip("Test requires tool calling.") - Joke = _get_joke_class() - # Pydantic class - chat = model.with_structured_output(Joke, **self.structured_output_kwargs) + schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type] + chat = model.with_structured_output(schema, **self.structured_output_kwargs) result = chat.invoke("Tell me a joke about cats.") - assert isinstance(result, Joke) + validation_function(result) for chunk in chat.stream("Tell me a joke about cats."): - assert isinstance(chunk, Joke) + validation_function(chunk) + assert chunk - # Schema - chat = model.with_structured_output( - Joke.model_json_schema(), **self.structured_output_kwargs - ) - result = chat.invoke("Tell me a joke about cats.") - assert isinstance(result, dict) - assert set(result.keys()) == {"setup", "punchline"} - - for chunk in chat.stream("Tell me a joke about cats."): - assert isinstance(chunk, dict) - assert isinstance(chunk, dict) # for mypy - assert set(chunk.keys()) == {"setup", "punchline"} - - async def test_structured_output_async(self, model: BaseChatModel) -> None: + @pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"]) + async def test_structured_output_async( + self, model: BaseChatModel, schema_type: str + ) -> None: """Test to verify structured output is generated both on invoke and stream. This test is optional and should be skipped if the model does not support @@ -1233,28 +1247,14 @@ def has_tool_calling(self) -> bool: if not self.has_tool_calling: pytest.skip("Test requires tool calling.") - Joke = _get_joke_class() - - # Pydantic class - chat = model.with_structured_output(Joke, **self.structured_output_kwargs) + schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type] + chat = model.with_structured_output(schema, **self.structured_output_kwargs) result = await chat.ainvoke("Tell me a joke about cats.") - assert isinstance(result, Joke) + validation_function(result) async for chunk in chat.astream("Tell me a joke about cats."): - assert isinstance(chunk, Joke) - - # Schema - chat = model.with_structured_output( - Joke.model_json_schema(), **self.structured_output_kwargs - ) - result = await chat.ainvoke("Tell me a joke about cats.") - assert isinstance(result, dict) - assert set(result.keys()) == {"setup", "punchline"} - - async for chunk in chat.astream("Tell me a joke about cats."): - assert isinstance(chunk, dict) - assert isinstance(chunk, dict) # for mypy - assert set(chunk.keys()) == {"setup", "punchline"} + validation_function(chunk) + assert chunk @pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Test requires pydantic 2.") def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None: