Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow saving / loading from Huggingface Hub preset #1510

Merged
merged 6 commits into from
Mar 27, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions keras_nlp/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,14 @@
except ImportError:
kagglehub = None

try:
import huggingface_hub
except ImportError:
huggingface_hub = None

KAGGLE_PREFIX = "kaggle://"
GS_PREFIX = "gs://"
HF_PREFIX = "hf://"
TOKENIZER_ASSET_DIR = "assets/tokenizer"


Expand Down Expand Up @@ -64,6 +70,14 @@ def get_file(preset, path):
url,
cache_subdir=os.path.join("models", subdir),
)
elif preset.startswith(HF_PREFIX):
if huggingface_hub is None:
raise ImportError(
f"`from_preset()` requires the `huggingface_hub` package to load from '{preset}'. "
"Please install with `pip install huggingface_hub`."
)
hf_handle = preset.removeprefix(HF_PREFIX)
return huggingface_hub.hf_hub_download(repo_id=hf_handle, filename=path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do the error messages look like if a handle is unformed? Will it read well enough, or should we validate roughly here so we can have a message similar to the Kaggle error above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, hf_handle must correspond to a repo_id so something in the form username/repo_name (e.g. "google/gemma-7b") since the hf prefix has been removed. If it's not the case, an HFValidationError (which is a custom ValueError) is raised. Here are the validation rules we are checking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pushed 24ef262 to raise a ValueError similar to the kaggle handle one. I tried to be as consistent as possible. Please let me know what you think :)

elif os.path.exists(preset):
# Assume a local filepath.
return os.path.join(preset, path)
mattdangerw marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -109,6 +123,9 @@ def save_to_preset(
weights_filename="model.weights.h5",
):
"""Save a KerasNLP layer to a preset directory."""
push_to_hf = preset.startswith(HF_PREFIX)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we might end up doing some tweaks. We want to support a split flow for saving preprocessing and model weights. So our final flow might need to allow something like this

tokenizer.save_to_preset(dir)
backbone.save_to_preset(dir)
upload_preset("hf://user/model", dir)

I don't think we need to solve that here though! @SamanehSaadat is working on a draft of our upload flow currently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes perfect sense!

preset = preset.removeprefix(HF_PREFIX)

os.makedirs(preset, exist_ok=True)

# Save tokenizers assets.
Expand Down Expand Up @@ -154,6 +171,16 @@ def save_to_preset(
with open(metadata_path, "w") as metadata_file:
metadata_file.write(json.dumps(metadata, indent=4))

# If preset starts with `hf://`, push to the Hugging Face Hub.
if push_to_hf:
if huggingface_hub is None:
raise ImportError(
f"`save_to_preset()` requires the `huggingface_hub` package to save to '{preset}'. "
"Please install with `pip install huggingface_hub`."
)
repo_url = huggingface_hub.create_repo(repo_id=preset, exist_ok=True)
huggingface_hub.upload_folder(repo_id=repo_url.repo_id, folder_path=preset)


def load_from_preset(
preset,
Expand Down
Loading