Skip to content

Commit 316f18c

Browse files
authored
Allow saving / loading from Huggingface Hub preset (#1510)
* first draft * update upload_preset * lint * consistent error messages * lint
1 parent a6700eb commit 316f18c

File tree

1 file changed

+56
-5
lines changed

1 file changed

+56
-5
lines changed

keras_nlp/utils/preset_utils.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,16 @@
2727
except ImportError:
2828
kagglehub = None
2929

30+
try:
31+
import huggingface_hub
32+
from huggingface_hub.utils import HFValidationError
33+
except ImportError:
34+
huggingface_hub = None
35+
3036
KAGGLE_PREFIX = "kaggle://"
3137
GS_PREFIX = "gs://"
38+
HF_PREFIX = "hf://"
39+
3240
TOKENIZER_ASSET_DIR = "assets/tokenizer"
3341
CONFIG_FILE = "config.json"
3442
TOKENIZER_CONFIG_FILE = "tokenizer.json"
@@ -69,15 +77,33 @@ def get_file(preset, path):
6977
url,
7078
cache_subdir=os.path.join("models", subdir),
7179
)
80+
elif preset.startswith(HF_PREFIX):
81+
if huggingface_hub is None:
82+
raise ImportError(
83+
f"`from_preset()` requires the `huggingface_hub` package to load from '{preset}'. "
84+
"Please install with `pip install huggingface_hub`."
85+
)
86+
hf_handle = preset.removeprefix(HF_PREFIX)
87+
try:
88+
return huggingface_hub.hf_hub_download(
89+
repo_id=hf_handle, filename=path
90+
)
91+
except HFValidationError as e:
92+
raise ValueError(
93+
"Unexpected Hugging Face preset. Hugging Face model handles "
94+
"should have the form 'hf://{org}/{model}'. For example, "
95+
f"'hf://username/bert_base_en'. Received: preset={preset}."
96+
) from e
7297
elif os.path.exists(preset):
7398
# Assume a local filepath.
7499
return os.path.join(preset, path)
75100
else:
76101
raise ValueError(
77102
"Unknown preset identifier. A preset must be a one of:\n"
78-
"1) a built in preset identifier like `'bert_base_en'`\n"
103+
"1) a built-in preset identifier like `'bert_base_en'`\n"
79104
"2) a Kaggle Models handle like `'kaggle://keras/bert/keras/bert_base_en'`\n"
80-
"3) a path to a local preset directory like `'./bert_base_en`\n"
105+
"3) a Hugging Face handle like `'hf://username/bert_base_en'`\n"
106+
"4) a path to a local preset directory like `'./bert_base_en`\n"
81107
"Use `print(cls.presets.keys())` to view all built-in presets for "
82108
"API symbol `cls`.\n"
83109
f"Received: preset='{preset}'"
@@ -245,7 +271,9 @@ def upload_preset(
245271
uri: The URI identifying model to upload to.
246272
URIs with format
247273
`kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>`
248-
will be uploaded to Kaggle Hub.
274+
will be uploaded to Kaggle Hub while URIs with format
275+
`hf://[<HF_USERNAME>/]<MODEL>` will be uploaded to the Hugging
276+
Face Hub.
249277
preset: The path to the local model preset directory.
250278
allow_incomplete: If True, allows the upload of presets without
251279
a tokenizer configuration. Otherwise, a tokenizer
@@ -262,10 +290,33 @@ def upload_preset(
262290
if uri.startswith(KAGGLE_PREFIX):
263291
kaggle_handle = uri.removeprefix(KAGGLE_PREFIX)
264292
kagglehub.model_upload(kaggle_handle, preset)
293+
elif uri.startswith(HF_PREFIX):
294+
if huggingface_hub is None:
295+
raise ImportError(
296+
f"`upload_preset()` requires the `huggingface_hub` package to upload to '{uri}'. "
297+
"Please install with `pip install huggingface_hub`."
298+
)
299+
hf_handle = uri.removeprefix(HF_PREFIX)
300+
try:
301+
repo_url = huggingface_hub.create_repo(
302+
repo_id=hf_handle, exist_ok=True
303+
)
304+
except HFValidationError as e:
305+
raise ValueError(
306+
"Unexpected Hugging Face URI. Hugging Face model handles "
307+
"should have the form 'hf://[{org}/]{model}'. For example, "
308+
"'hf://username/bert_base_en' or 'hf://bert_case_en' to implicitly"
309+
f"upload to your user account. Received: URI={uri}."
310+
) from e
311+
huggingface_hub.upload_folder(
312+
repo_id=repo_url.repo_id, folder_path=preset
313+
)
265314
else:
266315
raise ValueError(
267-
f"Unexpected URI `'{uri}'`. Kaggle upload format should follow "
268-
"`kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>`."
316+
"Unknown URI. An URI must be a one of:\n"
317+
"1) a Kaggle Model handle like `'kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>'`\n"
318+
"2) a Hugging Face handle like `'hf://[<HF_USERNAME>/]<MODEL>'`\n"
319+
f"Received: uri='{uri}'."
269320
)
270321

271322

0 commit comments

Comments
 (0)