Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serialization] support loading torch state dict from disk #2687

Merged
merged 18 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions docs/source/en/package_reference/serialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,22 @@ rendered properly in your Markdown viewer.

# Serialization

`huggingface_hub` contains helpers to help ML libraries serialize models weights in a standardized way. This part of the lib is still under development and will be improved in future releases. The goal is to harmonize how weights are serialized on the Hub, both to remove code duplication across libraries and to foster conventions on the Hub.
`huggingface_hub` provides helpers to save and load ML model weights in a standardized way. This part of the library is still under development and will be improved in future releases. The goal is to harmonize how weights are saved and loaded across the Hub, both to remove code duplication across libraries and to establish consistent conventions.

## Save torch state dict
## Saving

The main helper of the `serialization` module takes a torch `nn.Module` as input and saves it to disk. It handles the logic to save shared tensors (see [safetensors explanation](https://huggingface.co/docs/safetensors/torch_shared_tensors)) as well as logic to split the state dictionary into shards, using [`split_torch_state_dict_into_shards`] under the hood. At the moment, only `torch` framework is supported.

If you want to save a state dictionary (e.g. a mapping between layer names and related tensors) instead of a `nn.Module`, you can use [`save_torch_state_dict`] which provides the same features. This is useful for example if you want to apply custom logic to the state dict before saving it.

### save_torch_model

[[autodoc]] huggingface_hub.save_torch_model

### save_torch_state_dict

[[autodoc]] huggingface_hub.save_torch_state_dict
hanouticelina marked this conversation as resolved.
Show resolved Hide resolved

## Split state dict into shards

The `serialization` module also contains low-level helpers to split a state dictionary into several shards, while creating a proper index in the process. These helpers are available for `torch` and `tensorflow` tensors and are designed to be easily extended to any other ML frameworks.

Expand All @@ -34,6 +37,19 @@ This is the underlying factory from which each framework-specific helper is deri

[[autodoc]] huggingface_hub.split_state_dict_into_shards_factory

## Loading

The loading helpers support both single-file and sharded checkpoints in either safetensors or pickle format. [`load_torch_model`] takes a `nn.Module` and a checkpoint path (either a single file or a directory) as input and load the weights into the model.

### load_torch_model

[[autodoc]] huggingface_hub.load_torch_model

### load_state_dict_from_file

[[autodoc]] huggingface_hub.load_state_dict_from_file


## Helpers

### get_torch_storage_id
Expand Down
4 changes: 4 additions & 0 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,8 @@
"get_tf_storage_size",
"get_torch_storage_id",
"get_torch_storage_size",
"load_state_dict_from_file",
hanouticelina marked this conversation as resolved.
Show resolved Hide resolved
"load_torch_model",
"save_torch_model",
"save_torch_state_dict",
"split_state_dict_into_shards_factory",
Expand Down Expand Up @@ -987,6 +989,8 @@ def __dir__():
get_tf_storage_size, # noqa: F401
get_torch_storage_id, # noqa: F401
get_torch_storage_size, # noqa: F401
load_state_dict_from_file, # noqa: F401
load_torch_model, # noqa: F401
save_torch_model, # noqa: F401
save_torch_state_dict, # noqa: F401
split_state_dict_into_shards_factory, # noqa: F401
Expand Down
2 changes: 2 additions & 0 deletions src/huggingface_hub/serialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from ._torch import (
get_torch_storage_id,
get_torch_storage_size,
load_state_dict_from_file,
load_torch_model,
save_torch_model,
save_torch_state_dict,
split_torch_state_dict_into_shards,
Expand Down
Loading
Loading