Skip to content

Commit

Permalink
Support nested generics with partial
Browse files Browse the repository at this point in the history
  • Loading branch information
mwildehahn committed Nov 22, 2024
1 parent 0c6de0e commit c9be3ba
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 21 deletions.
66 changes: 45 additions & 21 deletions instructor/dsl/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

from jiter import from_json
from pydantic import BaseModel, create_model
from typing import Union
import types
import sys
from pydantic.fields import FieldInfo
from typing import (
Any,
Expand All @@ -29,6 +32,12 @@

T_Model = TypeVar("T_Model", bound=BaseModel)

if sys.version_info >= (3, 10):
# types.UnionType is only available in Python 3.10 and above
UNION_ORIGINS = (Union, types.UnionType)
else:
UNION_ORIGINS = (Union,)


class MakeFieldsOptional:
pass
Expand All @@ -38,6 +47,37 @@ class PartialLiteralMixin:
pass


def _process_generic_arg(
arg: Any,
make_fields_optional: bool = False,
) -> Any:
arg_origin = get_origin(arg)
if arg_origin is not None:
# Handle any nested generic type (Union, List, Dict, etc.)
nested_args = get_args(arg)
modified_nested_args = tuple(
_process_generic_arg(
t,
make_fields_optional=make_fields_optional,
)
for t in nested_args
)
# Special handling for Union types (types.UnionType isn't subscriptable)
if arg_origin in UNION_ORIGINS:
return Union[modified_nested_args] # type: ignore

return arg_origin[modified_nested_args]
else:
if isinstance(arg, type) and issubclass(arg, BaseModel):
return (
Partial[arg, MakeFieldsOptional] # type: ignore[valid-type]
if make_fields_optional
else Partial[arg]
)
else:
return arg


def _make_field_optional(
field: FieldInfo,
) -> tuple[Any, FieldInfo]:
Expand All @@ -51,31 +91,23 @@ def _make_field_optional(
generic_base = get_origin(annotation)
generic_args = get_args(annotation)

# Recursively apply Partial to each of the generic arguments
modified_args = tuple(
(
Partial[arg, MakeFieldsOptional] # type: ignore[valid-type]
if isinstance(arg, type) and issubclass(arg, BaseModel)
else arg
)
for arg in generic_args
_process_generic_arg(arg, make_fields_optional=True) for arg in generic_args
)

# Reconstruct the generic type with modified arguments
tmp_field.annotation = (
Optional[generic_base[modified_args]] if generic_base else None
)
tmp_field.annotation = generic_base[modified_args] if generic_base else None
tmp_field.default = None
# If the field is a BaseModel, then recursively convert it's
# attributes to optionals.
elif isinstance(annotation, type) and issubclass(annotation, BaseModel):
tmp_field.annotation = Optional[Partial[annotation, MakeFieldsOptional]] # type: ignore[assignment, valid-type]
tmp_field.default = {}
else:
tmp_field.annotation = Optional[field.annotation] # type: ignore[assignment]
tmp_field.annotation = Optional[field.annotation]

Check failure on line 107 in instructor/dsl/partial.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.11)

Expected class but received "type[Any] | None"   "None" is not a class (reportGeneralTypeIssues)
tmp_field.default = None

return tmp_field.annotation, tmp_field # type: ignore
return tmp_field.annotation, tmp_field

Check failure on line 110 in instructor/dsl/partial.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.11)

Type of "annotation" is partially unknown   Type of "annotation" is "type[Any] | type[Partial[BaseModel]] | type[None] | Unknown | None" (reportUnknownMemberType)

Check failure on line 110 in instructor/dsl/partial.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.11)

Return type, "tuple[type[Any] | type[Partial[BaseModel]] | type[None] | Unknown | None, FieldInfo]", is partially unknown (reportUnknownVariableType)


class PartialBase(Generic[T_Model]):
Expand Down Expand Up @@ -360,15 +392,7 @@ def _wrap_models(field: FieldInfo) -> tuple[object, FieldInfo]:
generic_base = get_origin(annotation)
generic_args = get_args(annotation)

# Recursively apply Partial to each of the generic arguments
modified_args = tuple(
(
Partial[arg]
if isinstance(arg, type) and issubclass(arg, BaseModel)
else arg
)
for arg in generic_args
)
modified_args = tuple(_process_generic_arg(arg) for arg in generic_args)

# Reconstruct the generic type with modified arguments
tmp_field.annotation = (
Expand Down
25 changes: 25 additions & 0 deletions tests/dsl/test_partial.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# type: ignore[all]
from pydantic import BaseModel, Field
from typing import Optional, Union
from instructor.dsl.partial import Partial, PartialLiteralMixin
import pytest
import instructor
Expand All @@ -20,6 +21,23 @@ class SamplePartial(BaseModel):
b: SampleNestedPartial


class NestedA(BaseModel):
a: str
b: Optional[str]


class NestedB(BaseModel):
c: str
d: str
e: list[Union[str, int]]


class UnionWithNested(BaseModel):
a: list[Union[NestedA, NestedB]]
b: list[NestedA]
c: NestedB


def test_partial():
partial = Partial[SamplePartial]
assert partial.model_json_schema() == {
Expand Down Expand Up @@ -166,3 +184,10 @@ class Summary(BaseModel, PartialLiteralMixin):
previous_summary = extraction.summary

assert updates == 1


def test_union_with_nested():
partial = Partial[UnionWithNested]
partial.get_partial_model().model_validate_json(
'{"a": [{"b": "b"}, {"d": "d"}], "b": [{"b": "b"}], "c": {"d": "d"}, "e": [1, "a"]}'
)

0 comments on commit c9be3ba

Please sign in to comment.