Skip to content

Commit

Permalink
Add DataclassInstance for runtime type_checking
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Jan 3, 2025
1 parent 3618a38 commit 7d5d3dd
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import inspect
import json
import os
from dataclasses import asdict, dataclass, is_dataclass
from dataclasses import Field, asdict, dataclass, is_dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional, Protocol, Tuple, Type, TypeVar, Union

import packaging.version

Expand All @@ -25,7 +25,15 @@


if TYPE_CHECKING:
from _typeshed import DataclassInstance
from _typeshed import DataclassInstance # type: ignore
else:

class DataclassInstance(Protocol): # type: ignore
__dataclass_fields__: ClassVar[Dict[str, Field]]


Dataclass = TypeVar("Dataclass", bound=DataclassInstance)
DataclassType = Type[Dataclass]

if is_torch_available():
import torch # type: ignore
Expand Down Expand Up @@ -175,7 +183,7 @@ class ModelHubMixin:
```
"""

_hub_mixin_config: Optional[Union[dict, "DataclassInstance"]] = None
_hub_mixin_config: Optional[Union[dict, DataclassType]] = None
# ^ optional config attribute automatically set in `from_pretrained`
_hub_mixin_info: MixinInfo
# ^ information about the library integrating ModelHubMixin (used to generate model card)
Expand Down Expand Up @@ -366,7 +374,7 @@ def save_pretrained(
self,
save_directory: Union[str, Path],
*,
config: Optional[Union[dict, "DataclassInstance"]] = None,
config: Optional[Union[dict, DataclassType]] = None,
repo_id: Optional[str] = None,
push_to_hub: bool = False,
model_card_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -618,7 +626,7 @@ def push_to_hub(
self,
repo_id: str,
*,
config: Optional[Union[dict, "DataclassInstance"]] = None,
config: Optional[Union[dict, DataclassType]] = None,
commit_message: str = "Push model using huggingface_hub.",
private: Optional[bool] = None,
token: Optional[str] = None,
Expand Down Expand Up @@ -825,7 +833,7 @@ def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, stric
return model


def _load_dataclass(datacls: Type["DataclassInstance"], data: dict) -> "DataclassInstance":
def _load_dataclass(datacls: DataclassType, data: dict) -> DataclassType:
"""Load a dataclass instance from a dictionary.
Fields not expected by the dataclass are ignored.
Expand Down

0 comments on commit 7d5d3dd

Please sign in to comment.