Skip to content

Commit

Permalink
[Serialization] support loading torch state dict from disk (#2687)
Browse files Browse the repository at this point in the history
* add first version of state dict loading helpers

* Rename function

* Update documentation

* Update documentation

* Fix typo

* change titles

* remove file

* fix docstrings

* fix test for torch<=2.1.0

* changes post-review

* fix importing

* fix static imports

* fix documentation

* add requires decorator to the test

* Add mmap parameter

* fix Windows path escaping issue in regex match

* pass device when loading safetensors
  • Loading branch information
hanouticelina authored Dec 13, 2024
1 parent 51b866f commit b75f8d9
Show file tree
Hide file tree
Showing 5 changed files with 656 additions and 17 deletions.
22 changes: 19 additions & 3 deletions docs/source/en/package_reference/serialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,22 @@ rendered properly in your Markdown viewer.

# Serialization

`huggingface_hub` contains helpers to help ML libraries serialize models weights in a standardized way. This part of the lib is still under development and will be improved in future releases. The goal is to harmonize how weights are serialized on the Hub, both to remove code duplication across libraries and to foster conventions on the Hub.
`huggingface_hub` provides helpers to save and load ML model weights in a standardized way. This part of the library is still under development and will be improved in future releases. The goal is to harmonize how weights are saved and loaded across the Hub, both to remove code duplication across libraries and to establish consistent conventions.

## Save torch state dict
## Saving

The main helper of the `serialization` module takes a torch `nn.Module` as input and saves it to disk. It handles the logic to save shared tensors (see [safetensors explanation](https://huggingface.co/docs/safetensors/torch_shared_tensors)) as well as logic to split the state dictionary into shards, using [`split_torch_state_dict_into_shards`] under the hood. At the moment, only `torch` framework is supported.

If you want to save a state dictionary (e.g. a mapping between layer names and related tensors) instead of a `nn.Module`, you can use [`save_torch_state_dict`] which provides the same features. This is useful for example if you want to apply custom logic to the state dict before saving it.

### save_torch_model

[[autodoc]] huggingface_hub.save_torch_model

### save_torch_state_dict

[[autodoc]] huggingface_hub.save_torch_state_dict

## Split state dict into shards

The `serialization` module also contains low-level helpers to split a state dictionary into several shards, while creating a proper index in the process. These helpers are available for `torch` and `tensorflow` tensors and are designed to be easily extended to any other ML frameworks.

Expand All @@ -34,6 +37,19 @@ This is the underlying factory from which each framework-specific helper is deri

[[autodoc]] huggingface_hub.split_state_dict_into_shards_factory

## Loading

The loading helpers support both single-file and sharded checkpoints in either safetensors or pickle format. [`load_torch_model`] takes a `nn.Module` and a checkpoint path (either a single file or a directory) as input and load the weights into the model.

### load_torch_model

[[autodoc]] huggingface_hub.load_torch_model

### load_state_dict_from_file

[[autodoc]] huggingface_hub.load_state_dict_from_file


## Helpers

### get_torch_storage_id
Expand Down
4 changes: 4 additions & 0 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,8 @@
"get_tf_storage_size",
"get_torch_storage_id",
"get_torch_storage_size",
"load_state_dict_from_file",
"load_torch_model",
"save_torch_model",
"save_torch_state_dict",
"split_state_dict_into_shards_factory",
Expand Down Expand Up @@ -987,6 +989,8 @@ def __dir__():
get_tf_storage_size, # noqa: F401
get_torch_storage_id, # noqa: F401
get_torch_storage_size, # noqa: F401
load_state_dict_from_file, # noqa: F401
load_torch_model, # noqa: F401
save_torch_model, # noqa: F401
save_torch_state_dict, # noqa: F401
split_state_dict_into_shards_factory, # noqa: F401
Expand Down
2 changes: 2 additions & 0 deletions src/huggingface_hub/serialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from ._torch import (
get_torch_storage_id,
get_torch_storage_size,
load_state_dict_from_file,
load_torch_model,
save_torch_model,
save_torch_state_dict,
split_torch_state_dict_into_shards,
Expand Down
Loading

0 comments on commit b75f8d9

Please sign in to comment.