Skip to content

Commit af835b2

Browse files
committed
add tests
1 parent ff58d42 commit af835b2

File tree

5 files changed

+54
-5
lines changed

5 files changed

+54
-5
lines changed

src/huggingface_hub/hub_mixin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def save_pretrained(
204204
# save model card
205205
model_card_path = save_directory / "README.md"
206206
if not model_card_path.exists(): # do not overwrite if already exists
207-
self._generate_model_card().save(save_directory / "README.md")
207+
self.generate_model_card().save(save_directory / "README.md")
208208

209209
# push to the Hub if required
210210
if push_to_hub:
@@ -466,7 +466,7 @@ def push_to_hub(
466466
delete_patterns=delete_patterns,
467467
)
468468

469-
def _generate_model_card(self, *args, **kwargs) -> ModelCard:
469+
def generate_model_card(self, *args, **kwargs) -> ModelCard:
470470
card = ModelCard.from_template(
471471
card_data=ModelCardData(**asdict(self.library_info)),
472472
template_str=DEFAULT_MODEL_CARD,

src/huggingface_hub/repocard_data.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,17 @@ def __init__(
309309
ignore_metadata_errors: bool = False,
310310
**kwargs,
311311
):
312+
unique_tags = tags
313+
if tags is not None:
314+
unique_tags = [] # make tags unique + keep order explicitly
315+
for tag in tags:
316+
if tag not in unique_tags:
317+
unique_tags.append(tag)
318+
312319
self.language = language
313320
self.license = license
314321
self.library_name = library_name
315-
self.tags = tags
322+
self.tags = unique_tags
316323
self.base_model = base_model
317324
self.datasets = datasets
318325
self.metrics = metrics

tests/test_hub_mixin_pytorch.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import pytest
1010

11-
from huggingface_hub import HfApi, hf_hub_download
11+
from huggingface_hub import HfApi, ModelCard, hf_hub_download
1212
from huggingface_hub.constants import PYTORCH_WEIGHTS_NAME
1313
from huggingface_hub.hub_mixin import ModelHubMixin, PyTorchModelHubMixin
1414
from huggingface_hub.utils import EntryNotFoundError, HfHubHTTPError, SoftTemporaryDirectory, is_torch_available
@@ -32,6 +32,14 @@ def __init__(self, **kwargs):
3232
def forward(self, x):
3333
return self.l1(x)
3434

35+
class DummyModelWithTags(nn.Module, PyTorchModelHubMixin, tags=["tag1", "tag2"], library_name="my-dummy-lib"):
36+
def __init__(self, linear_layer: int = 4):
37+
super().__init__()
38+
self.l1 = nn.Linear(linear_layer, linear_layer)
39+
40+
def forward(self, x):
41+
return self.l1(x)
42+
3543
else:
3644
DummyModel = None
3745

@@ -231,3 +239,14 @@ def test_push_to_hub(self):
231239

232240
# Delete repo
233241
self._api.delete_repo(repo_id=repo_id)
242+
243+
def test_generate_model_card(self):
244+
model = DummyModelWithTags()
245+
card = model.generate_model_card()
246+
assert card.data.tags == ["tag1", "tag2", "pytorch_model_hub_mixin", "model_hub_mixin"]
247+
248+
model.save_pretrained(self.cache_dir)
249+
card_reloaded = ModelCard.load(self.cache_dir / "README.md")
250+
251+
assert str(card) == str(card_reloaded)
252+
assert card.data == card_reloaded.data

tests/test_repocard.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,15 @@
178178
Some cool dataset card.
179179
"""
180180

181+
DUMMY_MODEL_CARD_TEMPLATE = """
182+
---
183+
{{ card_data }}
184+
---
185+
186+
Custom template passed as a string.
187+
{{ repo_url | default("[More Information Needed]", true) }}
188+
"""
189+
181190

182191
def require_jinja(test_case):
183192
"""
@@ -586,7 +595,8 @@ def test_repo_card_from_default_template_with_model_id(self):
586595
)
587596

588597
@require_jinja
589-
def test_repo_card_from_custom_template(self):
598+
def test_repo_card_from_custom_template_path(self):
599+
# Template is passed as a path (not a raw string)
590600
template_path = SAMPLE_CARDS_DIR / "sample_template.md"
591601
card = RepoCard.from_template(
592602
card_data=CardData(
@@ -605,6 +615,15 @@ def test_repo_card_from_custom_template(self):
605615
"Custom template didn't set jinja variable correctly",
606616
)
607617

618+
@require_jinja
619+
def test_repo_card_from_custom_template_string(self):
620+
# Template is passed as a raw string (not a path)
621+
card = RepoCard.from_template(
622+
card_data=CardData(language="en", license="mit"),
623+
template_str=DUMMY_MODEL_CARD_TEMPLATE,
624+
)
625+
assert "Custom template passed as a string." in str(card)
626+
608627
def test_repo_card_data_must_be_dict(self):
609628
sample_path = SAMPLE_CARDS_DIR / "sample_invalid_card_data.md"
610629
with pytest.raises(ValueError, match="repo card metadata block should be a dict"):

tests/test_repocard_data.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ def test_eval_result_with_incomplete_source(self):
233233
source_name="Open LLM Leaderboard",
234234
)
235235

236+
def test_model_card_unique_tags(self):
237+
data = ModelCardData(tags=["tag2", "tag1", "tag2", "tag3"])
238+
assert data.tags == ["tag2", "tag1", "tag3"]
239+
236240

237241
class DatasetCardDataTest(unittest.TestCase):
238242
def test_train_eval_index_keys_updated(self):

0 commit comments

Comments
 (0)