Skip to content

Commit

Permalink
Improve from_dict performance by caching fields
Browse files Browse the repository at this point in the history
  • Loading branch information
ihabunek committed Nov 26, 2023
1 parent 48d9cae commit 1c5abb8
Showing 1 changed file with 20 additions and 9 deletions.
29 changes: 20 additions & 9 deletions toot/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

from dataclasses import dataclass, is_dataclass
from datetime import date, datetime
from typing import Dict, List, Optional, Type, TypeVar, Union
from functools import lru_cache
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import get_type_hints

from toot.typing_compat import get_args, get_origin
Expand Down Expand Up @@ -435,17 +436,27 @@ def from_dict(cls: Type[T], data: Dict) -> T:
data = prepare(data)

def _fields():
hints = get_type_hints(cls)
for field in dataclasses.fields(cls):
field_type = _prune_optional(hints[field.name])
default_value = _get_default_value(field)
value = data.get(field.name, default_value)
converted = _convert_with_error_handling(cls, field.name, field_type, value)
yield field.name, converted
for name, type, default in get_fields(cls):
value = data.get(name, default)
converted = _convert_with_error_handling(cls, name, type, value)
yield name, converted

return cls(**dict(_fields()))


@lru_cache(maxsize=100)
def get_fields(cls: Type) -> List[Tuple[str, Type, Any]]:
hints = get_type_hints(cls)
return [
(
field.name,
_prune_optional(hints[field.name]),
_get_default_value(field)
)
for field in dataclasses.fields(cls)
]


def from_dict_list(cls: Type[T], data: List[Dict]) -> List[T]:
return [from_dict(cls, x) for x in data]

Expand Down Expand Up @@ -497,7 +508,7 @@ def _convert(field_type, value):
raise ValueError(f"Not implemented for type '{field_type}'")


def _prune_optional(field_type):
def _prune_optional(field_type: Type) -> Type:
"""For `Optional[<type>]` returns the encapsulated `<type>`."""
if get_origin(field_type) == Union:
args = get_args(field_type)
Expand Down

0 comments on commit 1c5abb8

Please sign in to comment.