Skip to content

Commit

Permalink
⚗️ experiment(converters): develop private API for minimal dataclass …
Browse files Browse the repository at this point in the history
…with a converter
  • Loading branch information
nstarman committed Jan 6, 2025
1 parent feebd0e commit aefa264
Show file tree
Hide file tree
Showing 7 changed files with 311 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ repos:
rev: "v1.13.0"
hooks:
- id: mypy
files: src|tests
files: src
args: []
additional_dependencies:
- pytest
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ ignore = [

[tool.ruff.lint.isort]
combine-as-imports = true
extra-standard-library = ["typing_extensions"]


[tool.pylint]
Expand Down
177 changes: 174 additions & 3 deletions src/dataclassish/_src/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,31 @@
# ruff:noqa: N801
# pylint: disable=C0103

__all__ = ["AbstractConverter", "Optional", "Unless"]
__all__ = [
# Converters
"AbstractConverter",
"Optional",
"Unless",
# Minimal dataclass implementation
"dataclass",
"field",
]

import dataclasses
import functools
import inspect
from abc import ABCMeta, abstractmethod
from collections.abc import Callable
from typing import Any, Generic, TypeVar, cast, overload
from collections.abc import Callable, Hashable, Mapping
from typing import (
Any,
ClassVar,
Generic,
Protocol,
TypeVar,
cast,
overload,
)
from typing_extensions import dataclass_transform

ArgT = TypeVar("ArgT") # Input type
RetT = TypeVar("RetT") # Return type
Expand Down Expand Up @@ -139,3 +158,155 @@ def __call__(self, value: ArgT | PassThroughTs, /) -> RetT | PassThroughTs:
if isinstance(value, self.unconverted_types)
else self.converter(cast(ArgT, value))
)


#####################################################################
# Minimal implementation of a dataclass supporting converters.

_CT = TypeVar("_CT")


def field(
*,
converter: Callable[[Any], Any] | None = None,
metadata: Mapping[Hashable, Any] | None = None,
**kwargs: Any,
) -> Any:
"""Dataclass field with a converter argument.
Parameters
----------
converter : callable, optional
A callable that converts the value of the field. This is added to the
metadata of the field.
metadata : Mapping[Hashable, Any], optional
Additional metadata to add to the field.
See `dataclasses.field` for more information.
**kwargs : Any
Additional keyword arguments to pass to `dataclasses.field`.
"""
if converter is not None:
# Check the converter
if not callable(converter):
msg = f"converter must be callable, got {converter!r}" # type: ignore[unreachable]
raise TypeError(msg)

# Convert the metadata to a mutable dict if it is not None.
metadata = dict(metadata) if metadata is not None else {}

if "converter" in metadata:
msg = "cannot specify 'converter' in metadata and as a keyword argument."
raise ValueError(msg)

# Add the converter to the metadata
metadata["converter"] = converter

return dataclasses.field(metadata=metadata, **kwargs)


class DataclassInstance(Protocol):
__dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]


def _process_dataclass(cls: type[_CT], **kwargs: Any) -> type[_CT]:
# Make the dataclass from the class.
# This does all the usual dataclass stuff.
dcls: type[_CT] = dataclasses.dataclass(cls, **kwargs)

# Compute the signature of the __init__ method
sig = inspect.signature(dcls.__init__)
# Eliminate the 'self' parameter
sig = sig.replace(parameters=list(sig.parameters.values())[1:])
# Store the signature on the __init__ method (Not assigning to __signature__
# because that should have `self`).
dcls.__init__._obj_signature_ = sig # type: ignore[attr-defined]

# Ensure that the __init__ method does conversion
@functools.wraps(dcls.__init__) # give it the same signature
def init(
self: DataclassInstance, *args: Any, _skip_convert: bool = False, **kwargs: Any
) -> None:
# Fast path: no conversion
if _skip_convert:
self.__init__.__wrapped__(self, *args, **kwargs) # type: ignore[misc]
return

# Bind the arguments to the signature
ba = self.__init__._obj_signature_.bind_partial(*args, **kwargs) # type: ignore[misc]
ba.apply_defaults() # so eligible for conversion

# Convert the fields, if there's a converter
for f in dataclasses.fields(self):
k = f.name
if k not in ba.arguments: # mandatory field not provided?!
continue # defer the error to the dataclass __init__

converter = f.metadata.get("converter")
if converter is not None:
ba.arguments[k] = converter(ba.arguments[k])

# Call the original dataclass __init__ method
self.__init__.__wrapped__(self, *ba.args, **ba.kwargs) # type: ignore[misc]

dcls.__init__ = init # type: ignore[assignment, method-assign]

return dcls


@overload
def dataclass(cls: type[_CT], /, **kwargs: Any) -> type[_CT]: ...


@overload
def dataclass(**kwargs: Any) -> Callable[[type[_CT]], type[_CT]]: ...


@dataclass_transform(field_specifiers=(dataclasses.Field, dataclasses.field, field))
def dataclass(
cls: type[_CT] | None = None, /, **kwargs: Any
) -> type[_CT] | Callable[[type[_CT]], type[_CT]]:
"""Make a dataclass, supporting field converters.
For more information about dataclasses see the `dataclasses` module.
Parameters
----------
cls : type | None, optional
The class to transform into a dataclass. If `None`, returns a partial
function that can be used as a decorator.
**kwargs : Any
Additional keyword arguments to pass to `dataclasses.dataclass`.
Examples
--------
>>> from dataclassish.converters import Optional
>>> from dataclassish._src.converters import dataclass, field
>>> @dataclass
... class MyClass:
... attr: int | None = field(default=2.0, converter=Optional(int))
The converter is applied to the default value:
>>> MyClass().attr
2
The converter is applied to the input value:
>>> MyClass(None).attr is None
True
>>> MyClass(1).attr
1
And will work for any input value that the converter can handle, e.g.
``int(str)``:
>>> MyClass("3").attr
3
"""
if cls is None:
return functools.partial(_process_dataclass, **kwargs)
return _process_dataclass(cls, **kwargs)
2 changes: 1 addition & 1 deletion src/dataclassish/_src/flag_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from collections.abc import Callable, Iterable
from dataclasses import Field
from typing import Any, cast
from typing_extensions import Never

from plum import dispatch
from typing_extensions import Never

from .api import (
asdict,
Expand Down
1 change: 0 additions & 1 deletion src/dataclassish/_src/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from dataclasses import dataclass
from typing import Any, ClassVar, Generic, Protocol, TypeVar, runtime_checkable

from typing_extensions import Self


Expand Down
1 change: 1 addition & 0 deletions src/dataclassish/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
__all__ = ["AbstractConverter", "Optional", "Unless"]

from ._src.converters import AbstractConverter, Optional, Unless
# TODO: make dataclass & field public
133 changes: 133 additions & 0 deletions tests/test_converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""Test `dataclass.converters`."""

import pytest

import dataclassish
from dataclassish._src.converters import dataclass, field


def test_abstractconverter_is_abstract():
"""Test `AbstractConverter` is abstract."""
with pytest.raises(TypeError):
_ = dataclassish.converters.AbstractConverter()


def test_optional_object():
"""Test `Optional` as an object."""
converter = dataclassish.converters.Optional(int)

assert converter(None) is None
assert converter(1) == 1
assert converter(1.0) == 1
assert converter("1") == 1


def test_unless_object():
"""Test `Unless` as an object."""
converter = dataclassish.converters.Unless(int, float)

# Integers pass through
assert isinstance(converter(1), int)
assert converter(1) == 1

# Everything else is converted to float
assert isinstance(converter(1.0), float)
assert converter(1.0) == 1.0
assert converter("1") == 1.0
assert converter("1.0") == 1.0


def test_field_not_public():
"""Test `field` is not public."""
assert not hasattr(dataclassish.converters, "field")


def test_field():
"""Test `field`."""
converter = dataclassish.converters.Optional(int)

# Normal usage
f = field(converter=converter)
assert f.metadata["converter"] is converter
assert f.metadata["converter"](None) is None
assert f.metadata["converter"](1.0) == 1

# Non-callable converter
with pytest.raises(TypeError, match="converter must be callable"):
_ = field(converter=1)

# converter also in metadata
with pytest.raises(
ValueError,
match="cannot specify 'converter' in metadata and as a keyword argument.",
):
_ = field(converter=converter, metadata={"converter": 1})

# Converter is None
f = field(converter=None)
assert "converter" not in f.metadata

# Converter is None, metadata is not None
f = field(converter=None, metadata={"converter": converter})
assert f.metadata["converter"] is converter
assert f.metadata["converter"](None) is None
assert f.metadata["converter"](1.0) == 1


def test_dataclass_with_converter():
"""Test `dataclass` with a converter field."""

@dataclass
class MyClass:
attr: int | None = field(
default=2.0, converter=dataclassish.converters.Optional(int)
)

# Test default value conversion
obj = MyClass()
assert obj.attr == 2

# Test input value conversion
obj = MyClass(None)
assert obj.attr is None

obj = MyClass(1)
assert obj.attr == 1

obj = MyClass("3")
assert obj.attr == 3


def test_dataclass_skip_convert():
"""Test `dataclass` with `_skip_convert` flag."""

@dataclass
class MyClass:
attr: int | None = field(
default=2.0, converter=dataclassish.converters.Optional(int)
)

# Test skipping conversion
obj = MyClass("3", _skip_convert=True)
assert obj.attr == "3"


def test_dataclass_with_field_descriptor():
"""Test `dataclass` with a field descriptor."""

class Descriptor:
def __get__(self, instance, owner):
return 42

@dataclass
class MyClass:
attr: int = field(default=0, converter=int)
descriptor: int = Descriptor()

obj = MyClass()
assert obj.attr == 0
assert obj.descriptor == 42

obj = MyClass(attr=1.0)
assert obj.attr == 1
assert obj.descriptor == 42

0 comments on commit aefa264

Please sign in to comment.