-
Notifications
You must be signed in to change notification settings - Fork 289
Allow saving / loading from Huggingface Hub preset #1510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,8 +27,16 @@ | |
except ImportError: | ||
kagglehub = None | ||
|
||
try: | ||
import huggingface_hub | ||
from huggingface_hub.utils import HFValidationError | ||
except ImportError: | ||
huggingface_hub = None | ||
|
||
KAGGLE_PREFIX = "kaggle://" | ||
GS_PREFIX = "gs://" | ||
HF_PREFIX = "hf://" | ||
|
||
TOKENIZER_ASSET_DIR = "assets/tokenizer" | ||
CONFIG_FILE = "config.json" | ||
TOKENIZER_CONFIG_FILE = "tokenizer.json" | ||
|
@@ -69,15 +77,33 @@ def get_file(preset, path): | |
url, | ||
cache_subdir=os.path.join("models", subdir), | ||
) | ||
elif preset.startswith(HF_PREFIX): | ||
if huggingface_hub is None: | ||
raise ImportError( | ||
f"`from_preset()` requires the `huggingface_hub` package to load from '{preset}'. " | ||
"Please install with `pip install huggingface_hub`." | ||
) | ||
hf_handle = preset.removeprefix(HF_PREFIX) | ||
try: | ||
return huggingface_hub.hf_hub_download( | ||
repo_id=hf_handle, filename=path | ||
) | ||
except HFValidationError as e: | ||
raise ValueError( | ||
"Unexpected Hugging Face preset. Hugging Face model handles " | ||
"should have the form 'hf://{org}/{model}'. For example, " | ||
f"'hf://username/bert_base_en'. Received: preset={preset}." | ||
) from e | ||
elif os.path.exists(preset): | ||
# Assume a local filepath. | ||
return os.path.join(preset, path) | ||
else: | ||
raise ValueError( | ||
"Unknown preset identifier. A preset must be a one of:\n" | ||
"1) a built in preset identifier like `'bert_base_en'`\n" | ||
"1) a built-in preset identifier like `'bert_base_en'`\n" | ||
"2) a Kaggle Models handle like `'kaggle://keras/bert/keras/bert_base_en'`\n" | ||
"3) a path to a local preset directory like `'./bert_base_en`\n" | ||
"3) a Hugging Face handle like `'hf://username/bert_base_en'`\n" | ||
"4) a path to a local preset directory like `'./bert_base_en`\n" | ||
"Use `print(cls.presets.keys())` to view all built-in presets for " | ||
"API symbol `cls`.\n" | ||
f"Received: preset='{preset}'" | ||
|
@@ -245,7 +271,9 @@ def upload_preset( | |
uri: The URI identifying model to upload to. | ||
URIs with format | ||
`kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>` | ||
will be uploaded to Kaggle Hub. | ||
will be uploaded to Kaggle Hub while URIs with format | ||
`hf://[<HF_USERNAME>/]<MODEL>` will be uploaded to the Hugging | ||
Face Hub. | ||
preset: The path to the local model preset directory. | ||
allow_incomplete: If True, allows the upload of presets without | ||
a tokenizer configuration. Otherwise, a tokenizer | ||
|
@@ -262,10 +290,33 @@ def upload_preset( | |
if uri.startswith(KAGGLE_PREFIX): | ||
kaggle_handle = uri.removeprefix(KAGGLE_PREFIX) | ||
kagglehub.model_upload(kaggle_handle, preset) | ||
elif uri.startswith(HF_PREFIX): | ||
if huggingface_hub is None: | ||
raise ImportError( | ||
f"`upload_preset()` requires the `huggingface_hub` package to upload to '{uri}'. " | ||
"Please install with `pip install huggingface_hub`." | ||
) | ||
hf_handle = uri.removeprefix(HF_PREFIX) | ||
try: | ||
repo_url = huggingface_hub.create_repo( | ||
repo_id=hf_handle, exist_ok=True | ||
) | ||
except HFValidationError as e: | ||
raise ValueError( | ||
"Unexpected Hugging Face URI. Hugging Face model handles " | ||
"should have the form 'hf://[{org}/]{model}'. For example, " | ||
"'hf://username/bert_base_en' or 'hf://bert_case_en' to implicitly" | ||
f"upload to your user account. Received: URI={uri}." | ||
) from e | ||
huggingface_hub.upload_folder( | ||
repo_id=repo_url.repo_id, folder_path=preset | ||
) | ||
else: | ||
raise ValueError( | ||
f"Unexpected URI `'{uri}'`. Kaggle upload format should follow " | ||
"`kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>`." | ||
"Unknown URI. An URI must be a one of:\n" | ||
"1) a Kaggle Model handle like `'kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>'`\n" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here I followed the existing message but I find it inconsistent with the error in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can merge this as is, and I'll chat with folk later to figuring out our broader naming a push a small fix. |
||
"2) a Hugging Face handle like `'hf://[<HF_USERNAME>/]<MODEL>'`\n" | ||
f"Received: uri='{uri}'." | ||
) | ||
|
||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.