Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow saving / loading from Huggingface Hub preset #1510

Merged
merged 6 commits into from
Mar 27, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions keras_nlp/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,15 @@
except ImportError:
kagglehub = None

try:
import huggingface_hub
except ImportError:
huggingface_hub = None

KAGGLE_PREFIX = "kaggle://"
GS_PREFIX = "gs://"
HF_PREFIX = "hf://"

TOKENIZER_ASSET_DIR = "assets/tokenizer"
CONFIG_FILE = "config.json"
TOKENIZER_CONFIG_FILE = "tokenizer.json"
Expand Down Expand Up @@ -69,6 +76,14 @@ def get_file(preset, path):
url,
cache_subdir=os.path.join("models", subdir),
)
elif preset.startswith(HF_PREFIX):
if huggingface_hub is None:
raise ImportError(
f"`from_preset()` requires the `huggingface_hub` package to load from '{preset}'. "
"Please install with `pip install huggingface_hub`."
)
hf_handle = preset.removeprefix(HF_PREFIX)
return huggingface_hub.hf_hub_download(repo_id=hf_handle, filename=path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do the error messages look like if a handle is unformed? Will it read well enough, or should we validate roughly here so we can have a message similar to the Kaggle error above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, hf_handle must correspond to a repo_id so something in the form username/repo_name (e.g. "google/gemma-7b") since the hf prefix has been removed. If it's not the case, an HFValidationError (which is a custom ValueError) is raised. Here are the validation rules we are checking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pushed 24ef262 to raise a ValueError similar to the kaggle handle one. I tried to be as consistent as possible. Please let me know what you think :)

elif os.path.exists(preset):
# Assume a local filepath.
return os.path.join(preset, path)
mattdangerw marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -262,6 +277,15 @@ def upload_preset(
if uri.startswith(KAGGLE_PREFIX):
kaggle_handle = uri.removeprefix(KAGGLE_PREFIX)
kagglehub.model_upload(kaggle_handle, preset)
elif uri.startswith(HF_PREFIX):
if huggingface_hub is None:
raise ImportError(
f"`upload_preset()` requires the `huggingface_hub` package to upload to '{uri}'. "
"Please install with `pip install huggingface_hub`."
)
hf_handle = uri.removeprefix(HF_PREFIX)
repo_url = huggingface_hub.create_repo(repo_id=hf_handle, exist_ok=True)
huggingface_hub.upload_folder(repo_id=repo_url.repo_id, folder_path=preset)
else:
raise ValueError(
f"Unexpected URI `'{uri}'`. Kaggle upload format should follow "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, might want to reword this error message.

Side note: we are kinda inconsistent in how we refer to these model handles. I'm not sure if we should call these URIs or something else, but we should be consistent in our wording. No need to fix on this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've update the error message in 24ef262. I kept the uri naming to be consistent with the rest of the logic but agree with you it would be good to harmonize naming. I think inconsistency comes from the fact that in get_file the preset refers to either a local directory or a URI while in upload_preset, the preset refers to a local directory and uri to the URI. So maybe not so inconsistent after all?

Expand Down
Loading