Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
hanouticelina committed Jan 6, 2025
1 parent 8bbd0e1 commit a5c49e4
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
22 changes: 10 additions & 12 deletions src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,6 @@
)


# Type alias for dataclass instances
class DataclassInstance(Protocol):
__dataclass_fields__: ClassVar[Dict[str, Field[Any]]]


Dataclass = TypeVar("Dataclass", bound=DataclassInstance)
DataclassType = Type[Dataclass]

if is_torch_available():
import torch # type: ignore

Expand All @@ -43,6 +35,12 @@ class DataclassInstance(Protocol):

logger = logging.get_logger(__name__)


# Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349
class DataclassInstance(Protocol):
__dataclass_fields__: ClassVar[Dict[str, Field[Any]]]


# Generic variable that is either ModelHubMixin or a subclass thereof
T = TypeVar("T", bound="ModelHubMixin")
# Generic variable to represent an args type
Expand Down Expand Up @@ -180,7 +178,7 @@ class ModelHubMixin:
```
"""

_hub_mixin_config: Optional[Union[dict, DataclassType]] = None
_hub_mixin_config: Optional[Union[dict, DataclassInstance]] = None
# ^ optional config attribute automatically set in `from_pretrained`
_hub_mixin_info: MixinInfo
# ^ information about the library integrating ModelHubMixin (used to generate model card)
Expand Down Expand Up @@ -371,7 +369,7 @@ def save_pretrained(
self,
save_directory: Union[str, Path],
*,
config: Optional[Union[dict, DataclassType]] = None,
config: Optional[Union[dict, DataclassInstance]] = None,
repo_id: Optional[str] = None,
push_to_hub: bool = False,
model_card_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -623,7 +621,7 @@ def push_to_hub(
self,
repo_id: str,
*,
config: Optional[Union[dict, DataclassType]] = None,
config: Optional[Union[dict, DataclassInstance]] = None,
commit_message: str = "Push model using huggingface_hub.",
private: Optional[bool] = None,
token: Optional[str] = None,
Expand Down Expand Up @@ -830,7 +828,7 @@ def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, stric
return model


def _load_dataclass(datacls: DataclassType, data: dict) -> DataclassType:
def _load_dataclass(datacls: Type[DataclassInstance], data: dict) -> DataclassInstance:
"""Load a dataclass instance from a dictionary.
Fields not expected by the dataclass are ignored.
Expand Down
23 changes: 22 additions & 1 deletion tests/test_hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import unittest
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, Optional, Union, get_type_hints
from unittest.mock import Mock, patch

import jedi
Expand Down Expand Up @@ -474,3 +474,24 @@ def dummy_example_for_test(self, x: str) -> str:
source_lines = source.split("\n")
completions = script.complete(len(source_lines), len(source_lines[-1]))
assert any(completion.name == "dummy_example_for_test" for completion in completions)

def test_get_type_hints_works_as_expected(self):
"""
Ensure that `typing.get_type_hints` works as expected when inheriting from `ModelHubMixin`.
See https://github.com/huggingface/huggingface_hub/issues/2727.
"""

class ModelWithHints(ModelHubMixin):
def method_with_hints(self, x: int) -> str:
return str(x)

assert get_type_hints(ModelWithHints) != {}

# Test method type hints on class
hints = get_type_hints(ModelWithHints.method_with_hints)
assert hints == {"x": int, "return": str}

# Test method type hints on instance
model = ModelWithHints()
assert get_type_hints(model.method_with_hints) == {"x": int, "return": str}

0 comments on commit a5c49e4

Please sign in to comment.