diff --git a/polyfactory/value_generators/constrained_collections.py b/polyfactory/value_generators/constrained_collections.py index 65614e04..0cbaba10 100644 --- a/polyfactory/value_generators/constrained_collections.py +++ b/polyfactory/value_generators/constrained_collections.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, List, Mapping, TypeVar, cast +from enum import EnumMeta +from typing import TYPE_CHECKING, Any, Callable, List, Literal, Mapping, TypeVar, cast from polyfactory.exceptions import ParameterException from polyfactory.field_meta import FieldMeta @@ -43,6 +44,16 @@ def handle_constrained_collection( msg = "max_items must be larger or equal to min_items" raise ParameterException(msg) + if collection_type in (frozenset, set) or unique_items: + max_field_values = max_items + if hasattr(field_meta.annotation, "__origin__") and field_meta.annotation.__origin__ is Literal: + if field_meta.children is not None: + max_field_values = len(field_meta.children) + elif isinstance(field_meta.annotation, EnumMeta): + max_field_values = len(field_meta.annotation) + min_items = min(min_items, max_field_values) + max_items = min(max_items, max_field_values) + collection: set[T] | list[T] = set() if (collection_type in (frozenset, set) or unique_items) else [] try: diff --git a/tests/test_collection_length.py b/tests/test_collection_length.py index 9b19476b..8b9b3a58 100644 --- a/tests/test_collection_length.py +++ b/tests/test_collection_length.py @@ -1,10 +1,13 @@ -from typing import Any, Dict, List, Optional, Set, Tuple +from enum import Enum +from typing import Any, Dict, FrozenSet, List, Literal, Optional, Set, Tuple, get_args import pytest +from pydantic import BaseModel from pydantic.dataclasses import dataclass from polyfactory.factories import DataclassFactory +from polyfactory.factories.pydantic_factory import ModelFactory MIN_MAX_PARAMETERS = ((10, 15), (20, 25), (30, 40), (40, 50)) @@ -132,3 +135,50 @@ class FooFactory(DataclassFactory[Foo]): assert len(foo.foo) >= min_val, len(foo.foo) assert len(foo.foo) <= max_val, len(foo.foo) + + +@pytest.mark.parametrize("type_", (List, FrozenSet, Set)) +@pytest.mark.parametrize("min_items", (0, 2, 4)) +@pytest.mark.parametrize("max_inc", (0, 1, 4)) +def test_collection_length_with_literal(type_: type, min_items: int, max_inc: int) -> None: + max_items = min_items + max_inc + literal_type = Literal["Dog", "Cat", "Monkey"] + + @dataclass + class MyModel: + animal_collection: type_[literal_type] # type: ignore + + class MyFactory(DataclassFactory): + __model__ = MyModel + __randomize_collection_length__ = True + __min_collection_length__ = min_items + __max_collection_length__ = max_items + + result = MyFactory.build() + assert len(result.animal_collection) >= min(min_items, len(get_args(literal_type))) + assert len(result.animal_collection) <= max_items + + +@pytest.mark.parametrize("type_", (List, FrozenSet, Set)) +@pytest.mark.parametrize("min_items", (0, 2, 4)) +@pytest.mark.parametrize("max_inc", (0, 1, 4)) +def test_collection_length_with_enum(type_: type, min_items: int, max_inc: int) -> None: + max_items = min_items + max_inc + + class Animal(str, Enum): + DOG = "Dog" + CAT = "Cat" + MONKEY = "Monkey" + + class MyModel(BaseModel): + animal_collection: type_[Animal] # type: ignore + + class MyFactory(ModelFactory): + __model__ = MyModel + __randomize_collection_length__ = True + __min_collection_length__ = min_items + __max_collection_length__ = max_items + + result = MyFactory.build() + assert len(result.animal_collection) >= min(min_items, len(Animal)) + assert len(result.animal_collection) <= max_items