diff --git a/toot/entities.py b/toot/entities.py index d41a236a..8ef51e3c 100644 --- a/toot/entities.py +++ b/toot/entities.py @@ -12,7 +12,8 @@ from dataclasses import dataclass, is_dataclass from datetime import date, datetime -from typing import Dict, List, Optional, Type, TypeVar, Union +from functools import lru_cache +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union from typing import get_type_hints from toot.typing_compat import get_args, get_origin @@ -435,17 +436,27 @@ def from_dict(cls: Type[T], data: Dict) -> T: data = prepare(data) def _fields(): - hints = get_type_hints(cls) - for field in dataclasses.fields(cls): - field_type = _prune_optional(hints[field.name]) - default_value = _get_default_value(field) - value = data.get(field.name, default_value) - converted = _convert_with_error_handling(cls, field.name, field_type, value) - yield field.name, converted + for name, type, default in get_fields(cls): + value = data.get(name, default) + converted = _convert_with_error_handling(cls, name, type, value) + yield name, converted return cls(**dict(_fields())) +@lru_cache(maxsize=100) +def get_fields(cls: Type) -> List[Tuple[str, Type, Any]]: + hints = get_type_hints(cls) + return [ + ( + field.name, + _prune_optional(hints[field.name]), + _get_default_value(field) + ) + for field in dataclasses.fields(cls) + ] + + def from_dict_list(cls: Type[T], data: List[Dict]) -> List[T]: return [from_dict(cls, x) for x in data] @@ -497,7 +508,7 @@ def _convert(field_type, value): raise ValueError(f"Not implemented for type '{field_type}'") -def _prune_optional(field_type): +def _prune_optional(field_type: Type) -> Type: """For `Optional[]` returns the encapsulated ``.""" if get_origin(field_type) == Union: args = get_args(field_type)