Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Adjust min/max items to valid lengths for Set[Enum] fields #567

Merged
merged 8 commits into from
Sep 13, 2024
11 changes: 10 additions & 1 deletion polyfactory/value_generators/constrained_collections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, List, Mapping, TypeVar, cast
from enum import EnumMeta
from typing import TYPE_CHECKING, Any, Callable, List, Literal, Mapping, TypeVar, cast

from polyfactory.exceptions import ParameterException
from polyfactory.field_meta import FieldMeta
Expand Down Expand Up @@ -39,6 +40,14 @@ def handle_constrained_collection(
min_items = abs(min_items if min_items is not None else (max_items or 0))
max_items = abs(max_items if max_items is not None else min_items + 1)

if collection_type in (frozenset, set) or unique_items:
if hasattr(field_meta.annotation, "__origin__") and field_meta.annotation.__origin__ is Literal:
min_items = 1
max_items = 1
guacs marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(field_meta.annotation, EnumMeta):
min_items = min(min_items, len(field_meta.annotation))
max_items = len(field_meta.annotation)

if max_items < min_items:
msg = "max_items must be larger or equal to min_items"
raise ParameterException(msg)
Expand Down
47 changes: 46 additions & 1 deletion tests/test_collection_length.py
guacs marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Any, Dict, List, Optional, Set, Tuple
from enum import Enum
from typing import Any, Dict, FrozenSet, List, Literal, Optional, Set, Tuple

import pytest

from pydantic import BaseModel
from pydantic.dataclasses import dataclass

from polyfactory.factories import DataclassFactory
from polyfactory.factories.pydantic_factory import ModelFactory

MIN_MAX_PARAMETERS = ((10, 15), (20, 25), (30, 40), (40, 50))

Expand Down Expand Up @@ -132,3 +135,45 @@ class FooFactory(DataclassFactory[Foo]):

assert len(foo.foo) >= min_val, len(foo.foo)
assert len(foo.foo) <= max_val, len(foo.foo)


@pytest.mark.parametrize("type_", (List, FrozenSet, Set))
def test_collection_length_with_literal(type_: type) -> None:
@dataclass
class MyModel:
animal_collection: type_[Literal["Dog", "Cat", "Monkey"]] # type: ignore

class MyFactory(DataclassFactory):
__model__ = MyModel
__randomize_collection_length__ = True
__min_collection_length__ = 4
__max_collection_length__ = 5

result = MyFactory.build()
if type_ is List:
assert len(result.animal_collection) >= MyFactory.__min_collection_length__
else:
assert len(result.animal_collection) == 1


@pytest.mark.parametrize("type_", (List, FrozenSet, Set))
def test_collection_length_with_enum(type_: type) -> None:
class Animal(str, Enum):
DOG = "Dog"
CAT = "Cat"
MONKEY = "Monkey"

class MyModel(BaseModel):
animal_collection: type_[Animal] # type: ignore

class MyFactory(ModelFactory):
__model__ = MyModel
__randomize_collection_length__ = True
__min_collection_length__ = len(Animal) + 1
__max_collection_length__ = len(Animal) + 2

result = MyFactory.build()
if type_ is List:
assert len(result.animal_collection) >= MyFactory.__min_collection_length__
else:
assert len(result.animal_collection) == len(Animal)
Loading