diff --git a/audiocraft/models/loaders.py b/audiocraft/models/loaders.py index 7fd49d84..9411a913 100644 --- a/audiocraft/models/loaders.py +++ b/audiocraft/models/loaders.py @@ -26,6 +26,7 @@ from omegaconf import OmegaConf, DictConfig import torch +import audiocraft from . import builders from .encodec import CompressionModel @@ -60,7 +61,7 @@ def _get_state_dict( else: assert filename is not None, "filename needs to be defined if using HF checkpoints" - file = hf_hub_download(repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir) + file = hf_hub_download(repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir, library_name="audiocraft", library_version=audiocraft.__version__) return torch.load(file, map_location=device)