Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow saving / loading from Huggingface Hub preset #1510

Merged
merged 6 commits into from
Mar 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 56 additions & 5 deletions keras_nlp/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,16 @@
except ImportError:
kagglehub = None

try:
import huggingface_hub
from huggingface_hub.utils import HFValidationError
except ImportError:
huggingface_hub = None

KAGGLE_PREFIX = "kaggle://"
GS_PREFIX = "gs://"
HF_PREFIX = "hf://"

TOKENIZER_ASSET_DIR = "assets/tokenizer"
CONFIG_FILE = "config.json"
TOKENIZER_CONFIG_FILE = "tokenizer.json"
Expand Down Expand Up @@ -69,15 +77,33 @@ def get_file(preset, path):
url,
cache_subdir=os.path.join("models", subdir),
)
elif preset.startswith(HF_PREFIX):
if huggingface_hub is None:
raise ImportError(
f"`from_preset()` requires the `huggingface_hub` package to load from '{preset}'. "
"Please install with `pip install huggingface_hub`."
)
hf_handle = preset.removeprefix(HF_PREFIX)
try:
return huggingface_hub.hf_hub_download(
repo_id=hf_handle, filename=path
)
except HFValidationError as e:
raise ValueError(
"Unexpected Hugging Face preset. Hugging Face model handles "
"should have the form 'hf://{org}/{model}'. For example, "
f"'hf://username/bert_base_en'. Received: preset={preset}."
) from e
elif os.path.exists(preset):
# Assume a local filepath.
return os.path.join(preset, path)
mattdangerw marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError(
"Unknown preset identifier. A preset must be a one of:\n"
"1) a built in preset identifier like `'bert_base_en'`\n"
"1) a built-in preset identifier like `'bert_base_en'`\n"
"2) a Kaggle Models handle like `'kaggle://keras/bert/keras/bert_base_en'`\n"
"3) a path to a local preset directory like `'./bert_base_en`\n"
"3) a Hugging Face handle like `'hf://username/bert_base_en'`\n"
"4) a path to a local preset directory like `'./bert_base_en`\n"
"Use `print(cls.presets.keys())` to view all built-in presets for "
"API symbol `cls`.\n"
f"Received: preset='{preset}'"
Expand Down Expand Up @@ -245,7 +271,9 @@ def upload_preset(
uri: The URI identifying model to upload to.
URIs with format
`kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>`
will be uploaded to Kaggle Hub.
will be uploaded to Kaggle Hub while URIs with format
`hf://[<HF_USERNAME>/]<MODEL>` will be uploaded to the Hugging
Face Hub.
preset: The path to the local model preset directory.
allow_incomplete: If True, allows the upload of presets without
a tokenizer configuration. Otherwise, a tokenizer
Expand All @@ -262,10 +290,33 @@ def upload_preset(
if uri.startswith(KAGGLE_PREFIX):
kaggle_handle = uri.removeprefix(KAGGLE_PREFIX)
kagglehub.model_upload(kaggle_handle, preset)
elif uri.startswith(HF_PREFIX):
if huggingface_hub is None:
raise ImportError(
f"`upload_preset()` requires the `huggingface_hub` package to upload to '{uri}'. "
"Please install with `pip install huggingface_hub`."
)
hf_handle = uri.removeprefix(HF_PREFIX)
try:
repo_url = huggingface_hub.create_repo(
repo_id=hf_handle, exist_ok=True
)
except HFValidationError as e:
raise ValueError(
"Unexpected Hugging Face URI. Hugging Face model handles "
"should have the form 'hf://[{org}/]{model}'. For example, "
"'hf://username/bert_base_en' or 'hf://bert_case_en' to implicitly"
f"upload to your user account. Received: URI={uri}."
) from e
huggingface_hub.upload_folder(
repo_id=repo_url.repo_id, folder_path=preset
)
else:
raise ValueError(
f"Unexpected URI `'{uri}'`. Kaggle upload format should follow "
"`kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>`."
"Unknown URI. An URI must be a one of:\n"
"1) a Kaggle Model handle like `'kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>'`\n"
Copy link
Contributor Author

@Wauplin Wauplin Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I followed the existing message but I find it inconsistent with the error in get_file. In get_file, we provide real examples (e.g. 'kaggle://keras/bert/keras/bert_base_en') while here we only provide the format ('kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>'). Both are fine IMO but if you prefer one or the other, please let me know and I can update in this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can merge this as is, and I'll chat with folk later to figuring out our broader naming a push a small fix.

"2) a Hugging Face handle like `'hf://[<HF_USERNAME>/]<MODEL>'`\n"
f"Received: uri='{uri}'."
)


Expand Down
Loading