Skip to content

Commit e05c393

Browse files
committed
Merge branch 'main' into generate-modelcard-in-model-hub-mixin
2 parents a78cbc1 + 5a40707 commit e05c393

File tree

9 files changed

+310
-107
lines changed

9 files changed

+310
-107
lines changed

docs/source/en/guides/download.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ repo. For example if `filename="data/train.csv"` and `local_dir="path/to/folder"
144144
- If `local_dir_use_symlinks=True` is set, all files are symlinked for an optimal disk space optimization. This is
145145
for example useful when downloading a huge dataset with thousands of small files.
146146
- Finally, if you don't want symlinks at all you can disable them (`local_dir_use_symlinks=False`). The cache directory
147-
will still be used to check wether the file is already cached or not. If already cached, the file is **duplicated**
147+
will still be used to check whether the file is already cached or not. If already cached, the file is **duplicated**
148148
from the cache (i.e. saves bandwidth but increases disk usage). If the file is not already cached, it will be
149149
downloaded and moved directly to the local dir. This means that if you need to reuse it somewhere else later, it
150150
will be **re-downloaded**.

docs/source/en/guides/integrations.md

Lines changed: 28 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ are ready to go. You don't need to worry about stuff like repo creation, commits
148148
of this is handled by the mixin and is available to your users. The Mixin also ensures that public methods are well
149149
documented and type annotated.
150150

151-
As a bonus, [`ModelHubMixin`] handles the model configuration for you. In some cases, you have a `config` input parameter when initializing your class (dictionary or dataclass containing high-level settings). In such cases, the `config` value is automatically serialized into a `config.json` dictionary for you. When re-loading the model from the Hub, the configuration is correctly deserialized. Make sure to use type annotation if you want to deserialize it as a dataclass. The big advantage of having a `config.json` file in your model repository is that it automatically enables the analytics on the Hub (e.g. the "downloads" count).
151+
As a bonus, [`ModelHubMixin`] handles the model configuration for you. If your `__init__` method expects a `config` input, it will be automatically saved in the repo when calling `save_pretrained` and reloaded correctly by `load_pretrained`. Moreover, if the `config` input parameter is annotated with dataclass type (e.g. `config: Optional[MyConfigClass] = None`), then the `config` value will be correctly deserialized for you. Finally, all jsonable values passed at initialization will be also stored in the config file. This means you don't necessarily have to expect a `config` input to benefit from it. The big advantage of having a `config.json` file in your model repository is that it automatically enables the analytics on the Hub (e.g. the "downloads" count).
152152

153153
### A concrete example: PyTorch
154154

@@ -159,40 +159,38 @@ A good example of what we saw above is [`PyTorchModelHubMixin`], our integration
159159
Here is how any user can load/save a PyTorch model from/to the Hub:
160160

161161
```python
162-
>>> from dataclasses import dataclass
163162
>>> import torch
164163
>>> import torch.nn as nn
165164
>>> from huggingface_hub import PyTorchModelHubMixin
166165

167-
# 0. (optional) define a config class
168-
>>> @dataclass
169-
... class Config:
170-
... hidden_size: int = 512
171-
... vocab_size: int = 30000
172-
... output_size: int = 4
173166

174-
# 1. Define your Pytorch model exactly the same way you are used to
167+
# Define your Pytorch model exactly the same way you are used to
175168
>>> class MyModel(nn.Module, PyTorchModelHubMixin): # multiple inheritance
176-
... def __init__(self, config: Config):
169+
... def __init__(self, hidden_size: int = 512, vocab_size: int = 30000, output_size: int = 4):
177170
... super().__init__()
178-
... self.param = nn.Parameter(torch.rand(config.hidden_size, config.vocab_size))
179-
... self.linear = nn.Linear(config.output_size, config.vocab_size)
171+
... self.param = nn.Parameter(torch.rand(hidden_size, vocab_size))
172+
... self.linear = nn.Linear(output_size, vocab_size)
180173

181174
... def forward(self, x):
182175
... return self.linear(x + self.param)
183176

184-
>>> model = MyModel(Config(hidden_size=128))
177+
# 1. Create model
178+
>>> model = MyModel(hidden_size=128)
179+
180+
# Config is automatically created based on input + default values
181+
>>> model.config
182+
{"hidden_size": 128, "vocab_size": 30000, "output_size": 4}
185183

186184
# 2. (optional) Save model to local directory
187185
>>> model.save_pretrained("path/to/my-awesome-model")
188186

189187
# 3. Push model weights to the Hub
190188
>>> model.push_to_hub("my-awesome-model")
191189

192-
# 4. Initialize model from the Hub
190+
# 4. Initialize model from the Hub => config has been preserved
193191
>>> model = MyModel.from_pretrained("username/my-awesome-model")
194192
>>> model.config
195-
Config(hidden_size=128, vocab_size=30000, output_size=4)
193+
{"hidden_size": 128, "vocab_size": 30000, "output_size": 4}
196194
```
197195

198196
#### Implementation
@@ -211,25 +209,15 @@ class PyTorchModelHubMixin(ModelHubMixin):
211209
2. Implement the `_save_pretrained` method:
212210

213211
```py
214-
from huggingface_hub import ModelCard, ModelCardData
212+
from huggingface_hub import ModelHubMixin
215213

216214
class PyTorchModelHubMixin(ModelHubMixin):
217215
(...)
218216

219-
def _save_pretrained(self, save_directory: Path):
220-
"""Generate Model Card and save weights from a Pytorch model to a local directory."""
221-
model_card = ModelCard.from_template(
222-
card_data=ModelCardData(
223-
license='mit',
224-
library_name="pytorch",
225-
...
226-
),
227-
model_summary=...,
228-
model_type=...,
229-
...
230-
)
231-
(save_directory / "README.md").write_text(str(model))
232-
torch.save(obj=self.module.state_dict(), f=save_directory / "pytorch_model.bin")
217+
def _save_pretrained(self, save_directory: Path) -> None:
218+
"""Save weights from a Pytorch model to a local directory."""
219+
save_model_as_safetensor(self.module, str(save_directory / SAFETENSORS_SINGLE_FILE))
220+
233221
```
234222

235223
3. Implement the `_from_pretrained` method:
@@ -255,28 +243,24 @@ class PyTorchModelHubMixin(ModelHubMixin):
255243
**model_kwargs,
256244
):
257245
"""Load Pytorch pretrained weights and return the loaded model."""
258-
if os.path.isdir(model_id): # Can either be a local directory
259-
print("Loading weights from local directory")
260-
model_file = os.path.join(model_id, "pytorch_model.bin")
261-
else: # Or a model on the Hub
262-
model_file = hf_hub_download( # Download from the hub, passing same input args
246+
model = cls(**model_kwargs)
247+
if os.path.isdir(model_id):
248+
print("Loading weights from local directory")
249+
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
250+
return cls._load_as_safetensor(model, model_file, map_location, strict)
251+
252+
model_file = hf_hub_download(
263253
repo_id=model_id,
264-
filename="pytorch_model.bin",
254+
filename=SAFETENSORS_SINGLE_FILE,
265255
revision=revision,
266256
cache_dir=cache_dir,
267257
force_download=force_download,
268258
proxies=proxies,
269259
resume_download=resume_download,
270260
token=token,
271261
local_files_only=local_files_only,
272-
)
273-
274-
# Load model and return - custom logic depending on your framework
275-
model = cls(**model_kwargs)
276-
state_dict = torch.load(model_file, map_location=torch.device(map_location))
277-
model.load_state_dict(state_dict, strict=strict)
278-
model.eval()
279-
return model
262+
)
263+
return cls._load_as_safetensor(model, model_file, map_location, strict)
280264
```
281265

282266
And that's it! Your library now enables users to upload and download files to and from the Hub.

src/huggingface_hub/file_download.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1211,7 +1211,6 @@ def hf_hub_download(
12111211
raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}")
12121212

12131213
storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type))
1214-
os.makedirs(storage_folder, exist_ok=True)
12151214

12161215
# cross platform transcription of filename, to be used as a local file path.
12171216
relative_filename = os.path.join(*filename.split("/"))

0 commit comments

Comments
 (0)