From 07c9c18d8aad7a0f4c0da3e2611867d81eae79e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Mon, 30 Dec 2024 22:20:27 +0100 Subject: [PATCH 01/22] Add CONTEXT context variable --- src/marshmallow/__init__.py | 7 ++++- src/marshmallow/schema.py | 54 ++++++++++++++++++++++++++----------- tests/test_schema.py | 21 +++++++++++++++ 3 files changed, 65 insertions(+), 17 deletions(-) diff --git a/src/marshmallow/__init__.py b/src/marshmallow/__init__.py index ce666e1ed..b4ff0bdbf 100644 --- a/src/marshmallow/__init__.py +++ b/src/marshmallow/__init__.py @@ -14,7 +14,11 @@ validates_schema, ) from marshmallow.exceptions import ValidationError -from marshmallow.schema import Schema, SchemaOpts +from marshmallow.schema import ( + CONTEXT, + Schema, + SchemaOpts, +) from marshmallow.utils import EXCLUDE, INCLUDE, RAISE, missing from . import fields @@ -63,6 +67,7 @@ def __getattr__(name: str) -> typing.Any: __all__ = [ + "CONTEXT", "EXCLUDE", "INCLUDE", "RAISE", diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index aea9ecb5a..83bdc7b6b 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -2,6 +2,8 @@ from __future__ import annotations +import contextlib +import contextvars import copy import datetime as dt import decimal @@ -38,6 +40,19 @@ validate_unknown_parameter_value, ) +CONTEXT = contextvars.ContextVar("context", default=None) + + +class Context(contextlib.AbstractContextManager): + def __init__(self, context): + self.context = context + + def __enter__(self): + self.token = CONTEXT.set(self.context) + + def __exit__(self, *args, **kwargs): + CONTEXT.reset(self.token) + def _get_fields(attrs): """Get fields from a class @@ -501,13 +516,16 @@ def _serialize(self, obj: typing.Any, *, many: bool = False): ret[key] = value return ret - def dump(self, obj: typing.Any, *, many: bool | None = None): + def dump( + self, obj: typing.Any, *, many: bool | None = None, context: typing.Any = None + ): """Serialize an object to native Python data types according to this Schema's fields. :param obj: The object to serialize. :param many: Whether to serialize `obj` as a collection. If `None`, the value for `self.many` is used. + :param context: Optional context used when serializing. :return: Serialized data .. versionadded:: 1.0.0 @@ -518,20 +536,21 @@ def dump(self, obj: typing.Any, *, many: bool | None = None): .. versionchanged:: 3.0.0rc9 Validation no longer occurs upon serialization. """ - many = self.many if many is None else bool(many) - if self._hooks[PRE_DUMP]: - processed_obj = self._invoke_dump_processors( - PRE_DUMP, obj, many=many, original_data=obj - ) - else: - processed_obj = obj + with Context(context) if context is not None else contextlib.nullcontext(): + many = self.many if many is None else bool(many) + if self._hooks[PRE_DUMP]: + processed_obj = self._invoke_dump_processors( + PRE_DUMP, obj, many=many, original_data=obj + ) + else: + processed_obj = obj - result = self._serialize(processed_obj, many=many) + result = self._serialize(processed_obj, many=many) - if self._hooks[POST_DUMP]: - result = self._invoke_dump_processors( - POST_DUMP, result, many=many, original_data=obj - ) + if self._hooks[POST_DUMP]: + result = self._invoke_dump_processors( + POST_DUMP, result, many=many, original_data=obj + ) return result @@ -676,6 +695,7 @@ def load( many: bool | None = None, partial: bool | types.StrSequenceOrSet | None = None, unknown: str | None = None, + context: typing.Any = None, ): """Deserialize a data structure to an object defined by this Schema's fields. @@ -689,6 +709,7 @@ def load( :param unknown: Whether to exclude, include, or raise an error for unknown fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`. If `None`, the value for `self.unknown` is used. + :param context: Optional context used when deserializing. :return: Deserialized data .. versionadded:: 1.0.0 @@ -697,9 +718,10 @@ def load( A :exc:`ValidationError ` is raised if invalid data are passed. """ - return self._do_load( - data, many=many, partial=partial, unknown=unknown, postprocess=True - ) + with Context(context) if context is not None else contextlib.nullcontext(): + return self._do_load( + data, many=many, partial=partial, unknown=unknown, postprocess=True + ) def loads( self, diff --git a/tests/test_schema.py b/tests/test_schema.py index 321d00a55..604be3c89 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -7,6 +7,7 @@ import simplejson as json from marshmallow import ( + CONTEXT, EXCLUDE, INCLUDE, RAISE, @@ -2165,6 +2166,26 @@ def get_is_owner(self, user): class TestContext: + def test_context_load_dump(self): + class ContextField(fields.Integer): + def _serialize(self, value, attr, obj, **kwargs): + value *= CONTEXT.get({}).get("factor", 1) + return super()._serialize(value, attr, obj, **kwargs) + + def _deserialize(self, value, attr, data, **kwargs): + val = super()._deserialize(value, attr, data, **kwargs) + return val * CONTEXT.get({}).get("factor", 1) + + class ContextSchema(Schema): + ctx_fld = ContextField() + + ctx_schema = ContextSchema() + + assert ctx_schema.load({"ctx_fld": 1}) == {"ctx_fld": 1} + assert ctx_schema.load({"ctx_fld": 1}, context={"factor": 2}) == {"ctx_fld": 2} + assert ctx_schema.dump({"ctx_fld": 1}) == {"ctx_fld": 1} + assert ctx_schema.dump({"ctx_fld": 1}, context={"factor": 2}) == {"ctx_fld": 2} + def test_context_method(self): owner = User("Joe") blog = Blog(title="Joe Blog", user=owner) From 4f4e0488daee3f2a263ce3b7951f5c10c3da1c94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Mon, 30 Dec 2024 22:47:32 +0100 Subject: [PATCH 02/22] Expose Context context manager --- src/marshmallow/__init__.py | 2 ++ tests/test_schema.py | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/src/marshmallow/__init__.py b/src/marshmallow/__init__.py index b4ff0bdbf..f5ea026a2 100644 --- a/src/marshmallow/__init__.py +++ b/src/marshmallow/__init__.py @@ -16,6 +16,7 @@ from marshmallow.exceptions import ValidationError from marshmallow.schema import ( CONTEXT, + Context, Schema, SchemaOpts, ) @@ -71,6 +72,7 @@ def __getattr__(name: str) -> typing.Any: "EXCLUDE", "INCLUDE", "RAISE", + "Context", "Schema", "SchemaOpts", "fields", diff --git a/tests/test_schema.py b/tests/test_schema.py index 604be3c89..d47990f07 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -11,6 +11,7 @@ EXCLUDE, INCLUDE, RAISE, + Context, Schema, class_registry, fields, @@ -2185,6 +2186,15 @@ class ContextSchema(Schema): assert ctx_schema.load({"ctx_fld": 1}, context={"factor": 2}) == {"ctx_fld": 2} assert ctx_schema.dump({"ctx_fld": 1}) == {"ctx_fld": 1} assert ctx_schema.dump({"ctx_fld": 1}, context={"factor": 2}) == {"ctx_fld": 2} + with Context({"factor": 3}): + assert ctx_schema.load({"ctx_fld": 1}) == {"ctx_fld": 3} + assert ctx_schema.load({"ctx_fld": 1}, context={"factor": 4}) == { + "ctx_fld": 4 + } + assert ctx_schema.dump({"ctx_fld": 1}) == {"ctx_fld": 3} + assert ctx_schema.dump({"ctx_fld": 1}, context={"factor": 4}) == { + "ctx_fld": 4 + } def test_context_method(self): owner = User("Joe") From a5003b81e27a0594a68b02cb33355c85f51cd3cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Mon, 30 Dec 2024 23:21:30 +0100 Subject: [PATCH 03/22] Remove Schema.context and Field.context --- src/marshmallow/fields.py | 29 ++-------- src/marshmallow/schema.py | 4 -- tests/test_deserialization.py | 18 ++++-- tests/test_schema.py | 101 +++++++++++++--------------------- tests/test_serialization.py | 9 +-- 5 files changed, 60 insertions(+), 101 deletions(-) diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index c14c89cc0..a848c2e0e 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -395,13 +395,6 @@ def _deserialize( """ return value - # Properties - - @property - def context(self): - """The context dictionary for the parent :class:`Schema`.""" - return self.parent.context - class Raw(Field): """Field that applies no formatting.""" @@ -498,8 +491,6 @@ def schema(self): Renamed from `serializer` to `schema`. """ if not self._schema: - # Inherit context from parent. - context = getattr(self.parent, "context", {}) if callable(self.nested) and not isinstance(self.nested, type): nested = self.nested() else: @@ -512,7 +503,6 @@ def schema(self): if isinstance(nested, SchemaABC): self._schema = copy.copy(nested) - self._schema.context.update(context) # Respect only and exclude passed from parent and re-initialize fields set_class = self._schema.set_class if self.only is not None: @@ -539,7 +529,6 @@ def schema(self): many=self.many, only=self.only, exclude=self.exclude, - context=context, load_only=self._nested_normalized_option("load_only"), dump_only=self._nested_normalized_option("dump_only"), ) @@ -1909,14 +1898,12 @@ class Function(Field): :param serialize: A callable from which to retrieve the value. The function must take a single argument ``obj`` which is the object - to be serialized. It can also optionally take a ``context`` argument, - which is a dictionary of context variables passed to the serializer. + to be serialized. If no callable is provided then the ```load_only``` flag will be set to True. :param deserialize: A callable from which to retrieve the value. The function must take a single argument ``value`` which is the value - to be deserialized. It can also optionally take a ``context`` argument, - which is a dictionary of context variables passed to the deserializer. + to be deserialized. If no callable is provided then ```value``` will be passed through unchanged. @@ -1951,21 +1938,13 @@ def __init__( self.deserialize_func = deserialize and utils.callable_or_raise(deserialize) def _serialize(self, value, attr, obj, **kwargs): - return self._call_or_raise(self.serialize_func, obj, attr) + return self.serialize_func(obj) def _deserialize(self, value, attr, data, **kwargs): if self.deserialize_func: - return self._call_or_raise(self.deserialize_func, value, attr) + return self.deserialize_func(value) return value - def _call_or_raise(self, func, value, attr): - if len(utils.get_func_args(func)) > 1: - if self.parent.context is None: - msg = f"No context available for Function field {attr!r}" - raise ValidationError(msg) - return func(value, self.parent.context) - return func(value) - class Constant(Field): """A field that (de)serializes to a preset constant. If you only want the diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index 83bdc7b6b..9303c1811 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -263,8 +263,6 @@ class AlbumSchema(Schema): delimiters. :param many: Should be set to `True` if ``obj`` is a collection so that the object will be serialized to a list. - :param context: Optional context passed to :class:`fields.Method` and - :class:`fields.Function` fields. :param load_only: Fields to skip during serialization (write-only fields) :param dump_only: Fields to skip during deserialization (read-only fields) :param partial: Whether to ignore missing fields and not require @@ -361,7 +359,6 @@ def __init__( only: types.StrSequenceOrSet | None = None, exclude: types.StrSequenceOrSet = (), many: bool | None = None, - context: dict | None = None, load_only: types.StrSequenceOrSet = (), dump_only: types.StrSequenceOrSet = (), partial: bool | types.StrSequenceOrSet | None = None, @@ -388,7 +385,6 @@ def __init__( if unknown is None else validate_unknown_parameter_value(unknown) ) - self.context = context or {} self._normalize_nested_options() #: Dictionary mapping field_names -> :class:`Field` objects self.fields = {} # type: dict[str, ma_fields.Field] diff --git a/tests/test_deserialization.py b/tests/test_deserialization.py index 5c1e6d0de..1e4769a3d 100644 --- a/tests/test_deserialization.py +++ b/tests/test_deserialization.py @@ -7,7 +7,16 @@ import pytest -from marshmallow import EXCLUDE, INCLUDE, RAISE, Schema, fields, validate +from marshmallow import ( + CONTEXT, + EXCLUDE, + INCLUDE, + RAISE, + Context, + Schema, + fields, + validate, +) from marshmallow.exceptions import ValidationError from marshmallow.validate import Equal from tests.base import ( @@ -1000,10 +1009,11 @@ class Parent(Schema): field = fields.Function( lambda x: None, - deserialize=lambda val, context: val.upper() + context["key"], + deserialize=lambda val: val.upper() + CONTEXT.get()["key"], ) - field.parent = Parent(context={"key": "BAR"}) - assert field.deserialize("foo") == "FOOBAR" + field.parent = Parent() + with Context({"key": "BAR"}): + assert field.deserialize("foo") == "FOOBAR" def test_function_field_passed_deserialize_only_is_load_only(self): field = fields.Function(deserialize=lambda val: val.upper()) diff --git a/tests/test_schema.py b/tests/test_schema.py index d47990f07..868f7f364 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,6 +1,7 @@ import datetime as dt import math import random +import typing from collections import OrderedDict, namedtuple import pytest @@ -355,15 +356,13 @@ class NestedSchema(Schema): bar = fields.Str() def on_bind_field(self, field_name, field_obj): - field_obj.metadata["fname"] = self.context["fname"] + assert field_obj.parent is self + field_obj.metadata["fname"] = field_name foo = fields.Nested(NestedSchema) - schema1 = MySchema(context={"fname": "foobar"}) - schema2 = MySchema(context={"fname": "quxquux"}) - - assert schema1.fields["foo"].schema.fields["bar"].metadata["fname"] == "foobar" - assert schema2.fields["foo"].schema.fields["bar"].metadata["fname"] == "quxquux" + schema = MySchema() + assert schema.fields["foo"].schema.fields["bar"].metadata["fname"] == "bar" class TestValidate: @@ -2160,10 +2159,12 @@ class ValidatingSchema(Schema): class UserContextSchema(Schema): is_owner = fields.Method("get_is_owner") - is_collab = fields.Function(lambda user, ctx: user in ctx["blog"]) + is_collab = fields.Function( + lambda user: user in typing.cast(dict, CONTEXT.get())["blog"] + ) def get_is_owner(self, user): - return self.context["blog"].user.name == user.name + return CONTEXT.get()["blog"].user.name == user.name class TestContext: @@ -2199,41 +2200,26 @@ class ContextSchema(Schema): def test_context_method(self): owner = User("Joe") blog = Blog(title="Joe Blog", user=owner) - context = {"blog": blog} serializer = UserContextSchema() - serializer.context = context - data = serializer.dump(owner) - assert data["is_owner"] is True - nonowner = User("Fred") - data = serializer.dump(nonowner) - assert data["is_owner"] is False + with Context({"blog": blog}): + data = serializer.dump(owner) + assert data["is_owner"] is True + nonowner = User("Fred") + data = serializer.dump(nonowner) + assert data["is_owner"] is False def test_context_method_function(self): owner = User("Fred") blog = Blog("Killer Queen", user=owner) collab = User("Brian") blog.collaborators.append(collab) - context = {"blog": blog} - serializer = UserContextSchema() - serializer.context = context - data = serializer.dump(collab) - assert data["is_collab"] is True - noncollab = User("Foo") - data = serializer.dump(noncollab) - assert data["is_collab"] is False - - def test_function_field_raises_error_when_context_not_available(self): - # only has a function field - class UserFunctionContextSchema(Schema): - is_collab = fields.Function(lambda user, ctx: user in ctx["blog"]) - - owner = User("Joe") - serializer = UserFunctionContextSchema() - # no context - serializer.context = None - msg = "No context available for Function field {!r}".format("is_collab") - with pytest.raises(ValidationError, match=msg): - serializer.dump(owner) + with Context({"blog": blog}): + serializer = UserContextSchema() + data = serializer.dump(collab) + assert data["is_collab"] is True + noncollab = User("Foo") + data = serializer.dump(noncollab) + assert data["is_collab"] is False def test_function_field_handles_bound_serializer(self): class SerializeA: @@ -2248,32 +2234,21 @@ class UserFunctionContextSchema(Schema): owner = User("Joe") serializer = UserFunctionContextSchema() - # no context - serializer.context = None data = serializer.dump(owner) assert data["is_collab"] == "value" - def test_fields_context(self): - class CSchema(Schema): - name = fields.String() - - ser = CSchema() - ser.context["foo"] = 42 - - assert ser.fields["name"].context == {"foo": 42} - def test_nested_fields_inherit_context(self): class InnerSchema(Schema): - likes_bikes = fields.Function(lambda obj, ctx: "bikes" in ctx["info"]) + likes_bikes = fields.Function(lambda obj: "bikes" in CONTEXT.get()["info"]) class CSchema(Schema): inner = fields.Nested(InnerSchema) ser = CSchema() - ser.context["info"] = "i like bikes" - obj = {"inner": {}} - result = ser.dump(obj) - assert result["inner"]["likes_bikes"] is True + with Context({"info": "i like bikes"}): + obj = {"inner": {}} + result = ser.dump(obj) + assert result["inner"]["likes_bikes"] is True # Regression test for https://github.com/marshmallow-code/marshmallow/issues/820 def test_nested_list_fields_inherit_context(self): @@ -2282,19 +2257,17 @@ class InnerSchema(Schema): @validates("foo") def validate_foo(self, value): - if "foo_context" not in self.context: + if "foo_context" not in CONTEXT.get(): raise ValidationError("Missing context") class OuterSchema(Schema): bars = fields.List(fields.Nested(InnerSchema())) inner = InnerSchema() - inner.context["foo_context"] = "foo" - assert inner.load({"foo": 42}) + assert inner.load({"foo": 42}, context={"foo_context": "foo"}) outer = OuterSchema() - outer.context["foo_context"] = "foo" - assert outer.load({"bars": [{"foo": 42}]}) + assert outer.load({"bars": [{"foo": 42}]}, context={"foo_context": "foo"}) # Regression test for https://github.com/marshmallow-code/marshmallow/issues/820 def test_nested_dict_fields_inherit_context(self): @@ -2303,19 +2276,19 @@ class InnerSchema(Schema): @validates("foo") def validate_foo(self, value): - if "foo_context" not in self.context: + if "foo_context" not in CONTEXT.get(): raise ValidationError("Missing context") class OuterSchema(Schema): bars = fields.Dict(values=fields.Nested(InnerSchema())) inner = InnerSchema() - inner.context["foo_context"] = "foo" - assert inner.load({"foo": 42}) + assert inner.load({"foo": 42}, context={"foo_context": "foo"}) outer = OuterSchema() - outer.context["foo_context"] = "foo" - assert outer.load({"bars": {"test": {"foo": 42}}}) + assert outer.load( + {"bars": {"test": {"foo": 42}}}, context={"foo_context": "foo"} + ) # Regression test for https://github.com/marshmallow-code/marshmallow/issues/1404 def test_nested_field_with_unpicklable_object_in_context(self): @@ -2327,11 +2300,11 @@ class InnerSchema(Schema): foo = fields.Field() class OuterSchema(Schema): - inner = fields.Nested(InnerSchema(context={"unp": Unpicklable()})) + inner = fields.Nested(InnerSchema()) outer = OuterSchema() obj = {"inner": {"foo": 42}} - assert outer.dump(obj) + assert outer.dump(obj, context={"unp": Unpicklable()}) def test_serializer_can_specify_nested_object_as_attribute(blog): diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 60db524a8..8e3f4c40d 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -10,7 +10,7 @@ import pytest -from marshmallow import Schema, fields +from marshmallow import CONTEXT, Context, Schema, fields from marshmallow import missing as missing_ from tests.base import ALL_FIELDS, DateEnum, GenderEnum, HairColorEnum, User, central @@ -108,10 +108,11 @@ class Parent(Schema): pass field = fields.Function( - serialize=lambda obj, context: obj.name.upper() + context["key"] + serialize=lambda obj: obj.name.upper() + CONTEXT.get()["key"] ) - field.parent = Parent(context={"key": "BAR"}) - assert "FOOBAR" == field.serialize("key", user) + field.parent = Parent() + with Context({"key": "BAR"}): + assert "FOOBAR" == field.serialize("key", user) def test_function_field_passed_uncallable_object(self): with pytest.raises(TypeError): From 7bb901ff8aa136bdc588ab1a3c103dc26db3de59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Tue, 31 Dec 2024 12:50:00 +0100 Subject: [PATCH 04/22] Expose context as field/schema property --- src/marshmallow/__init__.py | 9 ++------ src/marshmallow/context.py | 17 ++++++++++++++ src/marshmallow/fields.py | 23 +++++++++++++++---- src/marshmallow/schema.py | 19 +++++----------- tests/test_deserialization.py | 3 +-- tests/test_schema.py | 42 ++++++++++++++++++++++++++--------- tests/test_serialization.py | 4 ++-- 7 files changed, 77 insertions(+), 40 deletions(-) create mode 100644 src/marshmallow/context.py diff --git a/src/marshmallow/__init__.py b/src/marshmallow/__init__.py index f5ea026a2..6bdce07d0 100644 --- a/src/marshmallow/__init__.py +++ b/src/marshmallow/__init__.py @@ -5,6 +5,7 @@ from packaging.version import Version +from marshmallow.context import Context from marshmallow.decorators import ( post_dump, post_load, @@ -14,12 +15,7 @@ validates_schema, ) from marshmallow.exceptions import ValidationError -from marshmallow.schema import ( - CONTEXT, - Context, - Schema, - SchemaOpts, -) +from marshmallow.schema import Schema, SchemaOpts from marshmallow.utils import EXCLUDE, INCLUDE, RAISE, missing from . import fields @@ -68,7 +64,6 @@ def __getattr__(name: str) -> typing.Any: __all__ = [ - "CONTEXT", "EXCLUDE", "INCLUDE", "RAISE", diff --git a/src/marshmallow/context.py b/src/marshmallow/context.py new file mode 100644 index 000000000..6c1408bc1 --- /dev/null +++ b/src/marshmallow/context.py @@ -0,0 +1,17 @@ +"""Objects related to serializtion/deserialization context""" + +import contextlib +import contextvars + +CONTEXT = contextvars.ContextVar("context", default=None) + + +class Context(contextlib.AbstractContextManager): + def __init__(self, context): + self.context = context + + def __enter__(self): + self.token = CONTEXT.set(self.context) + + def __exit__(self, *args, **kwargs): + CONTEXT.reset(self.token) diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index a848c2e0e..04d55db2e 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -24,6 +24,7 @@ from marshmallow import class_registry, types, utils, validate from marshmallow.base import FieldABC, SchemaABC +from marshmallow.context import CONTEXT from marshmallow.exceptions import ( FieldInstanceResolutionError, StringNotCollectionError, @@ -230,6 +231,10 @@ def __repr__(self) -> str: def __deepcopy__(self, memo): return copy.copy(self) + @property + def context(self) -> typing.Any: + return CONTEXT.get() + def get_value(self, obj, attr, accessor=None, default=missing_): """Return the value for a given key from an object. @@ -1898,12 +1903,14 @@ class Function(Field): :param serialize: A callable from which to retrieve the value. The function must take a single argument ``obj`` which is the object - to be serialized. + to be serialized. It can also optionally take a ``context`` argument, + which is a dictionary of context variables passed to the serializer. If no callable is provided then the ```load_only``` flag will be set to True. :param deserialize: A callable from which to retrieve the value. The function must take a single argument ``value`` which is the value - to be deserialized. + to be deserialized. It can also optionally take a ``context`` argument, + which is a dictionary of context variables passed to the deserializer. If no callable is provided then ```value``` will be passed through unchanged. @@ -1938,13 +1945,21 @@ def __init__( self.deserialize_func = deserialize and utils.callable_or_raise(deserialize) def _serialize(self, value, attr, obj, **kwargs): - return self.serialize_func(obj) + return self._call_or_raise(self.serialize_func, obj, attr) def _deserialize(self, value, attr, data, **kwargs): if self.deserialize_func: - return self.deserialize_func(value) + return self._call_or_raise(self.deserialize_func, value, attr) return value + def _call_or_raise(self, func, value, attr): + if len(utils.get_func_args(func)) > 1: + if CONTEXT.get() is None: + msg = f"No context available for Function field {attr!r}" + raise ValidationError(msg) + return func(value, self.parent.context) + return func(value) + class Constant(Field): """A field that (de)serializes to a preset constant. If you only want the diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index 9303c1811..fff6e4b81 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -3,7 +3,6 @@ from __future__ import annotations import contextlib -import contextvars import copy import datetime as dt import decimal @@ -17,6 +16,7 @@ from marshmallow import base, class_registry, types from marshmallow import fields as ma_fields +from marshmallow.context import CONTEXT, Context from marshmallow.decorators import ( POST_DUMP, POST_LOAD, @@ -40,19 +40,6 @@ validate_unknown_parameter_value, ) -CONTEXT = contextvars.ContextVar("context", default=None) - - -class Context(contextlib.AbstractContextManager): - def __init__(self, context): - self.context = context - - def __enter__(self): - self.token = CONTEXT.set(self.context) - - def __exit__(self, *args, **kwargs): - CONTEXT.reset(self.token) - def _get_fields(attrs): """Get fields from a class @@ -408,6 +395,10 @@ def dict_class(self) -> type[dict]: else: return dict + @property + def context(self) -> typing.Any: + return CONTEXT.get() + @classmethod def from_dict( cls, diff --git a/tests/test_deserialization.py b/tests/test_deserialization.py index 1e4769a3d..ca9b26f03 100644 --- a/tests/test_deserialization.py +++ b/tests/test_deserialization.py @@ -8,7 +8,6 @@ import pytest from marshmallow import ( - CONTEXT, EXCLUDE, INCLUDE, RAISE, @@ -1009,7 +1008,7 @@ class Parent(Schema): field = fields.Function( lambda x: None, - deserialize=lambda val: val.upper() + CONTEXT.get()["key"], + deserialize=lambda val, context: val.upper() + context["key"], ) field.parent = Parent() with Context({"key": "BAR"}): diff --git a/tests/test_schema.py b/tests/test_schema.py index 868f7f364..a0050b734 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,14 +1,12 @@ import datetime as dt import math import random -import typing from collections import OrderedDict, namedtuple import pytest import simplejson as json from marshmallow import ( - CONTEXT, EXCLUDE, INCLUDE, RAISE, @@ -2159,24 +2157,35 @@ class ValidatingSchema(Schema): class UserContextSchema(Schema): is_owner = fields.Method("get_is_owner") - is_collab = fields.Function( - lambda user: user in typing.cast(dict, CONTEXT.get())["blog"] - ) + is_collab = fields.Function(lambda user, ctx: user in ctx["blog"]) def get_is_owner(self, user): - return CONTEXT.get()["blog"].user.name == user.name + return self.context["blog"].user.name == user.name class TestContext: + def test_field_schema_context_properties(self): + class CSchema(Schema): + name = fields.String() + + ser = CSchema() + + with Context({"foo": 42}): + assert ser.context == {"foo": 42} + assert ser.fields["name"].context == {"foo": 42} + def test_context_load_dump(self): class ContextField(fields.Integer): def _serialize(self, value, attr, obj, **kwargs): - value *= CONTEXT.get({}).get("factor", 1) + if self.context is not None: + value *= self.context["factor"] return super()._serialize(value, attr, obj, **kwargs) def _deserialize(self, value, attr, data, **kwargs): val = super()._deserialize(value, attr, data, **kwargs) - return val * CONTEXT.get({}).get("factor", 1) + if self.context is not None: + val *= self.context["factor"] + return val class ContextSchema(Schema): ctx_fld = ContextField() @@ -2221,6 +2230,17 @@ def test_context_method_function(self): data = serializer.dump(noncollab) assert data["is_collab"] is False + def test_function_field_raises_error_when_context_not_available(self): + # only has a function field + class UserFunctionContextSchema(Schema): + is_collab = fields.Function(lambda user, ctx: user in ctx["blog"]) + + owner = User("Joe") + serializer = UserFunctionContextSchema() + msg = "No context available for Function field {!r}".format("is_collab") + with pytest.raises(ValidationError, match=msg): + serializer.dump(owner) + def test_function_field_handles_bound_serializer(self): class SerializeA: def __call__(self, value): @@ -2239,7 +2259,7 @@ class UserFunctionContextSchema(Schema): def test_nested_fields_inherit_context(self): class InnerSchema(Schema): - likes_bikes = fields.Function(lambda obj: "bikes" in CONTEXT.get()["info"]) + likes_bikes = fields.Function(lambda obj, ctx: "bikes" in ctx["info"]) class CSchema(Schema): inner = fields.Nested(InnerSchema) @@ -2257,7 +2277,7 @@ class InnerSchema(Schema): @validates("foo") def validate_foo(self, value): - if "foo_context" not in CONTEXT.get(): + if "foo_context" not in self.context: raise ValidationError("Missing context") class OuterSchema(Schema): @@ -2276,7 +2296,7 @@ class InnerSchema(Schema): @validates("foo") def validate_foo(self, value): - if "foo_context" not in CONTEXT.get(): + if "foo_context" not in self.context: raise ValidationError("Missing context") class OuterSchema(Schema): diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 8e3f4c40d..978073c36 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -10,7 +10,7 @@ import pytest -from marshmallow import CONTEXT, Context, Schema, fields +from marshmallow import Context, Schema, fields from marshmallow import missing as missing_ from tests.base import ALL_FIELDS, DateEnum, GenderEnum, HairColorEnum, User, central @@ -108,7 +108,7 @@ class Parent(Schema): pass field = fields.Function( - serialize=lambda obj: obj.name.upper() + CONTEXT.get()["key"] + serialize=lambda obj, context: obj.name.upper() + context["key"] ) field.parent = Parent() with Context({"key": "BAR"}): From a454a88e8ce9ff563371be46918f597de2e05454 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Wed, 1 Jan 2025 22:02:24 +0100 Subject: [PATCH 05/22] Make current_context a Context class attribute --- src/marshmallow/context.py | 14 +++++++--- src/marshmallow/fields.py | 10 +++---- src/marshmallow/schema.py | 45 +++++++++++-------------------- tests/test_schema.py | 55 ++++++++++++++------------------------ 4 files changed, 49 insertions(+), 75 deletions(-) diff --git a/src/marshmallow/context.py b/src/marshmallow/context.py index 6c1408bc1..5eab172a0 100644 --- a/src/marshmallow/context.py +++ b/src/marshmallow/context.py @@ -3,15 +3,21 @@ import contextlib import contextvars -CONTEXT = contextvars.ContextVar("context", default=None) - class Context(contextlib.AbstractContextManager): + _current_context: contextvars.ContextVar = contextvars.ContextVar( + "context", default=None + ) + def __init__(self, context): self.context = context def __enter__(self): - self.token = CONTEXT.set(self.context) + self.token = self._current_context.set(self.context) def __exit__(self, *args, **kwargs): - CONTEXT.reset(self.token) + self._current_context.reset(self.token) + + @classmethod + def get(cls): + return cls._current_context.get() diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 04d55db2e..e0c2adec1 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -24,7 +24,7 @@ from marshmallow import class_registry, types, utils, validate from marshmallow.base import FieldABC, SchemaABC -from marshmallow.context import CONTEXT +from marshmallow.context import Context from marshmallow.exceptions import ( FieldInstanceResolutionError, StringNotCollectionError, @@ -231,10 +231,6 @@ def __repr__(self) -> str: def __deepcopy__(self, memo): return copy.copy(self) - @property - def context(self) -> typing.Any: - return CONTEXT.get() - def get_value(self, obj, attr, accessor=None, default=missing_): """Return the value for a given key from an object. @@ -1954,10 +1950,10 @@ def _deserialize(self, value, attr, data, **kwargs): def _call_or_raise(self, func, value, attr): if len(utils.get_func_args(func)) > 1: - if CONTEXT.get() is None: + if (context := Context.get()) is None: msg = f"No context available for Function field {attr!r}" raise ValidationError(msg) - return func(value, self.parent.context) + return func(value, context) return func(value) diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index fff6e4b81..9b27410b0 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -2,7 +2,6 @@ from __future__ import annotations -import contextlib import copy import datetime as dt import decimal @@ -16,7 +15,6 @@ from marshmallow import base, class_registry, types from marshmallow import fields as ma_fields -from marshmallow.context import CONTEXT, Context from marshmallow.decorators import ( POST_DUMP, POST_LOAD, @@ -395,10 +393,6 @@ def dict_class(self) -> type[dict]: else: return dict - @property - def context(self) -> typing.Any: - return CONTEXT.get() - @classmethod def from_dict( cls, @@ -503,16 +497,13 @@ def _serialize(self, obj: typing.Any, *, many: bool = False): ret[key] = value return ret - def dump( - self, obj: typing.Any, *, many: bool | None = None, context: typing.Any = None - ): + def dump(self, obj: typing.Any, *, many: bool | None = None): """Serialize an object to native Python data types according to this Schema's fields. :param obj: The object to serialize. :param many: Whether to serialize `obj` as a collection. If `None`, the value for `self.many` is used. - :param context: Optional context used when serializing. :return: Serialized data .. versionadded:: 1.0.0 @@ -523,21 +514,20 @@ def dump( .. versionchanged:: 3.0.0rc9 Validation no longer occurs upon serialization. """ - with Context(context) if context is not None else contextlib.nullcontext(): - many = self.many if many is None else bool(many) - if self._hooks[PRE_DUMP]: - processed_obj = self._invoke_dump_processors( - PRE_DUMP, obj, many=many, original_data=obj - ) - else: - processed_obj = obj + many = self.many if many is None else bool(many) + if self._hooks[PRE_DUMP]: + processed_obj = self._invoke_dump_processors( + PRE_DUMP, obj, many=many, original_data=obj + ) + else: + processed_obj = obj - result = self._serialize(processed_obj, many=many) + result = self._serialize(processed_obj, many=many) - if self._hooks[POST_DUMP]: - result = self._invoke_dump_processors( - POST_DUMP, result, many=many, original_data=obj - ) + if self._hooks[POST_DUMP]: + result = self._invoke_dump_processors( + POST_DUMP, result, many=many, original_data=obj + ) return result @@ -682,7 +672,6 @@ def load( many: bool | None = None, partial: bool | types.StrSequenceOrSet | None = None, unknown: str | None = None, - context: typing.Any = None, ): """Deserialize a data structure to an object defined by this Schema's fields. @@ -696,7 +685,6 @@ def load( :param unknown: Whether to exclude, include, or raise an error for unknown fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`. If `None`, the value for `self.unknown` is used. - :param context: Optional context used when deserializing. :return: Deserialized data .. versionadded:: 1.0.0 @@ -705,10 +693,9 @@ def load( A :exc:`ValidationError ` is raised if invalid data are passed. """ - with Context(context) if context is not None else contextlib.nullcontext(): - return self._do_load( - data, many=many, partial=partial, unknown=unknown, postprocess=True - ) + return self._do_load( + data, many=many, partial=partial, unknown=unknown, postprocess=True + ) def loads( self, diff --git a/tests/test_schema.py b/tests/test_schema.py index a0050b734..c48190ef6 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -2160,31 +2160,21 @@ class UserContextSchema(Schema): is_collab = fields.Function(lambda user, ctx: user in ctx["blog"]) def get_is_owner(self, user): - return self.context["blog"].user.name == user.name + return Context.get()["blog"].user.name == user.name class TestContext: - def test_field_schema_context_properties(self): - class CSchema(Schema): - name = fields.String() - - ser = CSchema() - - with Context({"foo": 42}): - assert ser.context == {"foo": 42} - assert ser.fields["name"].context == {"foo": 42} - def test_context_load_dump(self): class ContextField(fields.Integer): def _serialize(self, value, attr, obj, **kwargs): - if self.context is not None: - value *= self.context["factor"] + if (context := Context.get()) is not None: + value *= context.get("factor", 1) return super()._serialize(value, attr, obj, **kwargs) def _deserialize(self, value, attr, data, **kwargs): val = super()._deserialize(value, attr, data, **kwargs) - if self.context is not None: - val *= self.context["factor"] + if (context := Context.get()) is not None: + val *= context.get("factor", 1) return val class ContextSchema(Schema): @@ -2193,18 +2183,10 @@ class ContextSchema(Schema): ctx_schema = ContextSchema() assert ctx_schema.load({"ctx_fld": 1}) == {"ctx_fld": 1} - assert ctx_schema.load({"ctx_fld": 1}, context={"factor": 2}) == {"ctx_fld": 2} assert ctx_schema.dump({"ctx_fld": 1}) == {"ctx_fld": 1} - assert ctx_schema.dump({"ctx_fld": 1}, context={"factor": 2}) == {"ctx_fld": 2} - with Context({"factor": 3}): - assert ctx_schema.load({"ctx_fld": 1}) == {"ctx_fld": 3} - assert ctx_schema.load({"ctx_fld": 1}, context={"factor": 4}) == { - "ctx_fld": 4 - } - assert ctx_schema.dump({"ctx_fld": 1}) == {"ctx_fld": 3} - assert ctx_schema.dump({"ctx_fld": 1}, context={"factor": 4}) == { - "ctx_fld": 4 - } + with Context({"factor": 2}): + assert ctx_schema.load({"ctx_fld": 1}) == {"ctx_fld": 2} + assert ctx_schema.dump({"ctx_fld": 1}) == {"ctx_fld": 2} def test_context_method(self): owner = User("Joe") @@ -2277,17 +2259,19 @@ class InnerSchema(Schema): @validates("foo") def validate_foo(self, value): - if "foo_context" not in self.context: + if "foo_context" not in Context.get(): raise ValidationError("Missing context") class OuterSchema(Schema): bars = fields.List(fields.Nested(InnerSchema())) inner = InnerSchema() - assert inner.load({"foo": 42}, context={"foo_context": "foo"}) + with Context({"foo_context": "foo"}): + assert inner.load({"foo": 42}) outer = OuterSchema() - assert outer.load({"bars": [{"foo": 42}]}, context={"foo_context": "foo"}) + with Context({"foo_context": "foo"}): + assert outer.load({"bars": [{"foo": 42}]}) # Regression test for https://github.com/marshmallow-code/marshmallow/issues/820 def test_nested_dict_fields_inherit_context(self): @@ -2296,19 +2280,19 @@ class InnerSchema(Schema): @validates("foo") def validate_foo(self, value): - if "foo_context" not in self.context: + if "foo_context" not in Context.get(): raise ValidationError("Missing context") class OuterSchema(Schema): bars = fields.Dict(values=fields.Nested(InnerSchema())) inner = InnerSchema() - assert inner.load({"foo": 42}, context={"foo_context": "foo"}) + with Context({"foo_context": "foo"}): + assert inner.load({"foo": 42}) outer = OuterSchema() - assert outer.load( - {"bars": {"test": {"foo": 42}}}, context={"foo_context": "foo"} - ) + with Context({"foo_context": "foo"}): + assert outer.load({"bars": {"test": {"foo": 42}}}) # Regression test for https://github.com/marshmallow-code/marshmallow/issues/1404 def test_nested_field_with_unpicklable_object_in_context(self): @@ -2324,7 +2308,8 @@ class OuterSchema(Schema): outer = OuterSchema() obj = {"inner": {"foo": 42}} - assert outer.dump(obj, context={"unp": Unpicklable()}) + with Context({"unp": Unpicklable()}): + assert outer.dump(obj) def test_serializer_can_specify_nested_object_as_attribute(blog): From 0f285ede97a63f78538a75a0b2f2c4923ed7c47d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Wed, 1 Jan 2025 22:06:13 +0100 Subject: [PATCH 06/22] Don't provide None as default context --- src/marshmallow/context.py | 4 +--- src/marshmallow/fields.py | 5 +---- tests/test_schema.py | 23 ++++++++++------------- 3 files changed, 12 insertions(+), 20 deletions(-) diff --git a/src/marshmallow/context.py b/src/marshmallow/context.py index 5eab172a0..a21a027c3 100644 --- a/src/marshmallow/context.py +++ b/src/marshmallow/context.py @@ -5,9 +5,7 @@ class Context(contextlib.AbstractContextManager): - _current_context: contextvars.ContextVar = contextvars.ContextVar( - "context", default=None - ) + _current_context: contextvars.ContextVar = contextvars.ContextVar("context") def __init__(self, context): self.context = context diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index e0c2adec1..61410831b 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -1950,10 +1950,7 @@ def _deserialize(self, value, attr, data, **kwargs): def _call_or_raise(self, func, value, attr): if len(utils.get_func_args(func)) > 1: - if (context := Context.get()) is None: - msg = f"No context available for Function field {attr!r}" - raise ValidationError(msg) - return func(value, context) + return func(value, Context.get()) return func(value) diff --git a/tests/test_schema.py b/tests/test_schema.py index c48190ef6..d97cfa831 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -2167,13 +2167,21 @@ class TestContext: def test_context_load_dump(self): class ContextField(fields.Integer): def _serialize(self, value, attr, obj, **kwargs): - if (context := Context.get()) is not None: + try: + context = Context.get() + except LookupError: + pass + else: value *= context.get("factor", 1) return super()._serialize(value, attr, obj, **kwargs) def _deserialize(self, value, attr, data, **kwargs): val = super()._deserialize(value, attr, data, **kwargs) - if (context := Context.get()) is not None: + try: + context = Context.get() + except LookupError: + pass + else: val *= context.get("factor", 1) return val @@ -2212,17 +2220,6 @@ def test_context_method_function(self): data = serializer.dump(noncollab) assert data["is_collab"] is False - def test_function_field_raises_error_when_context_not_available(self): - # only has a function field - class UserFunctionContextSchema(Schema): - is_collab = fields.Function(lambda user, ctx: user in ctx["blog"]) - - owner = User("Joe") - serializer = UserFunctionContextSchema() - msg = "No context available for Function field {!r}".format("is_collab") - with pytest.raises(ValidationError, match=msg): - serializer.dump(owner) - def test_function_field_handles_bound_serializer(self): class SerializeA: def __call__(self, value): From feffc2e805a448b5729d4c82f412ecab365da277 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Wed, 1 Jan 2025 22:08:11 +0100 Subject: [PATCH 07/22] Fix Function field docstring about context --- src/marshmallow/fields.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 61410831b..30eb86914 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -1900,13 +1900,13 @@ class Function(Field): :param serialize: A callable from which to retrieve the value. The function must take a single argument ``obj`` which is the object to be serialized. It can also optionally take a ``context`` argument, - which is a dictionary of context variables passed to the serializer. + which is the value of the current context returned by Context.get(). If no callable is provided then the ```load_only``` flag will be set to True. :param deserialize: A callable from which to retrieve the value. The function must take a single argument ``value`` which is the value to be deserialized. It can also optionally take a ``context`` argument, - which is a dictionary of context variables passed to the deserializer. + which is the value of the current context returned by Context.get(). If no callable is provided then ```value``` will be passed through unchanged. From 6fac57d41cc86fd960cbd92daca9c32181a21aed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Wed, 1 Jan 2025 23:21:39 +0100 Subject: [PATCH 08/22] Allow passing a default to Context.get --- src/marshmallow/context.py | 4 +++- tests/test_schema.py | 12 ++---------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/marshmallow/context.py b/src/marshmallow/context.py index a21a027c3..45c03e49f 100644 --- a/src/marshmallow/context.py +++ b/src/marshmallow/context.py @@ -17,5 +17,7 @@ def __exit__(self, *args, **kwargs): self._current_context.reset(self.token) @classmethod - def get(cls): + def get(cls, default=...): + if default is not ...: + return cls._current_context.get(default) return cls._current_context.get() diff --git a/tests/test_schema.py b/tests/test_schema.py index d97cfa831..d6241c724 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -2167,21 +2167,13 @@ class TestContext: def test_context_load_dump(self): class ContextField(fields.Integer): def _serialize(self, value, attr, obj, **kwargs): - try: - context = Context.get() - except LookupError: - pass - else: + if (context := Context.get(None)) is not None: value *= context.get("factor", 1) return super()._serialize(value, attr, obj, **kwargs) def _deserialize(self, value, attr, data, **kwargs): val = super()._deserialize(value, attr, data, **kwargs) - try: - context = Context.get() - except LookupError: - pass - else: + if (context := Context.get(None)) is not None: val *= context.get("factor", 1) return val From 1a4eec7a548453c27203a0dc455986340f2fc305 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Thu, 2 Jan 2025 00:04:03 +0100 Subject: [PATCH 09/22] Never pass context to functions in Function field --- src/marshmallow/fields.py | 16 ++++------------ tests/test_deserialization.py | 2 +- tests/test_schema.py | 4 ++-- tests/test_serialization.py | 2 +- 4 files changed, 8 insertions(+), 16 deletions(-) diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 30eb86914..a848c2e0e 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -24,7 +24,6 @@ from marshmallow import class_registry, types, utils, validate from marshmallow.base import FieldABC, SchemaABC -from marshmallow.context import Context from marshmallow.exceptions import ( FieldInstanceResolutionError, StringNotCollectionError, @@ -1899,14 +1898,12 @@ class Function(Field): :param serialize: A callable from which to retrieve the value. The function must take a single argument ``obj`` which is the object - to be serialized. It can also optionally take a ``context`` argument, - which is the value of the current context returned by Context.get(). + to be serialized. If no callable is provided then the ```load_only``` flag will be set to True. :param deserialize: A callable from which to retrieve the value. The function must take a single argument ``value`` which is the value - to be deserialized. It can also optionally take a ``context`` argument, - which is the value of the current context returned by Context.get(). + to be deserialized. If no callable is provided then ```value``` will be passed through unchanged. @@ -1941,18 +1938,13 @@ def __init__( self.deserialize_func = deserialize and utils.callable_or_raise(deserialize) def _serialize(self, value, attr, obj, **kwargs): - return self._call_or_raise(self.serialize_func, obj, attr) + return self.serialize_func(obj) def _deserialize(self, value, attr, data, **kwargs): if self.deserialize_func: - return self._call_or_raise(self.deserialize_func, value, attr) + return self.deserialize_func(value) return value - def _call_or_raise(self, func, value, attr): - if len(utils.get_func_args(func)) > 1: - return func(value, Context.get()) - return func(value) - class Constant(Field): """A field that (de)serializes to a preset constant. If you only want the diff --git a/tests/test_deserialization.py b/tests/test_deserialization.py index ca9b26f03..c294fcbde 100644 --- a/tests/test_deserialization.py +++ b/tests/test_deserialization.py @@ -1008,7 +1008,7 @@ class Parent(Schema): field = fields.Function( lambda x: None, - deserialize=lambda val, context: val.upper() + context["key"], + deserialize=lambda val: val.upper() + Context.get()["key"], ) field.parent = Parent() with Context({"key": "BAR"}): diff --git a/tests/test_schema.py b/tests/test_schema.py index d6241c724..5b2083149 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -2157,7 +2157,7 @@ class ValidatingSchema(Schema): class UserContextSchema(Schema): is_owner = fields.Method("get_is_owner") - is_collab = fields.Function(lambda user, ctx: user in ctx["blog"]) + is_collab = fields.Function(lambda user: user in Context.get()["blog"]) def get_is_owner(self, user): return Context.get()["blog"].user.name == user.name @@ -2230,7 +2230,7 @@ class UserFunctionContextSchema(Schema): def test_nested_fields_inherit_context(self): class InnerSchema(Schema): - likes_bikes = fields.Function(lambda obj, ctx: "bikes" in ctx["info"]) + likes_bikes = fields.Function(lambda obj: "bikes" in Context.get()["info"]) class CSchema(Schema): inner = fields.Nested(InnerSchema) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 978073c36..bea0308f2 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -108,7 +108,7 @@ class Parent(Schema): pass field = fields.Function( - serialize=lambda obj, context: obj.name.upper() + context["key"] + serialize=lambda obj: obj.name.upper() + Context.get()["key"] ) field.parent = Parent() with Context({"key": "BAR"}): From 4447c07d89b7060c593bebb3a206688eb1a11511 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Thu, 2 Jan 2025 19:12:35 +0100 Subject: [PATCH 10/22] Remove utils.get_func_args --- src/marshmallow/utils.py | 16 ---------------- tests/test_utils.py | 17 ----------------- 2 files changed, 33 deletions(-) diff --git a/src/marshmallow/utils.py b/src/marshmallow/utils.py index b4202343d..8de415608 100644 --- a/src/marshmallow/utils.py +++ b/src/marshmallow/utils.py @@ -3,7 +3,6 @@ from __future__ import annotations import datetime as dt -import functools import inspect import typing from collections.abc import Mapping @@ -242,21 +241,6 @@ def _signature(func: typing.Callable) -> list[str]: return list(inspect.signature(func).parameters.keys()) -def get_func_args(func: typing.Callable) -> list[str]: - """Given a callable, return a list of argument names. Handles - `functools.partial` objects and class-based callables. - - .. versionchanged:: 3.0.0a1 - Do not return bound arguments, eg. ``self``. - """ - if inspect.isfunction(func) or inspect.ismethod(func): - return _signature(func) - if isinstance(func, functools.partial): - return _signature(func.func) - # Callable class - return _signature(func) - - def resolve_field_instance(cls_or_instance): """Return a Schema instance from a Schema class or instance. diff --git a/tests/test_utils.py b/tests/test_utils.py index 9c4981523..84b812907 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,6 @@ import datetime as dt from collections import namedtuple from copy import copy, deepcopy -from functools import partial import pytest @@ -200,22 +199,6 @@ def test_from_timestamp_with_overflow_value(): utils.from_timestamp(value) -def test_get_func_args(): - def f1(foo, bar): - pass - - f2 = partial(f1, "baz") - - class F3: - def __call__(self, foo, bar): - pass - - f3 = F3() - - for func in [f1, f2, f3]: - assert utils.get_func_args(func) == ["foo", "bar"] - - # Regression test for https://github.com/marshmallow-code/marshmallow/issues/540 def test_function_field_using_type_annotation(): def get_split_words(value: str): # noqa From 72de755adc1b508e90f53eaabab508f28c1ce479 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Thu, 2 Jan 2025 19:16:19 +0100 Subject: [PATCH 11/22] Make _CURRENT_CONTEXT a module-level attribute --- src/marshmallow/context.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/marshmallow/context.py b/src/marshmallow/context.py index 45c03e49f..cb98b5307 100644 --- a/src/marshmallow/context.py +++ b/src/marshmallow/context.py @@ -3,21 +3,21 @@ import contextlib import contextvars +_CURRENT_CONTEXT: contextvars.ContextVar = contextvars.ContextVar("context") -class Context(contextlib.AbstractContextManager): - _current_context: contextvars.ContextVar = contextvars.ContextVar("context") +class Context(contextlib.AbstractContextManager): def __init__(self, context): self.context = context def __enter__(self): - self.token = self._current_context.set(self.context) + self.token = _CURRENT_CONTEXT.set(self.context) def __exit__(self, *args, **kwargs): - self._current_context.reset(self.token) + _CURRENT_CONTEXT.reset(self.token) @classmethod def get(cls, default=...): if default is not ...: - return cls._current_context.get(default) - return cls._current_context.get() + return _CURRENT_CONTEXT.get(default) + return _CURRENT_CONTEXT.get() From 17bd038024bd413cf783c5a52e02662a0edb81b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Thu, 2 Jan 2025 19:51:29 +0100 Subject: [PATCH 12/22] Move Context into experimental --- src/marshmallow/__init__.py | 2 - src/marshmallow/experimental/__init__.py | 5 + src/marshmallow/{ => experimental}/context.py | 0 tests/test_context.py | 173 ++++++++++++++++++ tests/test_deserialization.py | 13 -- tests/test_schema.py | 147 --------------- tests/test_serialization.py | 13 +- 7 files changed, 179 insertions(+), 174 deletions(-) create mode 100644 src/marshmallow/experimental/__init__.py rename src/marshmallow/{ => experimental}/context.py (100%) create mode 100644 tests/test_context.py diff --git a/src/marshmallow/__init__.py b/src/marshmallow/__init__.py index 6bdce07d0..ce666e1ed 100644 --- a/src/marshmallow/__init__.py +++ b/src/marshmallow/__init__.py @@ -5,7 +5,6 @@ from packaging.version import Version -from marshmallow.context import Context from marshmallow.decorators import ( post_dump, post_load, @@ -67,7 +66,6 @@ def __getattr__(name: str) -> typing.Any: "EXCLUDE", "INCLUDE", "RAISE", - "Context", "Schema", "SchemaOpts", "fields", diff --git a/src/marshmallow/experimental/__init__.py b/src/marshmallow/experimental/__init__.py new file mode 100644 index 000000000..b8f6f65ba --- /dev/null +++ b/src/marshmallow/experimental/__init__.py @@ -0,0 +1,5 @@ +"""Experimental features. + +The features in this subpackage are experimental. Breaking changes may be +introduced in minor marshmallow versions. +""" diff --git a/src/marshmallow/context.py b/src/marshmallow/experimental/context.py similarity index 100% rename from src/marshmallow/context.py rename to src/marshmallow/experimental/context.py diff --git a/tests/test_context.py b/tests/test_context.py new file mode 100644 index 000000000..c24d4278f --- /dev/null +++ b/tests/test_context.py @@ -0,0 +1,173 @@ +from marshmallow import Schema, fields, validates +from marshmallow.exceptions import ValidationError +from marshmallow.experimental.context import Context +from tests.base import Blog, User + + +class UserContextSchema(Schema): + is_owner = fields.Method("get_is_owner") + is_collab = fields.Function(lambda user: user in Context.get()["blog"]) + + def get_is_owner(self, user): + return Context.get()["blog"].user.name == user.name + + +class TestContext: + def test_context_load_dump(self): + class ContextField(fields.Integer): + def _serialize(self, value, attr, obj, **kwargs): + if (context := Context.get(None)) is not None: + value *= context.get("factor", 1) + return super()._serialize(value, attr, obj, **kwargs) + + def _deserialize(self, value, attr, data, **kwargs): + val = super()._deserialize(value, attr, data, **kwargs) + if (context := Context.get(None)) is not None: + val *= context.get("factor", 1) + return val + + class ContextSchema(Schema): + ctx_fld = ContextField() + + ctx_schema = ContextSchema() + + assert ctx_schema.load({"ctx_fld": 1}) == {"ctx_fld": 1} + assert ctx_schema.dump({"ctx_fld": 1}) == {"ctx_fld": 1} + with Context({"factor": 2}): + assert ctx_schema.load({"ctx_fld": 1}) == {"ctx_fld": 2} + assert ctx_schema.dump({"ctx_fld": 1}) == {"ctx_fld": 2} + + def test_context_method(self): + owner = User("Joe") + blog = Blog(title="Joe Blog", user=owner) + serializer = UserContextSchema() + with Context({"blog": blog}): + data = serializer.dump(owner) + assert data["is_owner"] is True + nonowner = User("Fred") + data = serializer.dump(nonowner) + assert data["is_owner"] is False + + def test_context_function(self): + owner = User("Fred") + blog = Blog("Killer Queen", user=owner) + collab = User("Brian") + blog.collaborators.append(collab) + with Context({"blog": blog}): + serializer = UserContextSchema() + data = serializer.dump(collab) + assert data["is_collab"] is True + noncollab = User("Foo") + data = serializer.dump(noncollab) + assert data["is_collab"] is False + + def test_function_field_handles_bound_serializer(self): + class SerializeA: + def __call__(self, value): + return "value" + + serialize = SerializeA() + + # only has a function field + class UserFunctionContextSchema(Schema): + is_collab = fields.Function(serialize) + + owner = User("Joe") + serializer = UserFunctionContextSchema() + data = serializer.dump(owner) + assert data["is_collab"] == "value" + + def test_nested_fields_inherit_context(self): + class InnerSchema(Schema): + likes_bikes = fields.Function(lambda obj: "bikes" in Context.get()["info"]) + + class CSchema(Schema): + inner = fields.Nested(InnerSchema) + + ser = CSchema() + with Context({"info": "i like bikes"}): + obj = {"inner": {}} + result = ser.dump(obj) + assert result["inner"]["likes_bikes"] is True + + # Regression test for https://github.com/marshmallow-code/marshmallow/issues/820 + def test_nested_list_fields_inherit_context(self): + class InnerSchema(Schema): + foo = fields.Field() + + @validates("foo") + def validate_foo(self, value): + if "foo_context" not in Context.get(): + raise ValidationError("Missing context") + + class OuterSchema(Schema): + bars = fields.List(fields.Nested(InnerSchema())) + + inner = InnerSchema() + with Context({"foo_context": "foo"}): + assert inner.load({"foo": 42}) + + outer = OuterSchema() + with Context({"foo_context": "foo"}): + assert outer.load({"bars": [{"foo": 42}]}) + + # Regression test for https://github.com/marshmallow-code/marshmallow/issues/820 + def test_nested_dict_fields_inherit_context(self): + class InnerSchema(Schema): + foo = fields.Field() + + @validates("foo") + def validate_foo(self, value): + if "foo_context" not in Context.get(): + raise ValidationError("Missing context") + + class OuterSchema(Schema): + bars = fields.Dict(values=fields.Nested(InnerSchema())) + + inner = InnerSchema() + with Context({"foo_context": "foo"}): + assert inner.load({"foo": 42}) + + outer = OuterSchema() + with Context({"foo_context": "foo"}): + assert outer.load({"bars": {"test": {"foo": 42}}}) + + # Regression test for https://github.com/marshmallow-code/marshmallow/issues/1404 + def test_nested_field_with_unpicklable_object_in_context(self): + class Unpicklable: + def __deepcopy__(self, _): + raise NotImplementedError + + class InnerSchema(Schema): + foo = fields.Field() + + class OuterSchema(Schema): + inner = fields.Nested(InnerSchema()) + + outer = OuterSchema() + obj = {"inner": {"foo": 42}} + with Context({"unp": Unpicklable()}): + assert outer.dump(obj) + + def test_function_field_passed_serialize_with_context(self, user): + class Parent(Schema): + pass + + field = fields.Function( + serialize=lambda obj: obj.name.upper() + Context.get()["key"] + ) + field.parent = Parent() + with Context({"key": "BAR"}): + assert field.serialize("key", user) == "MONTYBAR" + + def test_function_field_deserialization_with_context(self): + class Parent(Schema): + pass + + field = fields.Function( + lambda x: None, + deserialize=lambda val: val.upper() + Context.get()["key"], + ) + field.parent = Parent() + with Context({"key": "BAR"}): + assert field.deserialize("foo") == "FOOBAR" diff --git a/tests/test_deserialization.py b/tests/test_deserialization.py index c294fcbde..6acd3b751 100644 --- a/tests/test_deserialization.py +++ b/tests/test_deserialization.py @@ -11,7 +11,6 @@ EXCLUDE, INCLUDE, RAISE, - Context, Schema, fields, validate, @@ -1002,18 +1001,6 @@ def test_function_field_deserialization_with_callable(self): field = fields.Function(lambda x: None, deserialize=lambda val: val.upper()) assert field.deserialize("foo") == "FOO" - def test_function_field_deserialization_with_context(self): - class Parent(Schema): - pass - - field = fields.Function( - lambda x: None, - deserialize=lambda val: val.upper() + Context.get()["key"], - ) - field.parent = Parent() - with Context({"key": "BAR"}): - assert field.deserialize("foo") == "FOOBAR" - def test_function_field_passed_deserialize_only_is_load_only(self): field = fields.Function(deserialize=lambda val: val.upper()) assert field.load_only is True diff --git a/tests/test_schema.py b/tests/test_schema.py index 5b2083149..2ad5d565c 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -10,7 +10,6 @@ EXCLUDE, INCLUDE, RAISE, - Context, Schema, class_registry, fields, @@ -2155,152 +2154,6 @@ class ValidatingSchema(Schema): assert "Color must be red or blue" in errors["color"] -class UserContextSchema(Schema): - is_owner = fields.Method("get_is_owner") - is_collab = fields.Function(lambda user: user in Context.get()["blog"]) - - def get_is_owner(self, user): - return Context.get()["blog"].user.name == user.name - - -class TestContext: - def test_context_load_dump(self): - class ContextField(fields.Integer): - def _serialize(self, value, attr, obj, **kwargs): - if (context := Context.get(None)) is not None: - value *= context.get("factor", 1) - return super()._serialize(value, attr, obj, **kwargs) - - def _deserialize(self, value, attr, data, **kwargs): - val = super()._deserialize(value, attr, data, **kwargs) - if (context := Context.get(None)) is not None: - val *= context.get("factor", 1) - return val - - class ContextSchema(Schema): - ctx_fld = ContextField() - - ctx_schema = ContextSchema() - - assert ctx_schema.load({"ctx_fld": 1}) == {"ctx_fld": 1} - assert ctx_schema.dump({"ctx_fld": 1}) == {"ctx_fld": 1} - with Context({"factor": 2}): - assert ctx_schema.load({"ctx_fld": 1}) == {"ctx_fld": 2} - assert ctx_schema.dump({"ctx_fld": 1}) == {"ctx_fld": 2} - - def test_context_method(self): - owner = User("Joe") - blog = Blog(title="Joe Blog", user=owner) - serializer = UserContextSchema() - with Context({"blog": blog}): - data = serializer.dump(owner) - assert data["is_owner"] is True - nonowner = User("Fred") - data = serializer.dump(nonowner) - assert data["is_owner"] is False - - def test_context_method_function(self): - owner = User("Fred") - blog = Blog("Killer Queen", user=owner) - collab = User("Brian") - blog.collaborators.append(collab) - with Context({"blog": blog}): - serializer = UserContextSchema() - data = serializer.dump(collab) - assert data["is_collab"] is True - noncollab = User("Foo") - data = serializer.dump(noncollab) - assert data["is_collab"] is False - - def test_function_field_handles_bound_serializer(self): - class SerializeA: - def __call__(self, value): - return "value" - - serialize = SerializeA() - - # only has a function field - class UserFunctionContextSchema(Schema): - is_collab = fields.Function(serialize) - - owner = User("Joe") - serializer = UserFunctionContextSchema() - data = serializer.dump(owner) - assert data["is_collab"] == "value" - - def test_nested_fields_inherit_context(self): - class InnerSchema(Schema): - likes_bikes = fields.Function(lambda obj: "bikes" in Context.get()["info"]) - - class CSchema(Schema): - inner = fields.Nested(InnerSchema) - - ser = CSchema() - with Context({"info": "i like bikes"}): - obj = {"inner": {}} - result = ser.dump(obj) - assert result["inner"]["likes_bikes"] is True - - # Regression test for https://github.com/marshmallow-code/marshmallow/issues/820 - def test_nested_list_fields_inherit_context(self): - class InnerSchema(Schema): - foo = fields.Field() - - @validates("foo") - def validate_foo(self, value): - if "foo_context" not in Context.get(): - raise ValidationError("Missing context") - - class OuterSchema(Schema): - bars = fields.List(fields.Nested(InnerSchema())) - - inner = InnerSchema() - with Context({"foo_context": "foo"}): - assert inner.load({"foo": 42}) - - outer = OuterSchema() - with Context({"foo_context": "foo"}): - assert outer.load({"bars": [{"foo": 42}]}) - - # Regression test for https://github.com/marshmallow-code/marshmallow/issues/820 - def test_nested_dict_fields_inherit_context(self): - class InnerSchema(Schema): - foo = fields.Field() - - @validates("foo") - def validate_foo(self, value): - if "foo_context" not in Context.get(): - raise ValidationError("Missing context") - - class OuterSchema(Schema): - bars = fields.Dict(values=fields.Nested(InnerSchema())) - - inner = InnerSchema() - with Context({"foo_context": "foo"}): - assert inner.load({"foo": 42}) - - outer = OuterSchema() - with Context({"foo_context": "foo"}): - assert outer.load({"bars": {"test": {"foo": 42}}}) - - # Regression test for https://github.com/marshmallow-code/marshmallow/issues/1404 - def test_nested_field_with_unpicklable_object_in_context(self): - class Unpicklable: - def __deepcopy__(self, _): - raise NotImplementedError - - class InnerSchema(Schema): - foo = fields.Field() - - class OuterSchema(Schema): - inner = fields.Nested(InnerSchema()) - - outer = OuterSchema() - obj = {"inner": {"foo": 42}} - with Context({"unp": Unpicklable()}): - assert outer.dump(obj) - - def test_serializer_can_specify_nested_object_as_attribute(blog): class BlogUsernameSchema(Schema): author_name = fields.String(attribute="user.name") diff --git a/tests/test_serialization.py b/tests/test_serialization.py index bea0308f2..9dd7fbcc2 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -10,7 +10,7 @@ import pytest -from marshmallow import Context, Schema, fields +from marshmallow import Schema, fields from marshmallow import missing as missing_ from tests.base import ALL_FIELDS, DateEnum, GenderEnum, HairColorEnum, User, central @@ -103,17 +103,6 @@ def test_function_field_load_only(self): field = fields.Function(deserialize=lambda obj: None) assert field.load_only - def test_function_field_passed_serialize_with_context(self, user, monkeypatch): - class Parent(Schema): - pass - - field = fields.Function( - serialize=lambda obj: obj.name.upper() + Context.get()["key"] - ) - field.parent = Parent() - with Context({"key": "BAR"}): - assert "FOOBAR" == field.serialize("key", user) - def test_function_field_passed_uncallable_object(self): with pytest.raises(TypeError): fields.Function("uncallable") From 63abfc11ef5c0d8c1aeb007f4256b8b8a28f19ab Mon Sep 17 00:00:00 2001 From: Steven Loria Date: Thu, 2 Jan 2025 17:08:06 -0500 Subject: [PATCH 13/22] Add typing to context.py --- src/marshmallow/experimental/context.py | 14 ++++++++------ tests/test_context.py | 6 +++++- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/marshmallow/experimental/context.py b/src/marshmallow/experimental/context.py index cb98b5307..c7758448c 100644 --- a/src/marshmallow/experimental/context.py +++ b/src/marshmallow/experimental/context.py @@ -1,23 +1,25 @@ -"""Objects related to serializtion/deserialization context""" +"""Helper API for setting serialization/deserialization context.""" import contextlib import contextvars +import typing +_T = typing.TypeVar("_T") _CURRENT_CONTEXT: contextvars.ContextVar = contextvars.ContextVar("context") -class Context(contextlib.AbstractContextManager): - def __init__(self, context): +class Context(contextlib.AbstractContextManager, typing.Generic[_T]): + def __init__(self, context: _T) -> None: self.context = context - def __enter__(self): + def __enter__(self) -> None: self.token = _CURRENT_CONTEXT.set(self.context) - def __exit__(self, *args, **kwargs): + def __exit__(self, *args, **kwargs) -> None: _CURRENT_CONTEXT.reset(self.token) @classmethod - def get(cls, default=...): + def get(cls, default=...) -> _T: if default is not ...: return _CURRENT_CONTEXT.get(default) return _CURRENT_CONTEXT.get() diff --git a/tests/test_context.py b/tests/test_context.py index c24d4278f..c833c7510 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -1,3 +1,5 @@ +import typing + from marshmallow import Schema, fields, validates from marshmallow.exceptions import ValidationError from marshmallow.experimental.context import Context @@ -6,7 +8,9 @@ class UserContextSchema(Schema): is_owner = fields.Method("get_is_owner") - is_collab = fields.Function(lambda user: user in Context.get()["blog"]) + is_collab = fields.Function( + lambda user: user in Context[dict[str, typing.Any]].get()["blog"] + ) def get_is_owner(self, user): return Context.get()["blog"].user.name == user.name From c6c4e889c7cc1c98a546573f1c7c59337acd67f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Sat, 4 Jan 2025 00:41:19 +0100 Subject: [PATCH 14/22] Add tests for decorated processors with context --- tests/test_context.py | 74 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/tests/test_context.py b/tests/test_context.py index c833c7510..53fa83476 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -1,6 +1,17 @@ import typing -from marshmallow import Schema, fields, validates +import pytest + +from marshmallow import ( + Schema, + fields, + post_dump, + post_load, + pre_dump, + pre_load, + validates, + validates_schema, +) from marshmallow.exceptions import ValidationError from marshmallow.experimental.context import Context from tests.base import Blog, User @@ -175,3 +186,64 @@ class Parent(Schema): field.parent = Parent() with Context({"key": "BAR"}): assert field.deserialize("foo") == "FOOBAR" + + def test_decorated_processors_with_context(self): + class MySchema(Schema): + f_1 = fields.Integer() + f_2 = fields.Integer() + f_3 = fields.Integer() + f_4 = fields.Integer() + + @pre_dump + def multiply_f_1(self, item, **kwargs): + item["f_1"] *= Context.get()[1] + return item + + @pre_load + def multiply_f_2(self, data, **kwargs): + data["f_2"] *= Context.get()[2] + return data + + @post_dump + def multiply_f_3(self, item, **kwargs): + item["f_3"] *= Context.get()[3] + return item + + @post_load + def multiply_f_4(self, data, **kwargs): + data["f_4"] *= Context.get()[4] + return data + + schema = MySchema() + + with Context({1: 2, 2: 3, 3: 4, 4: 5}): + assert schema.dump({"f_1": 1, "f_2": 1, "f_3": 1, "f_4": 1}) == { + "f_1": 2, + "f_2": 1, + "f_3": 4, + "f_4": 1, + } + assert schema.load({"f_1": 1, "f_2": 1, "f_3": 1, "f_4": 1}) == { + "f_1": 1, + "f_2": 3, + "f_3": 1, + "f_4": 5, + } + + def test_validates_schema_with_context(self): + class MySchema(Schema): + f_1 = fields.Integer() + f_2 = fields.Integer() + + @validates_schema + def validate_schema(self, data, **kwargs): + if data["f_2"] != data["f_1"] * Context.get(): + raise ValidationError("Fail") + + schema = MySchema() + + with Context(2): + schema.load({"f_1": 1, "f_2": 2}) + with pytest.raises(ValidationError) as excinfo: + schema.load({"f_1": 1, "f_2": 3}) + assert excinfo.value.messages["_schema"] == ["Fail"] From 947de5179ff4e7ecc7bdeb1388c4efe7e74776aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Sat, 4 Jan 2025 01:13:56 +0100 Subject: [PATCH 15/22] Update documentation about removal of context --- docs/custom_fields.rst | 34 --------------------------------- docs/extending.rst | 13 ------------- docs/upgrading.rst | 43 ++++++++++++++++++++++++++++++++++++++++++ docs/why.rst | 33 -------------------------------- 4 files changed, 43 insertions(+), 80 deletions(-) diff --git a/docs/custom_fields.rst b/docs/custom_fields.rst index b9110080d..e829f4953 100644 --- a/docs/custom_fields.rst +++ b/docs/custom_fields.rst @@ -95,40 +95,6 @@ Both :class:`Function ` and :class:`Method 100.0 -.. _adding-context: - -Adding Context to `Method` and `Function` Fields ------------------------------------------------- - -A :class:`Function ` or :class:`Method ` field may need information about its environment to know how to serialize a value. - -In these cases, you can set the ``context`` attribute (a dictionary) of a `Schema`. :class:`Function ` and :class:`Method ` fields will have access to this dictionary. - -As an example, you might want your ``UserSchema`` to output whether or not a ``User`` is the author of a ``Blog`` or whether a certain word appears in a ``Blog's`` title. - -.. code-block:: python - - class UserSchema(Schema): - name = fields.String() - # Function fields optionally receive context argument - is_author = fields.Function(lambda user, context: user == context["blog"].author) - likes_bikes = fields.Method("writes_about_bikes") - - def writes_about_bikes(self, user): - return "bicycle" in self.context["blog"].title.lower() - - - schema = UserSchema() - - user = User("Freddie Mercury", "fred@queen.com") - blog = Blog("Bicycle Blog", author=user) - - schema.context = {"blog": blog} - result = schema.dump(user) - result["is_author"] # => True - result["likes_bikes"] # => True - - Customizing Error Messages -------------------------- diff --git a/docs/extending.rst b/docs/extending.rst index 582e07f3e..bce86c182 100644 --- a/docs/extending.rst +++ b/docs/extending.rst @@ -461,19 +461,6 @@ Our application schemas can now inherit from our custom schema class. result = ser.dump(user) result # {"user": {"name": "Keith", "email": "keith@stones.com"}} -Using Context -------------- - -The ``context`` attribute of a `Schema` is a general-purpose store for extra information that may be needed for (de)serialization. It may be used in both ``Schema`` and ``Field`` methods. - -.. code-block:: python - - schema = UserSchema() - # Make current HTTP request available to - # custom fields, schema methods, schema validators, etc. - schema.context["request"] = request - schema.dump(user) - Custom Error Messages --------------------- diff --git a/docs/upgrading.rst b/docs/upgrading.rst index 36ea0a0b1..abe758134 100644 --- a/docs/upgrading.rst +++ b/docs/upgrading.rst @@ -58,6 +58,49 @@ If you want to use anonymous functions, you can use this helper function. password = fields.String(validate=predicate(lambda x: x == "password")) +Schema Context is Removed +************************* + +The feature allowing to pass a context to the schema has been removed. Users should +use `contextvars` for that. + +marshmallow 4.0 provides an experimental `Context ` +manager class that can be used both to set and retrieve the context. + +.. code-block:: python + + # 3.x + from marshmallow import Schema, fields + + + class UserSchema(Schema): + name = fields.Function( + serialize=lambda obj, context: obj["name"].upper() + context["suffix"] + ) + + + user_schema = UserSchema() + user_schema.context = {"suffix": "BAR"} + user_schema.dump({"name": "foo"}) + # {'name': 'FOOBAR'} + + # 4.x + from marshmallow import Schema, fields + from marshmallow.experimental.context import Context + + + def transform_name(obj): + return obj["name"].upper() + Context.get()["suffix"] + + + class UserSchema(Schema): + name = fields.Function(serialize=transform_name) + + + with Context({"suffix": "BAR"}): + UserSchema().dump({"name": "foo"}) + # {'name': 'FOOBAR'} + Upgrading to 3.3 ++++++++++++++++ diff --git a/docs/why.rst b/docs/why.rst index c0d630ea2..3603c7da8 100644 --- a/docs/why.rst +++ b/docs/why.rst @@ -55,39 +55,6 @@ In this example, a single schema produced three different outputs! The dynamic n .. _Django REST Framework: https://www.django-rest-framework.org/ .. _Flask-RESTful: https://flask-restful.readthedocs.io/ - -Context-aware serialization. ----------------------------- - -Marshmallow schemas can modify their output based on the context in which they are used. Field objects have access to a ``context`` dictionary that can be changed at runtime. - -Here's a simple example that shows how a `Schema ` can anonymize a person's name when a boolean is set on the context. - -.. code-block:: python - - class PersonSchema(Schema): - id = fields.Integer() - name = fields.Method("get_name") - - def get_name(self, person, context): - if context.get("anonymize"): - return "" - return person.name - - - person = Person(name="Monty") - schema = PersonSchema() - schema.dump(person) # {'id': 143, 'name': 'Monty'} - - # In a different context, anonymize the name - schema.context["anonymize"] = True - schema.dump(person) # {'id': 143, 'name': ''} - - -.. seealso:: - - See the relevant section of the :ref:`usage guide ` to learn more about context-aware serialization. - Advanced schema nesting. ------------------------ From 5b10b846c22326bf67f42e609f17596382869b75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Sat, 4 Jan 2025 01:21:25 +0100 Subject: [PATCH 16/22] Update versionchanged in docstrings --- src/marshmallow/fields.py | 6 ++++++ src/marshmallow/schema.py | 5 ++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index b1d58bb25..a7160be1a 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -151,6 +151,9 @@ class Field(FieldABC): .. versionchanged:: 3.13.0 Replace ``missing`` and ``default`` parameters with ``load_default`` and ``dump_default``. + + .. versionchanged:: 4.0.0 + Remove ``context`` property. """ # Some fields, such as Method fields and Function fields, are not expected @@ -1919,6 +1922,9 @@ class Function(Field): .. versionchanged:: 3.0.0a1 Removed ``func`` parameter. + + .. versionchanged:: 4.0.0 + Don't pass context to serialization and deserialization functions. """ _CHECK_ATTRIBUTE = False diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index ed1ce2953..5c1ca335d 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -245,7 +245,10 @@ class AlbumSchema(Schema): fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`. .. versionchanged:: 3.0.0 - `prefix` parameter removed. + Remove ``prefix`` parameter. + + .. versionchanged:: 4.0.0 + Remove ``context`` parameter. """ TYPE_MAPPING = { From 318cae0d97db1349a66a748935797ab2fb33ac4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Sat, 4 Jan 2025 01:25:05 +0100 Subject: [PATCH 17/22] Update changelog about Context --- CHANGELOG.rst | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index bd8148030..fb0c3c00f 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -48,6 +48,29 @@ As a consequence of this change: Thanks :user:`ddelange` for the PR. +- *Backwards-incompatible*: Remove schema ``context`` property. Passing a context + should be done using a context variable. (issue:`1826`) + marshmallow 4.0 provides an experimental `Context ` + manager class that can be used both to set and retrieve the context. + +.. code-block:: python + + from marshmallow import Schema, fields + from marshmallow.experimental.context import Context + + + def transform_name(obj): + return obj["name"].upper() + Context.get()["suffix"] + + + class UserSchema(Schema): + name = fields.Function(serialize=transform_name) + + + with Context({"suffix": "BAR"}): + UserSchema().dump({"name": "foo"}) + # {'name': 'FOOBAR'} + Deprecations/Removals: - *Backwards-incompatible*: Remove implicit field creation, i.e. using the ``fields`` or ``additional`` class Meta options with undeclared fields (:issue:`1356`). From e638af3c1af8ee10b2c47cef8844ad9e95bfc65f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Sat, 4 Jan 2025 01:28:20 +0100 Subject: [PATCH 18/22] Context: initialize token at __init__ --- src/marshmallow/experimental/context.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/marshmallow/experimental/context.py b/src/marshmallow/experimental/context.py index c7758448c..f235dbf42 100644 --- a/src/marshmallow/experimental/context.py +++ b/src/marshmallow/experimental/context.py @@ -11,12 +11,13 @@ class Context(contextlib.AbstractContextManager, typing.Generic[_T]): def __init__(self, context: _T) -> None: self.context = context + self.token: contextvars.Token | None = None def __enter__(self) -> None: self.token = _CURRENT_CONTEXT.set(self.context) def __exit__(self, *args, **kwargs) -> None: - _CURRENT_CONTEXT.reset(self.token) + _CURRENT_CONTEXT.reset(typing.cast(contextvars.Token, self.token)) @classmethod def get(cls, default=...) -> _T: From 56049597bed6d5479aeb4af02e218974740f5ee0 Mon Sep 17 00:00:00 2001 From: Steven Loria Date: Sun, 5 Jan 2025 00:52:07 -0500 Subject: [PATCH 19/22] Minor edit to upgrading guide --- docs/upgrading.rst | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/upgrading.rst b/docs/upgrading.rst index 74d9ad974..0cf1db48a 100644 --- a/docs/upgrading.rst +++ b/docs/upgrading.rst @@ -170,8 +170,8 @@ if you need to change the final output type. ``pass_many`` is renamed to ``pass_collection`` in decorators ************************************************************* -The ``pass_many`` argument to `pre_load `, -`post_load `, `pre_dump `, +The ``pass_many`` argument to `pre_load `, +`post_load `, `pre_dump `, and `post_dump ` is renamed to ``pass_collection``. The behavior is unchanged. @@ -242,7 +242,7 @@ Upgrading to 3.13 ``load_default`` and ``dump_default`` +++++++++++++++++++++++++++++++++++++ -The ``missing`` and ``default`` parameters of fields are renamed to +The ``missing`` and ``default`` parameters of fields are renamed to ``load_default`` and ``dump_default``, respectively. .. code-block:: python @@ -263,11 +263,11 @@ The ``missing`` and ``default`` parameters of fields are renamed to ``load_default`` and ``dump_default`` are passed to the field constructor as keyword arguments. -Schema Context is Removed +Schema context is removed ************************* -The feature allowing to pass a context to the schema has been removed. Users should -use `contextvars` for that. +Passing context to the schema is no longer supported. Use `contextvars` for passing context to +fields and pre-/post-processing methods instead. marshmallow 4.0 provides an experimental `Context ` manager class that can be used both to set and retrieve the context. From c89c15a9f681cbfed887391a450296bcca5e1163 Mon Sep 17 00:00:00 2001 From: Steven Loria Date: Sun, 5 Jan 2025 01:22:41 -0500 Subject: [PATCH 20/22] Add more documentation for Context --- docs/api_reference.rst | 1 + docs/custom_fields.rst | 40 +++++++++++++++++++++++ docs/marshmallow.experimental.context.rst | 5 +++ src/marshmallow/experimental/context.py | 34 ++++++++++++++++++- 4 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 docs/marshmallow.experimental.context.rst diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 68bc78484..d38e56d2c 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -10,6 +10,7 @@ API Reference marshmallow.decorators marshmallow.validate marshmallow.utils + marshmallow.experimental.context marshmallow.error_store marshmallow.class_registry marshmallow.exceptions diff --git a/docs/custom_fields.rst b/docs/custom_fields.rst index 80a54ccef..e380abdbd 100644 --- a/docs/custom_fields.rst +++ b/docs/custom_fields.rst @@ -94,6 +94,46 @@ Both :class:`Function ` and :class:`Method 100.0 +Using context +------------- + +A field may need information about its environment to know how to (de)serialize a value. + +You can use the experimental `Context ` class +to set and retrieve context. + +As an example, you might want your ``UserSchema`` to output whether or not a ``User`` is the author of a ``Blog`` or whether a certain word appears in a ``Blog's`` title. + +.. code-block:: python + + from marshmallow import Schema, fields + from marshmallow.experimental.context import Context + + + class UserSchema(Schema): + name = fields.String() + + is_author = fields.Function(lambda user: user == Context.get()["blog"].author) + likes_bikes = fields.Method("writes_about_bikes") + + def writes_about_bikes(self, user): + return "bicycle" in Context.get()["blog"].title.lower() + + + schema = UserSchema() + + user = User("Freddie Mercury", "fred@queen.com") + blog = Blog("Bicycle Blog", author=user) + + with Context({"blog": blog}): + result = schema.dump(user) + result["is_author"] # => True + result["likes_bikes"] # => True + +.. note:: + You can use `Context.get ` + within custom fields, pre-/post-processing methods, and validators. + Customizing error messages -------------------------- diff --git a/docs/marshmallow.experimental.context.rst b/docs/marshmallow.experimental.context.rst new file mode 100644 index 000000000..50f8e0e61 --- /dev/null +++ b/docs/marshmallow.experimental.context.rst @@ -0,0 +1,5 @@ +Context (experimental) +====================== + +.. automodule:: marshmallow.experimental.context + :members: diff --git a/src/marshmallow/experimental/context.py b/src/marshmallow/experimental/context.py index f235dbf42..b1422687f 100644 --- a/src/marshmallow/experimental/context.py +++ b/src/marshmallow/experimental/context.py @@ -1,4 +1,29 @@ -"""Helper API for setting serialization/deserialization context.""" +"""Helper API for setting serialization/deserialization context. + +Example usage: + +.. code-block:: python + + import typing + + from marshmallow import Schema, fields + from marshmallow.experimental.context import Context + + + class UserContext(typing.TypedDict): + suffix: str + + + class UserSchema(Schema): + name_suffixed = fields.Function( + lambda user: user["name"] + Context[UserContext].get()["suffix"] + ) + + + with Context({"suffix": "bar"}): + print(UserSchema().dump({"name": "foo"})) + # {'name': 'foobar'} +""" import contextlib import contextvars @@ -9,6 +34,8 @@ class Context(contextlib.AbstractContextManager, typing.Generic[_T]): + """Context manager for setting and retrieving context.""" + def __init__(self, context: _T) -> None: self.context = context self.token: contextvars.Token | None = None @@ -21,6 +48,11 @@ def __exit__(self, *args, **kwargs) -> None: @classmethod def get(cls, default=...) -> _T: + """Get the current context. + + :param default: Default value to return if no context is set. + If not provided and no context is set, a :exc:`LookupError` is raised. + """ if default is not ...: return _CURRENT_CONTEXT.get(default) return _CURRENT_CONTEXT.get() From f1cbe27112cdcef5afd75c30f721287fb2ec26da Mon Sep 17 00:00:00 2001 From: Steven Loria Date: Sun, 5 Jan 2025 10:06:46 -0500 Subject: [PATCH 21/22] More complete examples --- CHANGELOG.rst | 24 ++++--- docs/custom_fields.rst | 48 ++++++++++--- docs/upgrading.rst | 91 +++++++++++++------------ src/marshmallow/experimental/context.py | 4 +- 4 files changed, 103 insertions(+), 64 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b741d2c37..c5f36afa9 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,7 +18,7 @@ Other changes: As a consequence of this change: - Time with time offsets are now supported. - YYYY-MM-DD is now accepted as a datetime and deserialized as naive 00:00 AM. - - `from_iso_date`, `from_iso_time` and `from_iso_datetime` are removed from `marshmallow.utils` + - `from_iso_date`, `from_iso_time` and `from_iso_datetime` are removed from `marshmallow.utils`. - *Backwards-incompatible*: Custom validators must raise a `ValidationError ` for invalid values. Returning `False` is no longer supported (:issue:`1775`). @@ -48,28 +48,32 @@ As a consequence of this change: Thanks :user:`ddelange` for the PR. -- *Backwards-incompatible*: Remove schema ``context`` property. Passing a context - should be done using a context variable. (issue:`1826`) - marshmallow 4.0 provides an experimental `Context ` - manager class that can be used both to set and retrieve the context. +- *Backwards-incompatible*: Remove `Schema `'s ``context`` attribute. Passing a context + should be done using `contextvars.ContextVar` (:issue:`1826`). + marshmallow 4 provides an experimental `Context ` + manager class that can be used to both set and retrieve context. .. code-block:: python + import typing + from marshmallow import Schema, fields from marshmallow.experimental.context import Context - def transform_name(obj): - return obj["name"].upper() + Context.get()["suffix"] + class UserContext(typing.TypedDict): + suffix: str class UserSchema(Schema): - name = fields.Function(serialize=transform_name) + name_suffixed = fields.Function( + lambda obj: obj["name"] + Context[UserContext].get()["suffix"] + ) - with Context({"suffix": "BAR"}): + with Context[UserContext]({"suffix": "bar"}): UserSchema().dump({"name": "foo"}) - # {'name': 'FOOBAR'} + # {'name_suffixed': 'foobar'} Deprecations/Removals: diff --git a/docs/custom_fields.rst b/docs/custom_fields.rst index e380abdbd..5c2550fe0 100644 --- a/docs/custom_fields.rst +++ b/docs/custom_fields.rst @@ -94,6 +94,8 @@ Both :class:`Function ` and :class:`Method 100.0 +.. _using_context: + Using context ------------- @@ -102,37 +104,63 @@ A field may need information about its environment to know how to (de)serialize You can use the experimental `Context ` class to set and retrieve context. -As an example, you might want your ``UserSchema`` to output whether or not a ``User`` is the author of a ``Blog`` or whether a certain word appears in a ``Blog's`` title. +Let's say your ``UserSchema`` needs to output +whether or not a ``User`` is the author of a ``Blog`` or +whether a certain word appears in a ``Blog's`` title. .. code-block:: python + import typing + from dataclasses import dataclass + from marshmallow import Schema, fields from marshmallow.experimental.context import Context + @dataclass + class User: + name: str + + + @dataclass + class Blog: + title: str + author: User + + + class ContextDict(typing.TypedDict): + blog: Blog + + class UserSchema(Schema): name = fields.String() - is_author = fields.Function(lambda user: user == Context.get()["blog"].author) + is_author = fields.Function( + lambda user: user == Context[ContextDict].get()["blog"].author + ) likes_bikes = fields.Method("writes_about_bikes") - def writes_about_bikes(self, user): - return "bicycle" in Context.get()["blog"].title.lower() + def writes_about_bikes(self, user: User) -> bool: + return "bicycle" in Context[ContextDict].get()["blog"].title.lower() + +.. note:: + You can use `Context.get ` + within custom fields, pre-/post-processing methods, and validators. +When (de)serializing, set the context by using `Context ` as a context manager. + +.. code-block:: python - schema = UserSchema() user = User("Freddie Mercury", "fred@queen.com") blog = Blog("Bicycle Blog", author=user) + schema = UserSchema() with Context({"blog": blog}): result = schema.dump(user) - result["is_author"] # => True - result["likes_bikes"] # => True + print(result["is_author"]) # => True + print(result["likes_bikes"]) # => True -.. note:: - You can use `Context.get ` - within custom fields, pre-/post-processing methods, and validators. Customizing error messages -------------------------- diff --git a/docs/upgrading.rst b/docs/upgrading.rst index 0cf1db48a..e26943b27 100644 --- a/docs/upgrading.rst +++ b/docs/upgrading.rst @@ -57,6 +57,55 @@ If you want to use anonymous functions, you can use this helper function. class UserSchema(Schema): password = fields.String(validate=predicate(lambda x: x == "password")) +New context API +*************** + +Passing context to `Schema ` classes is no longer supported. Use `contextvars.ContextVar` for passing context to +fields, pre-/post-processing methods, and validators instead. + +marshmallow 4 provides an experimental `Context ` +manager class that can be used to both set and retrieve context. + +.. code-block:: python + + # 3.x + from marshmallow import Schema, fields + + + class UserSchema(Schema): + name_suffixed = fields.Function( + lambda obj, context: obj["name"] + context["suffix"] + ) + + + user_schema = UserSchema() + user_schema.context = {"suffix": "bar"} + user_schema.dump({"name": "foo"}) + # {'name_suffixed': 'foobar'} + + # 4.x + import typing + + from marshmallow import Schema, fields + from marshmallow.experimental.context import Context + + + class UserContext(typing.TypedDict): + suffix: str + + + class UserSchema(Schema): + name_suffixed = fields.Function( + lambda obj: obj["name"] + Context[UserContext].get()["suffix"] + ) + + + with Context[UserContext]({"suffix": "bar"}): + UserSchema().dump({"name": "foo"}) + # {'name_suffixed': 'foobar'} + +See :ref:`using_context` for more information. + Implicit field creation is removed ********************************** @@ -263,48 +312,6 @@ The ``missing`` and ``default`` parameters of fields are renamed to ``load_default`` and ``dump_default`` are passed to the field constructor as keyword arguments. -Schema context is removed -************************* - -Passing context to the schema is no longer supported. Use `contextvars` for passing context to -fields and pre-/post-processing methods instead. - -marshmallow 4.0 provides an experimental `Context ` -manager class that can be used both to set and retrieve the context. - -.. code-block:: python - - # 3.x - from marshmallow import Schema, fields - - - class UserSchema(Schema): - name = fields.Function( - serialize=lambda obj, context: obj["name"].upper() + context["suffix"] - ) - - - user_schema = UserSchema() - user_schema.context = {"suffix": "BAR"} - user_schema.dump({"name": "foo"}) - # {'name': 'FOOBAR'} - - # 4.x - from marshmallow import Schema, fields - from marshmallow.experimental.context import Context - - - def transform_name(obj): - return obj["name"].upper() + Context.get()["suffix"] - - - class UserSchema(Schema): - name = fields.Function(serialize=transform_name) - - - with Context({"suffix": "BAR"}): - UserSchema().dump({"name": "foo"}) - # {'name': 'FOOBAR'} Upgrading to 3.3 ++++++++++++++++ diff --git a/src/marshmallow/experimental/context.py b/src/marshmallow/experimental/context.py index b1422687f..af1445f51 100644 --- a/src/marshmallow/experimental/context.py +++ b/src/marshmallow/experimental/context.py @@ -20,9 +20,9 @@ class UserSchema(Schema): ) - with Context({"suffix": "bar"}): + with Context[UserContext]({"suffix": "bar"}): print(UserSchema().dump({"name": "foo"})) - # {'name': 'foobar'} + # {'name_suffixed': 'foobar'} """ import contextlib From 63d46aa31f3d5d6b1c260785afe6834349181873 Mon Sep 17 00:00:00 2001 From: Steven Loria Date: Sun, 5 Jan 2025 14:21:39 -0500 Subject: [PATCH 22/22] Exemplify using type aliases for Context --- docs/upgrading.rst | 7 +++++-- src/marshmallow/experimental/context.py | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/docs/upgrading.rst b/docs/upgrading.rst index e26943b27..d0c857f06 100644 --- a/docs/upgrading.rst +++ b/docs/upgrading.rst @@ -94,13 +94,16 @@ manager class that can be used to both set and retrieve context. suffix: str + UserSchemaContext = Context[UserContext] + + class UserSchema(Schema): name_suffixed = fields.Function( - lambda obj: obj["name"] + Context[UserContext].get()["suffix"] + lambda obj: obj["name"] + UserSchemaContext.get()["suffix"] ) - with Context[UserContext]({"suffix": "bar"}): + with UserSchemaContext({"suffix": "bar"}): UserSchema().dump({"name": "foo"}) # {'name_suffixed': 'foobar'} diff --git a/src/marshmallow/experimental/context.py b/src/marshmallow/experimental/context.py index af1445f51..bd06d5fb8 100644 --- a/src/marshmallow/experimental/context.py +++ b/src/marshmallow/experimental/context.py @@ -14,13 +14,16 @@ class UserContext(typing.TypedDict): suffix: str + UserSchemaContext = Context[UserContext] + + class UserSchema(Schema): name_suffixed = fields.Function( - lambda user: user["name"] + Context[UserContext].get()["suffix"] + lambda user: user["name"] + UserSchemaContext.get()["suffix"] ) - with Context[UserContext]({"suffix": "bar"}): + with UserSchemaContext({"suffix": "bar"}): print(UserSchema().dump({"name": "foo"})) # {'name_suffixed': 'foobar'} """