Skip to content

Commit

Permalink
Replace type comments with inline typing
Browse files Browse the repository at this point in the history
  • Loading branch information
sloria committed Jan 4, 2025
1 parent 1a58c1f commit ad5a10b
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 36 deletions.
18 changes: 9 additions & 9 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def __init__(
)

# Collect default error message from self and parent classes
messages = {} # type: dict[str, str]
messages: dict[str, str] = {}
for cls in reversed(self.__class__.__mro__):
messages.update(getattr(cls, "default_error_messages", {}))
messages.update(error_messages or {})
Expand Down Expand Up @@ -919,7 +919,7 @@ class Number(Field):
:param kwargs: The same keyword arguments that :class:`Field` receives.
"""

num_type = float # type: typing.Type
num_type: type = float

#: Default error messages.
default_error_messages = {
Expand Down Expand Up @@ -956,7 +956,7 @@ def _serialize(self, value, attr, obj, **kwargs) -> str | _T | None:
"""Return a string if `self.as_string=True`, otherwise return this field's `num_type`."""
if value is None:
return None
ret = self._format_num(value) # type: _T
ret: _T = self._format_num(value)
return self._to_string(ret) if self.as_string else ret

def _deserialize(self, value, attr, data, **kwargs) -> _T | None:
Expand Down Expand Up @@ -1213,23 +1213,23 @@ class DateTime(Field):
Add timestamp as a format.
"""

SERIALIZATION_FUNCS = {
SERIALIZATION_FUNCS: dict[str, typing.Callable[[typing.Any], str | float]] = {
"iso": utils.isoformat,
"iso8601": utils.isoformat,
"rfc": utils.rfcformat,
"rfc822": utils.rfcformat,
"timestamp": utils.timestamp,
"timestamp_ms": utils.timestamp_ms,
} # type: dict[str, typing.Callable[[typing.Any], str | float]]
}

DESERIALIZATION_FUNCS = {
DESERIALIZATION_FUNCS: dict[str, typing.Callable[[str], typing.Any]] = {
"iso": utils.from_iso_datetime,
"iso8601": utils.from_iso_datetime,
"rfc": utils.from_rfc,
"rfc822": utils.from_rfc,
"timestamp": utils.from_timestamp,
"timestamp_ms": utils.from_timestamp_ms,
} # type: dict[str, typing.Callable[[str], typing.Any]]
}

DEFAULT_FORMAT = "iso"

Expand Down Expand Up @@ -1732,7 +1732,7 @@ class IP(Field):

default_error_messages = {"invalid_ip": "Not a valid IP address."}

DESERIALIZATION_CLASS = None # type: typing.Optional[typing.Type]
DESERIALIZATION_CLASS: type | None = None

def __init__(self, *args, exploded=False, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -1796,7 +1796,7 @@ class IPInterface(Field):

default_error_messages = {"invalid_ip_interface": "Not a valid IP interface."}

DESERIALIZATION_CLASS = None # type: typing.Optional[typing.Type]
DESERIALIZATION_CLASS: type | None = None

def __init__(self, *args, exploded: bool = False, **kwargs):
super().__init__(*args, **kwargs)
Expand Down
34 changes: 18 additions & 16 deletions src/marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def resolve_hooks(cls) -> dict[str, list[tuple[str, bool, dict]]]:
"""
mro = inspect.getmro(cls)

hooks = defaultdict(list) # type: dict[str, list[tuple[str, bool, dict]]]
hooks: dict[str, list[tuple[str, bool, dict]]] = defaultdict(list)

for attr_name in dir(cls):
# Need to look up the actual descriptor, not whatever might be
Expand All @@ -174,7 +174,9 @@ def resolve_hooks(cls) -> dict[str, list[tuple[str, bool, dict]]]:
continue

try:
hook_config = attr.__marshmallow_hook__ # type: dict[str, list[tuple[bool, dict]]]
hook_config: dict[str, list[tuple[bool, dict]]] = (
attr.__marshmallow_hook__
)
except AttributeError:
pass
else:
Expand Down Expand Up @@ -282,7 +284,7 @@ class AlbumSchema(Schema):
`prefix` parameter removed.
"""

TYPE_MAPPING = {
TYPE_MAPPING: dict[type, type[ma_fields.Field]] = {
str: ma_fields.String,
bytes: ma_fields.String,
dt.datetime: ma_fields.DateTime,
Expand All @@ -297,23 +299,23 @@ class AlbumSchema(Schema):
dt.date: ma_fields.Date,
dt.timedelta: ma_fields.TimeDelta,
decimal.Decimal: ma_fields.Decimal,
} # type: dict[type, typing.Type[ma_fields.Field]]
}
#: Overrides for default schema-level error messages
error_messages = {} # type: dict[str, str]
error_messages: dict[str, str] = {}

_default_error_messages = {
_default_error_messages: dict[str, str] = {
"type": "Invalid input type.",
"unknown": "Unknown field.",
} # type: dict[str, str]
}

OPTIONS_CLASS = SchemaOpts # type: type
OPTIONS_CLASS: type = SchemaOpts

set_class = OrderedSet

# These get set by SchemaMeta
opts = None # type: SchemaOpts
_declared_fields = {} # type: dict[str, ma_fields.Field]
_hooks = {} # type: dict[str, list[tuple[str, bool, dict]]]
opts: SchemaOpts
_declared_fields: dict[str, ma_fields.Field] = {}
_hooks: dict[str, list[tuple[str, bool, dict]]] = {}

class Meta:
"""Options object for a Schema.
Expand Down Expand Up @@ -391,9 +393,9 @@ def __init__(
self.context = context or {}
self._normalize_nested_options()
#: Dictionary mapping field_names -> :class:`Field` objects
self.fields = {} # type: dict[str, ma_fields.Field]
self.load_fields = {} # type: dict[str, ma_fields.Field]
self.dump_fields = {} # type: dict[str, ma_fields.Field]
self.fields: dict[str, ma_fields.Field] = {}
self.load_fields: dict[str, ma_fields.Field] = {}
self.dump_fields: dict[str, ma_fields.Field] = {}
self._init_fields()
messages = {}
messages.update(self._default_error_messages)
Expand Down Expand Up @@ -821,7 +823,7 @@ def _do_load(
:return: Deserialized data
"""
error_store = ErrorStore()
errors = {} # type: dict[str, list[str]]
errors: dict[str, list[str]] = {}
many = self.many if many is None else bool(many)
unknown = (
self.unknown
Expand All @@ -838,7 +840,7 @@ def _do_load(
)
except ValidationError as err:
errors = err.normalized_messages()
result = None # type: list | dict | None
result: list | dict | None = None
else:
processed_data = data
if not errors:
Expand Down
18 changes: 9 additions & 9 deletions src/marshmallow/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Validator(ABC):
add a useful `__repr__` implementation for validators.
"""

error = None # type: str | None
error: str | None = None

def __repr__(self) -> str:
args = self._repr_args()
Expand Down Expand Up @@ -65,7 +65,7 @@ def is_even(value):

def __init__(self, *validators: types.Validator, error: str | None = None):
self.validators = tuple(validators)
self.error = error or self.default_error_message # type: str
self.error: str = error or self.default_error_message

def _repr_args(self) -> str:
return f"validators={self.validators!r}"
Expand Down Expand Up @@ -191,7 +191,7 @@ def __init__(
)
self.relative = relative
self.absolute = absolute
self.error = error or self.default_message # type: str
self.error: str = error or self.default_message
self.schemes = schemes or self.default_schemes
self.require_tld = require_tld

Expand Down Expand Up @@ -250,7 +250,7 @@ class Email(Validator):
default_message = "Not a valid email address."

def __init__(self, *, error: str | None = None):
self.error = error or self.default_message # type: str
self.error: str = error or self.default_message

def _format_error(self, value: str) -> str:
return self.error.format(input=value)
Expand Down Expand Up @@ -436,7 +436,7 @@ class Equal(Validator):

def __init__(self, comparable, *, error: str | None = None):
self.comparable = comparable
self.error = error or self.default_message # type: str
self.error: str = error or self.default_message

def _repr_args(self) -> str:
return f"comparable={self.comparable!r}"
Expand Down Expand Up @@ -477,7 +477,7 @@ def __init__(
self.regex = (
re.compile(regex, flags) if isinstance(regex, (str, bytes)) else regex
)
self.error = error or self.default_message # type: str
self.error: str = error or self.default_message

def _repr_args(self) -> str:
return f"regex={self.regex!r}"
Expand Down Expand Up @@ -514,7 +514,7 @@ class Predicate(Validator):

def __init__(self, method: str, *, error: str | None = None, **kwargs):
self.method = method
self.error = error or self.default_message # type: str
self.error: str = error or self.default_message
self.kwargs = kwargs

def _repr_args(self) -> str:
Expand Down Expand Up @@ -545,7 +545,7 @@ class NoneOf(Validator):
def __init__(self, iterable: typing.Iterable, *, error: str | None = None):
self.iterable = iterable
self.values_text = ", ".join(str(each) for each in self.iterable)
self.error = error or self.default_message # type: str
self.error: str = error or self.default_message

def _repr_args(self) -> str:
return f"iterable={self.iterable!r}"
Expand Down Expand Up @@ -585,7 +585,7 @@ def __init__(
self.choices_text = ", ".join(str(choice) for choice in self.choices)
self.labels = labels if labels is not None else []
self.labels_text = ", ".join(str(label) for label in self.labels)
self.error = error or self.default_message # type: str
self.error: str = error or self.default_message

def _repr_args(self) -> str:
return f"choices={self.choices!r}, labels={self.labels!r}"
Expand Down
4 changes: 2 additions & 2 deletions tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def get_lowername(obj):

class UserSchema(Schema):
name = fields.String()
age = fields.Float() # type: fields.Field
age: fields.Field = fields.Float()
created = fields.DateTime()
created_formatted = fields.DateTime(
format="%Y-%m-%d", attribute="created", dump_only=True
Expand All @@ -193,7 +193,7 @@ class UserSchema(Schema):
homepage = fields.Url()
email = fields.Email()
balance = fields.Decimal()
is_old = fields.Method("get_is_old") # type: fields.Field
is_old: fields.Field = fields.Method("get_is_old")
lowername = fields.Function(get_lowername)
registered = fields.Boolean()
hair_colors = fields.List(fields.Raw)
Expand Down

0 comments on commit ad5a10b

Please sign in to comment.