From c9be3ba5f2ae5806332136e0467550f26d32aafc Mon Sep 17 00:00:00 2001 From: Michael Hahn Date: Fri, 22 Nov 2024 10:28:28 -0800 Subject: [PATCH] Support nested generics with partial --- instructor/dsl/partial.py | 66 ++++++++++++++++++++++++++------------- tests/dsl/test_partial.py | 25 +++++++++++++++ 2 files changed, 70 insertions(+), 21 deletions(-) diff --git a/instructor/dsl/partial.py b/instructor/dsl/partial.py index d1acc13c3..268e5cc0d 100644 --- a/instructor/dsl/partial.py +++ b/instructor/dsl/partial.py @@ -10,6 +10,9 @@ from jiter import from_json from pydantic import BaseModel, create_model +from typing import Union +import types +import sys from pydantic.fields import FieldInfo from typing import ( Any, @@ -29,6 +32,12 @@ T_Model = TypeVar("T_Model", bound=BaseModel) +if sys.version_info >= (3, 10): + # types.UnionType is only available in Python 3.10 and above + UNION_ORIGINS = (Union, types.UnionType) +else: + UNION_ORIGINS = (Union,) + class MakeFieldsOptional: pass @@ -38,6 +47,37 @@ class PartialLiteralMixin: pass +def _process_generic_arg( + arg: Any, + make_fields_optional: bool = False, +) -> Any: + arg_origin = get_origin(arg) + if arg_origin is not None: + # Handle any nested generic type (Union, List, Dict, etc.) + nested_args = get_args(arg) + modified_nested_args = tuple( + _process_generic_arg( + t, + make_fields_optional=make_fields_optional, + ) + for t in nested_args + ) + # Special handling for Union types (types.UnionType isn't subscriptable) + if arg_origin in UNION_ORIGINS: + return Union[modified_nested_args] # type: ignore + + return arg_origin[modified_nested_args] + else: + if isinstance(arg, type) and issubclass(arg, BaseModel): + return ( + Partial[arg, MakeFieldsOptional] # type: ignore[valid-type] + if make_fields_optional + else Partial[arg] + ) + else: + return arg + + def _make_field_optional( field: FieldInfo, ) -> tuple[Any, FieldInfo]: @@ -51,20 +91,12 @@ def _make_field_optional( generic_base = get_origin(annotation) generic_args = get_args(annotation) - # Recursively apply Partial to each of the generic arguments modified_args = tuple( - ( - Partial[arg, MakeFieldsOptional] # type: ignore[valid-type] - if isinstance(arg, type) and issubclass(arg, BaseModel) - else arg - ) - for arg in generic_args + _process_generic_arg(arg, make_fields_optional=True) for arg in generic_args ) # Reconstruct the generic type with modified arguments - tmp_field.annotation = ( - Optional[generic_base[modified_args]] if generic_base else None - ) + tmp_field.annotation = generic_base[modified_args] if generic_base else None tmp_field.default = None # If the field is a BaseModel, then recursively convert it's # attributes to optionals. @@ -72,10 +104,10 @@ def _make_field_optional( tmp_field.annotation = Optional[Partial[annotation, MakeFieldsOptional]] # type: ignore[assignment, valid-type] tmp_field.default = {} else: - tmp_field.annotation = Optional[field.annotation] # type: ignore[assignment] + tmp_field.annotation = Optional[field.annotation] tmp_field.default = None - return tmp_field.annotation, tmp_field # type: ignore + return tmp_field.annotation, tmp_field class PartialBase(Generic[T_Model]): @@ -360,15 +392,7 @@ def _wrap_models(field: FieldInfo) -> tuple[object, FieldInfo]: generic_base = get_origin(annotation) generic_args = get_args(annotation) - # Recursively apply Partial to each of the generic arguments - modified_args = tuple( - ( - Partial[arg] - if isinstance(arg, type) and issubclass(arg, BaseModel) - else arg - ) - for arg in generic_args - ) + modified_args = tuple(_process_generic_arg(arg) for arg in generic_args) # Reconstruct the generic type with modified arguments tmp_field.annotation = ( diff --git a/tests/dsl/test_partial.py b/tests/dsl/test_partial.py index 197e1a62c..96ba0fe69 100644 --- a/tests/dsl/test_partial.py +++ b/tests/dsl/test_partial.py @@ -1,5 +1,6 @@ # type: ignore[all] from pydantic import BaseModel, Field +from typing import Optional, Union from instructor.dsl.partial import Partial, PartialLiteralMixin import pytest import instructor @@ -20,6 +21,23 @@ class SamplePartial(BaseModel): b: SampleNestedPartial +class NestedA(BaseModel): + a: str + b: Optional[str] + + +class NestedB(BaseModel): + c: str + d: str + e: list[Union[str, int]] + + +class UnionWithNested(BaseModel): + a: list[Union[NestedA, NestedB]] + b: list[NestedA] + c: NestedB + + def test_partial(): partial = Partial[SamplePartial] assert partial.model_json_schema() == { @@ -166,3 +184,10 @@ class Summary(BaseModel, PartialLiteralMixin): previous_summary = extraction.summary assert updates == 1 + + +def test_union_with_nested(): + partial = Partial[UnionWithNested] + partial.get_partial_model().model_validate_json( + '{"a": [{"b": "b"}, {"d": "d"}], "b": [{"b": "b"}], "c": {"d": "d"}, "e": [1, "a"]}' + )