diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index cbcdd74b90..d2da195014 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -1,9 +1,9 @@ import inspect import json import os -from dataclasses import asdict, dataclass, is_dataclass +from dataclasses import Field, asdict, dataclass, is_dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional, Protocol, Tuple, Type, TypeVar, Union import packaging.version @@ -25,7 +25,15 @@ if TYPE_CHECKING: - from _typeshed import DataclassInstance + from _typeshed import DataclassInstance # type: ignore +else: + + class DataclassInstance(Protocol): # type: ignore + __dataclass_fields__: ClassVar[Dict[str, Field]] + + +Dataclass = TypeVar("Dataclass", bound=DataclassInstance) +DataclassType = Type[Dataclass] if is_torch_available(): import torch # type: ignore @@ -175,7 +183,7 @@ class ModelHubMixin: ``` """ - _hub_mixin_config: Optional[Union[dict, "DataclassInstance"]] = None + _hub_mixin_config: Optional[Union[dict, DataclassType]] = None # ^ optional config attribute automatically set in `from_pretrained` _hub_mixin_info: MixinInfo # ^ information about the library integrating ModelHubMixin (used to generate model card) @@ -366,7 +374,7 @@ def save_pretrained( self, save_directory: Union[str, Path], *, - config: Optional[Union[dict, "DataclassInstance"]] = None, + config: Optional[Union[dict, DataclassType]] = None, repo_id: Optional[str] = None, push_to_hub: bool = False, model_card_kwargs: Optional[Dict[str, Any]] = None, @@ -618,7 +626,7 @@ def push_to_hub( self, repo_id: str, *, - config: Optional[Union[dict, "DataclassInstance"]] = None, + config: Optional[Union[dict, DataclassType]] = None, commit_message: str = "Push model using huggingface_hub.", private: Optional[bool] = None, token: Optional[str] = None, @@ -825,7 +833,7 @@ def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, stric return model -def _load_dataclass(datacls: Type["DataclassInstance"], data: dict) -> "DataclassInstance": +def _load_dataclass(datacls: DataclassType, data: dict) -> DataclassType: """Load a dataclass instance from a dictionary. Fields not expected by the dataclass are ignored.