From 85473bbe878ce761468d450158d374faa37f045a Mon Sep 17 00:00:00 2001 From: Vaibhav Srivastav Date: Thu, 19 Oct 2023 14:47:47 +0200 Subject: [PATCH 1/3] add library_name for better model tracking on the Hub. --- audiocraft/models/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/audiocraft/models/loaders.py b/audiocraft/models/loaders.py index 7fd49d84..37e829d3 100644 --- a/audiocraft/models/loaders.py +++ b/audiocraft/models/loaders.py @@ -60,7 +60,7 @@ def _get_state_dict( else: assert filename is not None, "filename needs to be defined if using HF checkpoints" - file = hf_hub_download(repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir) + file = hf_hub_download(repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir, library_name="audiocraft") return torch.load(file, map_location=device) From 029b6d6d6e53758bbb2bdf6bf9033e8be680b8ac Mon Sep 17 00:00:00 2001 From: Vaibhav Srivastav Date: Thu, 19 Oct 2023 15:31:07 +0200 Subject: [PATCH 2/3] add library_version. --- audiocraft/models/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/audiocraft/models/loaders.py b/audiocraft/models/loaders.py index 37e829d3..c11cc1e7 100644 --- a/audiocraft/models/loaders.py +++ b/audiocraft/models/loaders.py @@ -60,7 +60,7 @@ def _get_state_dict( else: assert filename is not None, "filename needs to be defined if using HF checkpoints" - file = hf_hub_download(repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir, library_name="audiocraft") + file = hf_hub_download(repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir, library_name="audiocraft", library_version="1.0.0") return torch.load(file, map_location=device) From 4aec50a294d0078afcbbfc93f6017843b0815ba5 Mon Sep 17 00:00:00 2001 From: Vaibhav Srivastav Date: Mon, 30 Oct 2023 18:27:29 +0100 Subject: [PATCH 3/3] up --- audiocraft/models/loaders.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/audiocraft/models/loaders.py b/audiocraft/models/loaders.py index c11cc1e7..9411a913 100644 --- a/audiocraft/models/loaders.py +++ b/audiocraft/models/loaders.py @@ -26,6 +26,7 @@ from omegaconf import OmegaConf, DictConfig import torch +import audiocraft from . import builders from .encodec import CompressionModel @@ -60,7 +61,7 @@ def _get_state_dict( else: assert filename is not None, "filename needs to be defined if using HF checkpoints" - file = hf_hub_download(repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir, library_name="audiocraft", library_version="1.0.0") + file = hf_hub_download(repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir, library_name="audiocraft", library_version=audiocraft.__version__) return torch.load(file, map_location=device)