Skip to content

Commit

Permalink
'Refactored by Sourcery'
Browse files Browse the repository at this point in the history
  • Loading branch information
Sourcery AI committed Jun 27, 2023
1 parent afa68c9 commit f8c0296
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 41 deletions.
19 changes: 12 additions & 7 deletions polyfactory/collection_extender.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,17 @@ def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[

@classmethod
def _subclass_for_type(cls, annotation_alias: Any) -> type[CollectionExtender]:
for subclass in cls.__subclasses__():
if any(is_safe_subclass(annotation_alias, t) for t in subclass.__types__):
return subclass
return FallbackExtender
return next(
(
subclass
for subclass in cls.__subclasses__()
if any(
is_safe_subclass(annotation_alias, t)
for t in subclass.__types__
)
),
FallbackExtender,
)

@classmethod
def extend_type_args(
Expand Down Expand Up @@ -68,9 +75,7 @@ class DictExtender(CollectionExtender):

@staticmethod
def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]:
if not type_args:
return type_args
return type_args * number_of_args
return type_args if not type_args else type_args * number_of_args


class FallbackExtender(CollectionExtender):
Expand Down
15 changes: 3 additions & 12 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _create_pydantic_type_map(cls: type[BaseFactory[Any]]) -> dict[type, Callabl
except ImportError:
mapping = {}

try:
with suppress(ImportError):
# v1 only values - these will raise an exception in v2
# in pydantic v2 these are all aliases for Annotated with a constraint.
# we therefore do not need them in v2
Expand Down Expand Up @@ -153,17 +153,11 @@ def _create_pydantic_type_map(cls: type[BaseFactory[Any]]) -> dict[type, Callabl
}
)

except ImportError:
pass

try:
with suppress(ImportError):
# this might be removed by pydantic 2
from pydantic import color

mapping[color.Color] = cls.__faker__.hex_color # pyright: ignore
except ImportError:
pass

return mapping


Expand Down Expand Up @@ -331,10 +325,7 @@ def _handle_factory_field(cls, field_value: Any, field_build_parameters: Any | N
if isinstance(field_value, Fixture):
return field_value.to_value()

if callable(field_value):
return field_value()

return field_value
return field_value() if callable(field_value) else field_value

@classmethod
def _get_or_create_factory(
Expand Down
24 changes: 11 additions & 13 deletions polyfactory/factories/typed_dict_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,18 @@ def get_model_fields(cls) -> list["FieldMeta"]:
:returns: A list of field MetaData instances.
"""
fields_meta: list["FieldMeta"] = []

model_type_hints = get_type_hints(cls.__model__, include_extras=True)

for field_name, annotation in model_type_hints.items():
fields_meta.append(
FieldMeta.from_type(
annotation=annotation,
random=cls.__random__,
name=field_name,
default=getattr(cls.__model__, field_name, Null),
randomize_collection_length=cls.__randomize_collection_length__,
min_collection_length=cls.__min_collection_length__,
max_collection_length=cls.__max_collection_length__,
)
fields_meta: list["FieldMeta"] = [
FieldMeta.from_type(
annotation=annotation,
random=cls.__random__,
name=field_name,
default=getattr(cls.__model__, field_name, Null),
randomize_collection_length=cls.__randomize_collection_length__,
min_collection_length=cls.__min_collection_length__,
max_collection_length=cls.__max_collection_length__,
)
for field_name, annotation in model_type_hints.items()
]
return fields_meta
10 changes: 5 additions & 5 deletions polyfactory/field_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,15 @@ def parse_constraints(cls, metadata: list[Any]) -> "Constraints": # pragma: no
constraints.update(cast("dict[str, Any]", cls.parse_constraints(metadata=inner_metadata)))
elif func := getattr(value, "func", None):
if func is str.islower:
constraints.update({"lower_case": True})
constraints["lower_case"] = True
elif func is str.isupper:
constraints.update({"upper_case": True})
constraints["upper_case"] = True
elif func is str.isascii:
constraints.update({"pattern": "[[:ascii:]]"})
constraints["pattern"] = "[[:ascii:]]"
elif func is str.isdigit:
constraints.update({"pattern": "[[:digit:]]"})
constraints["pattern"] = "[[:digit:]]"
elif is_dataclass(value) and (value_dict := asdict(value)) and ("allowed_schemes" in value_dict):
constraints.update({"url": {k: v for k, v in value_dict.items() if v is not None}})
constraints["url"] = {k: v for k, v in value_dict.items() if v is not None}
else:
constraints.update(
{
Expand Down
6 changes: 2 additions & 4 deletions polyfactory/utils/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def is_dict_key_or_value_type(annotation: Any) -> "TypeGuard[Any]":
:returns: A typeguard.
"""
return str(annotation) == "~KT" or str(annotation) == "~VT"
return str(annotation) in {"~KT", "~VT"}


def is_union(annotation: Any) -> "TypeGuard[Any | Any]":
Expand Down Expand Up @@ -135,6 +135,4 @@ def get_type_origin(annotation: Any) -> Any:
origin = get_origin(annotation)
if origin in (Annotated, Required, NotRequired):
origin = get_args(annotation)[0]
if mapped_type := TYPE_MAPPING.get(origin): # pyright: ignore
return mapped_type
return origin
return mapped_type if (mapped_type := TYPE_MAPPING.get(origin)) else origin

0 comments on commit f8c0296

Please sign in to comment.