-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow saving / loading from Huggingface Hub preset #1510
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,8 +27,15 @@ | |
except ImportError: | ||
kagglehub = None | ||
|
||
try: | ||
import huggingface_hub | ||
except ImportError: | ||
huggingface_hub = None | ||
|
||
KAGGLE_PREFIX = "kaggle://" | ||
GS_PREFIX = "gs://" | ||
HF_PREFIX = "hf://" | ||
|
||
TOKENIZER_ASSET_DIR = "assets/tokenizer" | ||
CONFIG_FILE = "config.json" | ||
TOKENIZER_CONFIG_FILE = "tokenizer.json" | ||
|
@@ -69,6 +76,14 @@ def get_file(preset, path): | |
url, | ||
cache_subdir=os.path.join("models", subdir), | ||
) | ||
elif preset.startswith(HF_PREFIX): | ||
if huggingface_hub is None: | ||
raise ImportError( | ||
f"`from_preset()` requires the `huggingface_hub` package to load from '{preset}'. " | ||
"Please install with `pip install huggingface_hub`." | ||
) | ||
hf_handle = preset.removeprefix(HF_PREFIX) | ||
return huggingface_hub.hf_hub_download(repo_id=hf_handle, filename=path) | ||
elif os.path.exists(preset): | ||
# Assume a local filepath. | ||
return os.path.join(preset, path) | ||
mattdangerw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
@@ -262,6 +277,15 @@ def upload_preset( | |
if uri.startswith(KAGGLE_PREFIX): | ||
kaggle_handle = uri.removeprefix(KAGGLE_PREFIX) | ||
kagglehub.model_upload(kaggle_handle, preset) | ||
elif uri.startswith(HF_PREFIX): | ||
if huggingface_hub is None: | ||
raise ImportError( | ||
f"`upload_preset()` requires the `huggingface_hub` package to upload to '{uri}'. " | ||
"Please install with `pip install huggingface_hub`." | ||
) | ||
hf_handle = uri.removeprefix(HF_PREFIX) | ||
repo_url = huggingface_hub.create_repo(repo_id=hf_handle, exist_ok=True) | ||
huggingface_hub.upload_folder(repo_id=repo_url.repo_id, folder_path=preset) | ||
else: | ||
raise ValueError( | ||
f"Unexpected URI `'{uri}'`. Kaggle upload format should follow " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, might want to reword this error message. Side note: we are kinda inconsistent in how we refer to these model handles. I'm not sure if we should call these URIs or something else, but we should be consistent in our wording. No need to fix on this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've update the error message in 24ef262. I kept the |
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do the error messages look like if a handle is unformed? Will it read well enough, or should we validate roughly here so we can have a message similar to the Kaggle error above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At this point,
hf_handle
must correspond to a repo_id so something in the formusername/repo_name
(e.g. "google/gemma-7b") since the hf prefix has been removed. If it's not the case, anHFValidationError
(which is a customValueError
) is raised. Here are the validation rules we are checking.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just pushed 24ef262 to raise a ValueError similar to the kaggle handle one. I tried to be as consistent as possible. Please let me know what you think :)