27
27
except ImportError :
28
28
kagglehub = None
29
29
30
+ try :
31
+ import huggingface_hub
32
+ from huggingface_hub .utils import HFValidationError
33
+ except ImportError :
34
+ huggingface_hub = None
35
+
30
36
KAGGLE_PREFIX = "kaggle://"
31
37
GS_PREFIX = "gs://"
38
+ HF_PREFIX = "hf://"
39
+
32
40
TOKENIZER_ASSET_DIR = "assets/tokenizer"
33
41
CONFIG_FILE = "config.json"
34
42
TOKENIZER_CONFIG_FILE = "tokenizer.json"
@@ -69,15 +77,33 @@ def get_file(preset, path):
69
77
url ,
70
78
cache_subdir = os .path .join ("models" , subdir ),
71
79
)
80
+ elif preset .startswith (HF_PREFIX ):
81
+ if huggingface_hub is None :
82
+ raise ImportError (
83
+ f"`from_preset()` requires the `huggingface_hub` package to load from '{ preset } '. "
84
+ "Please install with `pip install huggingface_hub`."
85
+ )
86
+ hf_handle = preset .removeprefix (HF_PREFIX )
87
+ try :
88
+ return huggingface_hub .hf_hub_download (
89
+ repo_id = hf_handle , filename = path
90
+ )
91
+ except HFValidationError as e :
92
+ raise ValueError (
93
+ "Unexpected Hugging Face preset. Hugging Face model handles "
94
+ "should have the form 'hf://{org}/{model}'. For example, "
95
+ f"'hf://username/bert_base_en'. Received: preset={ preset } ."
96
+ ) from e
72
97
elif os .path .exists (preset ):
73
98
# Assume a local filepath.
74
99
return os .path .join (preset , path )
75
100
else :
76
101
raise ValueError (
77
102
"Unknown preset identifier. A preset must be a one of:\n "
78
- "1) a built in preset identifier like `'bert_base_en'`\n "
103
+ "1) a built- in preset identifier like `'bert_base_en'`\n "
79
104
"2) a Kaggle Models handle like `'kaggle://keras/bert/keras/bert_base_en'`\n "
80
- "3) a path to a local preset directory like `'./bert_base_en`\n "
105
+ "3) a Hugging Face handle like `'hf://username/bert_base_en'`\n "
106
+ "4) a path to a local preset directory like `'./bert_base_en`\n "
81
107
"Use `print(cls.presets.keys())` to view all built-in presets for "
82
108
"API symbol `cls`.\n "
83
109
f"Received: preset='{ preset } '"
@@ -245,7 +271,9 @@ def upload_preset(
245
271
uri: The URI identifying model to upload to.
246
272
URIs with format
247
273
`kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>`
248
- will be uploaded to Kaggle Hub.
274
+ will be uploaded to Kaggle Hub while URIs with format
275
+ `hf://[<HF_USERNAME>/]<MODEL>` will be uploaded to the Hugging
276
+ Face Hub.
249
277
preset: The path to the local model preset directory.
250
278
allow_incomplete: If True, allows the upload of presets without
251
279
a tokenizer configuration. Otherwise, a tokenizer
@@ -262,10 +290,33 @@ def upload_preset(
262
290
if uri .startswith (KAGGLE_PREFIX ):
263
291
kaggle_handle = uri .removeprefix (KAGGLE_PREFIX )
264
292
kagglehub .model_upload (kaggle_handle , preset )
293
+ elif uri .startswith (HF_PREFIX ):
294
+ if huggingface_hub is None :
295
+ raise ImportError (
296
+ f"`upload_preset()` requires the `huggingface_hub` package to upload to '{ uri } '. "
297
+ "Please install with `pip install huggingface_hub`."
298
+ )
299
+ hf_handle = uri .removeprefix (HF_PREFIX )
300
+ try :
301
+ repo_url = huggingface_hub .create_repo (
302
+ repo_id = hf_handle , exist_ok = True
303
+ )
304
+ except HFValidationError as e :
305
+ raise ValueError (
306
+ "Unexpected Hugging Face URI. Hugging Face model handles "
307
+ "should have the form 'hf://[{org}/]{model}'. For example, "
308
+ "'hf://username/bert_base_en' or 'hf://bert_case_en' to implicitly"
309
+ f"upload to your user account. Received: URI={ uri } ."
310
+ ) from e
311
+ huggingface_hub .upload_folder (
312
+ repo_id = repo_url .repo_id , folder_path = preset
313
+ )
265
314
else :
266
315
raise ValueError (
267
- f"Unexpected URI `'{ uri } '`. Kaggle upload format should follow "
268
- "`kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>`."
316
+ "Unknown URI. An URI must be a one of:\n "
317
+ "1) a Kaggle Model handle like `'kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>'`\n "
318
+ "2) a Hugging Face handle like `'hf://[<HF_USERNAME>/]<MODEL>'`\n "
319
+ f"Received: uri='{ uri } '."
269
320
)
270
321
271
322
0 commit comments