Skip to content

Commit

Permalink
Move Annotated and Union handling to their own functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mvanderlee committed Jun 18, 2024
1 parent 7eb9281 commit fae55c3
Showing 1 changed file with 43 additions and 15 deletions.
58 changes: 43 additions & 15 deletions marshmallow_dataclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,21 +664,6 @@ def _field_for_generic_type(
# Override base_schema.TYPE_MAPPING to change the class used for generic types below
type_mapping = base_schema.TYPE_MAPPING if base_schema else {}

if origin is Annotated:
marshmallow_annotations = [
arg
for arg in arguments[1:]
if (inspect.isclass(arg) and issubclass(arg, marshmallow.fields.Field))
or isinstance(arg, marshmallow.fields.Field)
]
if marshmallow_annotations:
field = marshmallow_annotations[-1]
# Got a field instance, return as is. User must know what they're doing
if isinstance(field, marshmallow.fields.Field):
return field

return field(**metadata)

if origin in (list, List):
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
list_type = cast(
Expand Down Expand Up @@ -728,6 +713,41 @@ def _field_for_generic_type(
**metadata,
)

return None


def _field_for_annotated_type(
typ: type,
**metadata: Any,
) -> Optional[marshmallow.fields.Field]:
"""
If the type is an Annotated interface, resolve the arguments and construct the appropriate Field.
"""
origin = typing_extensions.get_origin(typ)
arguments = typing_extensions.get_args(typ)
if origin and origin is Annotated:
marshmallow_annotations = [
arg
for arg in arguments[1:]
if (inspect.isclass(arg) and issubclass(arg, marshmallow.fields.Field))
or isinstance(arg, marshmallow.fields.Field)
]
if marshmallow_annotations:
field = marshmallow_annotations[-1]
# Got a field instance, return as is. User must know what they're doing
if isinstance(field, marshmallow.fields.Field):
return field

return field(**metadata)
return None


def _field_for_union_type(
typ: type,
base_schema: Optional[Type[marshmallow.Schema]],
**metadata: Any,
) -> Optional[marshmallow.fields.Field]:
arguments = typing_extensions.get_args(typ)
if typing_inspect.is_union_type(typ):
if typing_inspect.is_optional_type(typ):
metadata["allow_none"] = metadata.get("allow_none", True)
Expand Down Expand Up @@ -887,6 +907,14 @@ def _field_for_schema(
subtyp = Any
return _field_for_schema(subtyp, default, metadata, base_schema)

annotated_field = _field_for_annotated_type(typ, **metadata)
if annotated_field:
return annotated_field

union_field = _field_for_union_type(typ, base_schema, **metadata)
if union_field:
return union_field

# Generic types
generic_field = _field_for_generic_type(typ, base_schema, **metadata)
if generic_field:
Expand Down

0 comments on commit fae55c3

Please sign in to comment.