diff --git a/.gitignore b/.gitignore index cbfc4da..02da417 100644 --- a/.gitignore +++ b/.gitignore @@ -95,6 +95,9 @@ venv.bak/ # Rope project settings .ropeproject +# VSCode project settings +.vscode + # mkdocs documentation /site diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0aa3b11..b1f1ea6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,8 @@ repos: rev: v3.3.1 hooks: - id: pyupgrade - args: ["--py36-plus"] + # I've kept it on py3.7 so that it doesn't replace `Dict` with `dict` + args: ["--py37-plus"] - repo: https://github.com/python/black rev: 23.1.0 hooks: @@ -19,7 +20,7 @@ repos: rev: v1.1.1 hooks: - id: mypy - additional_dependencies: [marshmallow-enum,typeguard,marshmallow] + additional_dependencies: [typeguard,marshmallow] args: [--show-error-codes] - repo: https://github.com/asottile/blacken-docs rev: 1.13.0 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2ec0fb4..7aad941 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -21,3 +21,4 @@ Every commit is checked with pre-commit hooks for : - type safety with [mypy](http://mypy-lang.org/) - test conformance by running [tests](./tests) with [pytest](https://docs.pytest.org/en/latest/) - You can run `pytest` from the command line. + \ No newline at end of file diff --git a/README.md b/README.md index d862f59..9fe6858 100644 --- a/README.md +++ b/README.md @@ -242,7 +242,47 @@ class Sample: See [marshmallow's documentation about extending `Schema`](https://marshmallow.readthedocs.io/en/stable/extending.html). -### Custom NewType declarations +### Custom type aliases + +This library allows you to specify [customized marshmallow fields](https://marshmallow.readthedocs.io/en/stable/custom_fields.html#creating-a-field-class) using python's Annoted type [PEP-593](https://peps.python.org/pep-0593/). + +```python +from typing import Annotated +import marshmallow.fields as mf +import marshmallow.validate as mv + +IPv4 = Annotated[str, mf.String(validate=mv.Regexp(r"^([0-9]{1,3}\\.){3}[0-9]{1,3}$"))] +``` + +You can also pass a marshmallow field class. + +```python +from typing import Annotated +import marshmallow +from marshmallow_dataclass import NewType + +Email = Annotated[str, marshmallow.fields.Email] +``` + +For convenience, some custom types are provided: + +```python +from marshmallow_dataclass.typing import Email, Url +``` + +When using Python 3.8, you must import `Annotated` from the typing_extensions package + +```python +# Version agnostic import code: +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated +``` + +### Custom NewType declarations [__deprecated__] + +> NewType is deprecated in favor or type aliases using Annotated, as described above. This library exports a `NewType` function to create types that generate [customized marshmallow fields](https://marshmallow.readthedocs.io/en/stable/custom_fields.html#creating-a-field-class). @@ -266,12 +306,6 @@ from marshmallow_dataclass import NewType Email = NewType("Email", str, field=marshmallow.fields.Email) ``` -For convenience, some custom types are provided: - -```python -from marshmallow_dataclass.typing import Email, Url -``` - Note: if you are using `mypy`, you will notice that `mypy` throws an error if a variable defined with `NewType` is used in a type annotation. To resolve this, add the `marshmallow_dataclass.mypy` plugin to your `mypy` configuration, e.g.: diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 1fe9813..f82b13a 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -34,6 +34,7 @@ class User: }) Schema: ClassVar[Type[Schema]] = Schema # For the type checker """ + import collections.abc import dataclasses import inspect @@ -47,11 +48,13 @@ class User: Any, Callable, Dict, + FrozenSet, Generic, List, Mapping, NewType as typing_NewType, Optional, + Sequence, Set, Tuple, Type, @@ -60,24 +63,23 @@ class User: cast, get_type_hints, overload, - Sequence, - FrozenSet, ) import marshmallow +import typing_extensions import typing_inspect from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated if sys.version_info >= (3, 11): from typing import dataclass_transform -elif sys.version_info >= (3, 7): - from typing_extensions import dataclass_transform else: - # @dataclass_transform() only helps us with mypy>=1.1 which is only available for python>=3.7 - def dataclass_transform(**kwargs): - return lambda cls: cls + from typing_extensions import dataclass_transform __all__ = ["dataclass", "add_schema", "class_schema", "field_for_schema", "NewType"] @@ -511,7 +513,15 @@ def _internal_class_schema( base_schema: Optional[Type[marshmallow.Schema]] = None, ) -> Type[marshmallow.Schema]: schema_ctx = _schema_ctx_stack.top - schema_ctx.seen_classes[clazz] = clazz.__name__ + + if typing_extensions.get_origin(clazz) is Annotated and sys.version_info < (3, 10): + # https://github.com/python/cpython/blob/3.10/Lib/typing.py#L977 + class_name = clazz._name or clazz.__origin__.__name__ # type: ignore[attr-defined] + else: + class_name = clazz.__name__ + + schema_ctx.seen_classes[clazz] = class_name + try: # noinspection PyDataclass fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz) @@ -546,9 +556,18 @@ def _internal_class_schema( include_non_init = getattr(getattr(clazz, "Meta", None), "include_non_init", False) # Update the schema members to contain marshmallow fields instead of dataclass fields - type_hints = get_type_hints( - clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns - ) + + if sys.version_info >= (3, 9): + type_hints = get_type_hints( + clazz, + globalns=schema_ctx.globalns, + localns=schema_ctx.localns, + include_extras=True, + ) + else: + type_hints = get_type_hints( + clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns + ) attributes.update( ( field.name, @@ -639,8 +658,8 @@ def _field_for_generic_type( """ If the type is a generic interface, resolve the arguments and construct the appropriate Field. """ - origin = typing_inspect.get_origin(typ) - arguments = typing_inspect.get_args(typ, True) + origin = typing_extensions.get_origin(typ) + arguments = typing_extensions.get_args(typ) if origin: # Override base_schema.TYPE_MAPPING to change the class used for generic types below type_mapping = base_schema.TYPE_MAPPING if base_schema else {} @@ -694,6 +713,46 @@ def _field_for_generic_type( **metadata, ) + return None + + +def _field_for_annotated_type( + typ: type, + **metadata: Any, +) -> Optional[marshmallow.fields.Field]: + """ + If the type is an Annotated interface, resolve the arguments and construct the appropriate Field. + """ + origin = typing_extensions.get_origin(typ) + arguments = typing_extensions.get_args(typ) + if origin and origin is Annotated: + marshmallow_annotations = [ + arg + for arg in arguments[1:] + if (inspect.isclass(arg) and issubclass(arg, marshmallow.fields.Field)) + or isinstance(arg, marshmallow.fields.Field) + ] + if marshmallow_annotations: + if len(marshmallow_annotations) > 1: + warnings.warn( + "Multiple marshmallow Field annotations found. Using the last one." + ) + + field = marshmallow_annotations[-1] + # Got a field instance, return as is. User must know what they're doing + if isinstance(field, marshmallow.fields.Field): + return field + + return field(**metadata) + return None + + +def _field_for_union_type( + typ: type, + base_schema: Optional[Type[marshmallow.Schema]], + **metadata: Any, +) -> Optional[marshmallow.fields.Field]: + arguments = typing_extensions.get_args(typ) if typing_inspect.is_union_type(typ): if typing_inspect.is_optional_type(typ): metadata["allow_none"] = metadata.get("allow_none", True) @@ -806,6 +865,7 @@ def _field_for_schema( metadata.setdefault("allow_none", True) return marshmallow.fields.Raw(**metadata) + # i.e.: Literal['abc'] if typing_inspect.is_literal_type(typ): arguments = typing_inspect.get_args(typ) return marshmallow.fields.Raw( @@ -817,6 +877,7 @@ def _field_for_schema( **metadata, ) + # i.e.: Final[str] = 'abc' if typing_inspect.is_final_type(typ): arguments = typing_inspect.get_args(typ) if arguments: @@ -851,6 +912,14 @@ def _field_for_schema( subtyp = Any return _field_for_schema(subtyp, default, metadata, base_schema) + annotated_field = _field_for_annotated_type(typ, **metadata) + if annotated_field: + return annotated_field + + union_field = _field_for_union_type(typ, base_schema, **metadata) + if union_field: + return union_field + # Generic types generic_field = _field_for_generic_type(typ, base_schema, **metadata) if generic_field: @@ -869,14 +938,8 @@ def _field_for_schema( ) # enumerations - if issubclass(typ, Enum): - try: - return marshmallow.fields.Enum(typ, **metadata) - except AttributeError: - # Remove this once support for python 3.6 is dropped. - import marshmallow_enum - - return marshmallow_enum.EnumField(typ, **metadata) + if inspect.isclass(typ) and issubclass(typ, Enum): + return marshmallow.fields.Enum(typ, **metadata) # Nested marshmallow dataclass # it would be just a class name instead of actual schema util the schema is not ready yet @@ -939,7 +1002,8 @@ def NewType( field: Optional[Type[marshmallow.fields.Field]] = None, **kwargs, ) -> Callable[[_U], _U]: - """NewType creates simple unique types + """DEPRECATED: Use typing.Annotated instead. + NewType creates simple unique types to which you can attach custom marshmallow attributes. All the keyword arguments passed to this function will be transmitted to the marshmallow field constructor. diff --git a/marshmallow_dataclass/typing.py b/marshmallow_dataclass/typing.py index 01291eb..4db2f15 100644 --- a/marshmallow_dataclass/typing.py +++ b/marshmallow_dataclass/typing.py @@ -1,8 +1,14 @@ +import sys + import marshmallow.fields -from . import NewType -Url = NewType("Url", str, field=marshmallow.fields.Url) -Email = NewType("Email", str, field=marshmallow.fields.Email) +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated + +Url = Annotated[str, marshmallow.fields.Url] +Email = Annotated[str, marshmallow.fields.Email] # Aliases URL = Url diff --git a/pyproject.toml b/pyproject.toml index 4112bd1..49e683c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,3 +6,9 @@ target-version = ['py36', 'py37', 'py38', 'py39', 'py310', 'py310'] filterwarnings = [ "error:::marshmallow_dataclass|test", ] + +[tool.coverage.report] +exclude_also = [ + '^\s*\.\.\.\s*$', + '^\s*pass\s*$', +] \ No newline at end of file diff --git a/setup.py b/setup.py index 7cb2659..9ab3a24 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup VERSION = "9.0.0" diff --git a/tests/test_annotated.py b/tests/test_annotated.py new file mode 100644 index 0000000..e9105a6 --- /dev/null +++ b/tests/test_annotated.py @@ -0,0 +1,37 @@ +import sys +import unittest +from typing import Optional + +import marshmallow +import marshmallow.fields + +from marshmallow_dataclass import dataclass + +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated + + +class TestAnnotatedField(unittest.TestCase): + def test_annotated_field(self): + @dataclass + class AnnotatedValue: + value: Annotated[str, marshmallow.fields.Email] + default_string: Annotated[ + Optional[str], marshmallow.fields.String(load_default="Default String") + ] = None + + schema = AnnotatedValue.Schema() + + self.assertEqual( + schema.load({"value": "test@test.com"}), + AnnotatedValue(value="test@test.com", default_string="Default String"), + ) + self.assertEqual( + schema.load({"value": "test@test.com", "default_string": "override"}), + AnnotatedValue(value="test@test.com", default_string="override"), + ) + + with self.assertRaises(marshmallow.exceptions.ValidationError): + schema.load({"value": "notavalidemail"})