Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Adjust min/max items to valid lengths for Set[Enum] fields #567

Merged
merged 8 commits into from
Sep 13, 2024
6 changes: 6 additions & 0 deletions polyfactory/value_generators/constrained_collections.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from enum import EnumMeta
from typing import TYPE_CHECKING, Any, Callable, List, Mapping, TypeVar, cast

from polyfactory.exceptions import ParameterException
Expand Down Expand Up @@ -39,6 +40,11 @@ def handle_constrained_collection(
min_items = abs(min_items if min_items is not None else (max_items or 0))
max_items = abs(max_items if max_items is not None else min_items + 1)

if isinstance(field_meta.annotation, EnumMeta):
adrianeboyd marked this conversation as resolved.
Show resolved Hide resolved
adrianeboyd marked this conversation as resolved.
Show resolved Hide resolved
max_items = len(field_meta.annotation)
if min_items > max_items:
min_items = max_items
adrianeboyd marked this conversation as resolved.
Show resolved Hide resolved

if max_items < min_items:
msg = "max_items must be larger or equal to min_items"
raise ParameterException(msg)
Expand Down
19 changes: 19 additions & 0 deletions tests/test_complex_types.py
adrianeboyd marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,25 @@ class MyFactory(ModelFactory):
assert result.animal_list


def test_complex_typing_with_enum_set() -> None:
class Animal(str, Enum):
DOG = "Dog"
CAT = "Cat"
MONKEY = "Monkey"

class MyModel(BaseModel):
animal_list: Set[Animal]

class MyFactory(ModelFactory):
__model__ = MyModel
__randomize_collection_length__ = True
__min_collection_length__ = len(Animal) + 1
__min_collection_length__ = len(Animal) + 2

result = MyFactory.build()
assert len(result.animal_list) == len(Animal)


def test_union_literal() -> None:
class MyModel(BaseModel):
x: Union[int, Literal["a", "b", "c"]]
Expand Down
Loading