-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow saving / loading from Huggingface Hub preset #1510
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,8 +24,14 @@ | |
except ImportError: | ||
kagglehub = None | ||
|
||
try: | ||
import huggingface_hub | ||
except ImportError: | ||
huggingface_hub = None | ||
|
||
KAGGLE_PREFIX = "kaggle://" | ||
GS_PREFIX = "gs://" | ||
HF_PREFIX = "hf://" | ||
TOKENIZER_ASSET_DIR = "assets/tokenizer" | ||
|
||
|
||
|
@@ -64,6 +70,14 @@ def get_file(preset, path): | |
url, | ||
cache_subdir=os.path.join("models", subdir), | ||
) | ||
elif preset.startswith(HF_PREFIX): | ||
if huggingface_hub is None: | ||
raise ImportError( | ||
f"`from_preset()` requires the `huggingface_hub` package to load from '{preset}'. " | ||
"Please install with `pip install huggingface_hub`." | ||
) | ||
hf_handle = preset.removeprefix(HF_PREFIX) | ||
return huggingface_hub.hf_hub_download(repo_id=hf_handle, filename=path) | ||
elif os.path.exists(preset): | ||
# Assume a local filepath. | ||
return os.path.join(preset, path) | ||
mattdangerw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
@@ -109,6 +123,9 @@ def save_to_preset( | |
weights_filename="model.weights.h5", | ||
): | ||
"""Save a KerasNLP layer to a preset directory.""" | ||
push_to_hf = preset.startswith(HF_PREFIX) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we might end up doing some tweaks. We want to support a split flow for saving preprocessing and model weights. So our final flow might need to allow something like this tokenizer.save_to_preset(dir)
backbone.save_to_preset(dir)
upload_preset("hf://user/model", dir) I don't think we need to solve that here though! @SamanehSaadat is working on a draft of our upload flow currently. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes perfect sense! |
||
preset = preset.removeprefix(HF_PREFIX) | ||
|
||
os.makedirs(preset, exist_ok=True) | ||
|
||
# Save tokenizers assets. | ||
|
@@ -154,6 +171,16 @@ def save_to_preset( | |
with open(metadata_path, "w") as metadata_file: | ||
metadata_file.write(json.dumps(metadata, indent=4)) | ||
|
||
# If preset starts with `hf://`, push to the Hugging Face Hub. | ||
if push_to_hf: | ||
if huggingface_hub is None: | ||
raise ImportError( | ||
f"`save_to_preset()` requires the `huggingface_hub` package to save to '{preset}'. " | ||
"Please install with `pip install huggingface_hub`." | ||
) | ||
repo_url = huggingface_hub.create_repo(repo_id=preset, exist_ok=True) | ||
huggingface_hub.upload_folder(repo_id=repo_url.repo_id, folder_path=preset) | ||
|
||
|
||
def load_from_preset( | ||
preset, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do the error messages look like if a handle is unformed? Will it read well enough, or should we validate roughly here so we can have a message similar to the Kaggle error above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At this point,
hf_handle
must correspond to a repo_id so something in the formusername/repo_name
(e.g. "google/gemma-7b") since the hf prefix has been removed. If it's not the case, anHFValidationError
(which is a customValueError
) is raised. Here are the validation rules we are checking.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just pushed 24ef262 to raise a ValueError similar to the kaggle handle one. I tried to be as consistent as possible. Please let me know what you think :)