Skip to content

Commit 256a191

Browse files
authored
Limit to pydantic<2.x on python3.8 (#1828)
* limit pydantic version on python3.8 to 1.x * fix * fix comma * light fix/refacto for keras tests
1 parent a1b90fa commit 256a191

File tree

4 files changed

+19
-29
lines changed

4 files changed

+19
-29
lines changed

.github/workflows/python-tests.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ jobs:
7373
sudo apt update
7474
sudo apt install -y graphviz
7575
pip install .[tensorflow]
76-
pip install typing_extensions>=4.8.0
7776
;;
7877
7978
esac

setup.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@ def get_version() -> str:
3030

3131
extras["inference"] = [
3232
"aiohttp", # for AsyncInferenceClient
33-
"pydantic>1.1,<3.0", # match text-generation-inference v1.1.0
33+
# On Python 3.8, Pydantic 2.x and tensorflow don't play well together
34+
# Let's limit pydantic to 1.x for now. Since Tensorflow 2.14, Python3.8 is not supported anyway so impact should be
35+
# limited. We still trigger some CIs on Python 3.8 so we need this workaround.
36+
"pydantic>1.1,<3.0; python_version>'3.8'",
37+
"pydantic>1.1,<2.0; python_version=='3.8'",
3438
]
3539

3640
extras["torch"] = [
@@ -76,7 +80,6 @@ def get_version() -> str:
7680
"types-toml",
7781
"types-tqdm",
7882
"types-urllib3",
79-
"pydantic>1.1,<3.0", # for text-generation-interface dataclasses
8083
]
8184

8285
extras["quality"] = [

src/huggingface_hub/keras_mixin.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,18 @@ def _create_model_card(
9292
):
9393
"""
9494
Creates a model card for the repository.
95+
96+
Do not overwrite an existing README.md file.
9597
"""
98+
readme_path = repo_dir / "README.md"
99+
if readme_path.exists():
100+
return
101+
96102
hyperparameters = _create_hyperparameter_table(model)
97103
if plot_model and is_graphviz_available() and is_pydot_available():
98104
_plot_network(model, repo_dir)
99105
if metadata is None:
100106
metadata = {}
101-
readme_path = f"{repo_dir}/README.md"
102107
metadata["library_name"] = "keras"
103108
model_card: str = "---\n"
104109
model_card += yaml_dump(metadata, default_flow_style=False)
@@ -120,13 +125,7 @@ def _create_model_card(
120125
model_card += f"\n![Model Image]({path_to_plot})\n"
121126
model_card += "\n</details>"
122127

123-
if os.path.exists(readme_path):
124-
with open(readme_path, "r", encoding="utf8") as f:
125-
readme = f.read()
126-
else:
127-
readme = model_card
128-
with open(readme_path, "w", encoding="utf-8") as f:
129-
f.write(readme)
128+
readme_path.write_text(model_card)
130129

131130

132131
def save_pretrained_keras(

tests/test_keras_integration.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import json
22
import os
3-
import re
43
import unittest
54
from pathlib import Path
65

@@ -13,25 +12,14 @@
1312
push_to_hub_keras,
1413
save_pretrained_keras,
1514
)
16-
from huggingface_hub.utils import (
17-
is_graphviz_available,
18-
is_pydot_available,
19-
is_tf_available,
20-
logging,
21-
)
15+
from huggingface_hub.utils import is_graphviz_available, is_pydot_available, is_tf_available, logging
2216

2317
from .testing_constants import ENDPOINT_STAGING, TOKEN, USER
24-
from .testing_utils import (
25-
repo_name,
26-
)
18+
from .testing_utils import repo_name
2719

2820

2921
logger = logging.get_logger(__name__)
3022

31-
WORKING_REPO_SUBDIR = f"fixtures/working_repo_{__name__.split('.')[-1]}"
32-
WORKING_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), WORKING_REPO_SUBDIR)
33-
34-
PUSH_TO_HUB_KERAS_WARNING_REGEX = re.escape("Deprecated argument(s) used in 'push_to_hub_keras':")
3523

3624
if is_tf_available():
3725
import tensorflow as tf
@@ -259,8 +247,9 @@ def test_push_to_hub_keras_sequential_via_http_basic(self):
259247
push_to_hub_keras(model, repo_id=repo_id, token=TOKEN, api_endpoint=ENDPOINT_STAGING)
260248
model_info = self._api.model_info(repo_id)
261249
self.assertEqual(model_info.modelId, repo_id)
262-
self.assertTrue("README.md" in [f.rfilename for f in model_info.siblings])
263-
self.assertTrue("model.png" in [f.rfilename for f in model_info.siblings])
250+
repo_files = self._api.list_repo_files(repo_id)
251+
self.assertIn("README.md", repo_files)
252+
self.assertIn("model.png", repo_files)
264253
self._api.delete_repo(repo_id=repo_id)
265254

266255
def test_push_to_hub_keras_sequential_via_http_plot_false(self):
@@ -269,8 +258,8 @@ def test_push_to_hub_keras_sequential_via_http_plot_false(self):
269258
model = self.model_fit(model)
270259

271260
push_to_hub_keras(model, repo_id=repo_id, token=TOKEN, api_endpoint=ENDPOINT_STAGING, plot_model=False)
272-
model_info = self._api.model_info(repo_id)
273-
self.assertFalse("model.png" in [f.rfilename for f in model_info.siblings])
261+
repo_files = self._api.list_repo_files(repo_id)
262+
self.assertNotIn("model.png", repo_files)
274263
self._api.delete_repo(repo_id=repo_id)
275264

276265
def test_push_to_hub_keras_via_http_override_tensorboard(self):

0 commit comments

Comments
 (0)