Skip to content

Commit

Permalink
feat(vc): add knnvc model
Browse files Browse the repository at this point in the history
  • Loading branch information
eginhard committed Jan 15, 2025
1 parent e88b4b6 commit ea21777
Show file tree
Hide file tree
Showing 13 changed files with 317 additions and 18 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ repository are also still a useful source of information.

### Voice Conversion
- [FreeVC](https://arxiv.org/abs/2210.15418)
- [kNN-VC](https://doi.org/10.21437/Interspeech.2023-419)
- [OpenVoice](https://arxiv.org/abs/2312.01479)

### Others
Expand Down
26 changes: 26 additions & 0 deletions TTS/.models.json
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,22 @@
"license": "apache 2.0"
}
},
"librispeech100": {
"wavlm-hifigan": {
"description": "HiFiGAN vocoder for WavLM features from kNN-VC",
"github_rls_url": "https://github.com/idiap/coqui-ai-TTS/releases/download/v0.25.2_models/vocoder_models--en--librispeech100--wavlm-hifigan.zip",
"commit": "cfba7e0",
"author": "Benjamin van Niekerk @bshall, Matthew Baas @RF5",
"license": "MIT"
},
"wavlm-hifigan_prematched": {
"description": "Prematched HiFiGAN vocoder for WavLM features from kNN-VC",
"github_rls_url": "https://github.com/idiap/coqui-ai-TTS/releases/download/v0.25.2_models/vocoder_models--en--librispeech100--wavlm-hifigan_prematched.zip",
"commit": "cfba7e0",
"author": "Benjamin van Niekerk @bshall, Matthew Baas @RF5",
"license": "MIT"
}
},
"ljspeech": {
"multiband-melgan": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.6.1_models/vocoder_models--en--ljspeech--multiband-melgan.zip",
Expand Down Expand Up @@ -927,18 +943,27 @@
"freevc24": {
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/voice_conversion_models--multilingual--vctk--freevc24.zip",
"description": "FreeVC model trained on VCTK dataset from https://github.com/OlaWod/FreeVC",
"default_vocoder": null,
"author": "Jing-Yi Li @OlaWod",
"license": "MIT",
"commit": null
}
},
"multi-dataset": {
"knnvc": {
"description": "kNN-VC model from https://github.com/bshall/knn-vc",
"default_vocoder": "vocoder_models/en/librispeech100/wavlm-hifigan_prematched",
"author": "Benjamin van Niekerk @bshall, Matthew Baas @RF5",
"license": "MIT",
"commit": null
},
"openvoice_v1": {
"hf_url": [
"https://huggingface.co/myshell-ai/OpenVoice/resolve/main/checkpoints/converter/config.json",
"https://huggingface.co/myshell-ai/OpenVoice/resolve/main/checkpoints/converter/checkpoint.pth"
],
"description": "OpenVoice VC model from https://huggingface.co/myshell-ai/OpenVoiceV2",
"default_vocoder": null,
"author": "MyShell.ai",
"license": "MIT",
"commit": null
Expand All @@ -949,6 +974,7 @@
"https://huggingface.co/myshell-ai/OpenVoiceV2/resolve/main/converter/checkpoint.pth"
],
"description": "OpenVoice VC model from https://huggingface.co/myshell-ai/OpenVoiceV2",
"default_vocoder": null,
"author": "MyShell.ai",
"license": "MIT",
"commit": null
Expand Down
1 change: 1 addition & 0 deletions TTS/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def to_camel(text):
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
text = text.replace("Tts", "TTS")
text = text.replace("vc", "VC")
text = text.replace("Knn", "KNN")
return text


Expand Down
13 changes: 10 additions & 3 deletions TTS/utils/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing_extensions import Required

from TTS.config import load_config, read_json_with_comments
from TTS.vc.configs.knnvc_config import KNNVCConfig

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -267,9 +268,9 @@ def set_model_url(model_item: ModelItem) -> ModelItem:
model_item["model_url"] = model_item["github_rls_url"]
elif "hf_url" in model_item:
model_item["model_url"] = model_item["hf_url"]
elif "fairseq" in model_item["model_name"]:
elif "fairseq" in model_item.get("model_name", ""):
model_item["model_url"] = "https://dl.fbaipublicfiles.com/mms/tts/"
elif "xtts" in model_item["model_name"]:
elif "xtts" in model_item.get("model_name", ""):
model_item["model_url"] = "https://huggingface.co/coqui/"
return model_item

Expand Down Expand Up @@ -367,6 +368,9 @@ def create_dir_and_download_model(self, model_name: str, model_item: ModelItem,
logger.exception("Failed to download the model file to %s", output_path)
rmtree(output_path)
raise e
checkpoints = list(Path(output_path).glob("*.pt*"))
if len(checkpoints) == 1:
checkpoints[0].rename(checkpoints[0].parent / "model.pth")
self.print_model_license(model_item=model_item)

def check_if_configs_are_equal(self, model_name: str, model_item: ModelItem, output_path: Path) -> None:
Expand Down Expand Up @@ -431,11 +435,14 @@ def download_model(self, model_name: str) -> tuple[Path, Optional[Path], ModelIt
output_model_path = output_path
output_config_path = None
if (
model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name
model not in ["tortoise-v2", "bark", "knnvc"] and "fairseq" not in model_name and "xtts" not in model_name
): # TODO:This is stupid but don't care for now.
output_model_path, output_config_path = self._find_files(output_path)
else:
output_config_path = output_model_path / "config.json"
if model == "knnvc" and not output_config_path.exists():
knnvc_config = KNNVCConfig()
knnvc_config.save_json(output_config_path)
# update paths in the config.json
self._update_paths(output_path, output_config_path)
return output_model_path, output_config_path, model_item
Expand Down
4 changes: 3 additions & 1 deletion TTS/utils/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ def _load_vc(self, vc_checkpoint: str, vc_config_path: str, use_cuda: bool) -> N
"""
# pylint: disable=global-statement
self.vc_config = load_config(vc_config_path)
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
self.output_sample_rate = self.vc_config.audio.get(
"output_sample_rate", self.vc_config.audio.get("sample_rate", None)
)
self.vc_model = setup_vc_model(config=self.vc_config)
self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint)
if use_cuda:
Expand Down
59 changes: 59 additions & 0 deletions TTS/vc/configs/knnvc_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from dataclasses import dataclass, field

from coqpit import Coqpit

from TTS.config.shared_configs import BaseAudioConfig
from TTS.vc.configs.shared_configs import BaseVCConfig


@dataclass
class KNNVCAudioConfig(BaseAudioConfig):
"""Audio configuration.
Args:
sample_rate (int):
The sampling rate of the input waveform.
"""

sample_rate: int = field(default=16000)


@dataclass
class KNNVCArgs(Coqpit):
"""Model arguments.
Args:
ssl_dim (int):
The dimension of the self-supervised learning embedding.
"""

ssl_dim: int = field(default=1024)


@dataclass
class KNNVCConfig(BaseVCConfig):
"""Parameters.
Args:
model (str):
Model name. Do not change unless you know what you are doing.
model_args (KNNVCArgs):
Model architecture arguments. Defaults to `KNNVCArgs()`.
audio (KNNVCAudioConfig):
Audio processing configuration. Defaults to `KNNVCAudioConfig()`.
wavlm_layer (int):
WavLM layer to use for feature extraction.
topk (int):
k in the kNN -- the number of nearest neighbors to average over
"""

model: str = "knnvc"
model_args: KNNVCArgs = field(default_factory=KNNVCArgs)
audio: KNNVCAudioConfig = field(default_factory=KNNVCAudioConfig)

wavlm_layer: int = 6
topk: int = 4
2 changes: 1 addition & 1 deletion TTS/vc/layers/freevc/wavlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
model_uri = "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/WavLM-Large.pt"


def get_wavlm(device="cpu"):
def get_wavlm(device="cpu") -> WavLM:
"""Download the model and return the model object."""

output_path = get_user_data_dir("tts")
Expand Down
15 changes: 11 additions & 4 deletions TTS/vc/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import importlib
import logging
import re
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union

from TTS.vc.configs.shared_configs import BaseVCConfig
from TTS.vc.models.base_vc import BaseVC

logger = logging.getLogger(__name__)


def setup_model(config: BaseVCConfig) -> BaseVC:
logger.info("Using model: %s", config.model)
# fetch the right model implementation.
if "model" in config and config["model"].lower() == "freevc":
if config["model"].lower() == "freevc":
MyModel = importlib.import_module("TTS.vc.models.freevc").FreeVC
model = MyModel.init_from_config(config)
return model
elif config["model"].lower() == "knnvc":
MyModel = importlib.import_module("TTS.vc.models.knnvc").KNNVC
else:
msg = f"Model {config.model} does not exist!"
raise ValueError(msg)
return MyModel.init_from_config(config)
Loading

0 comments on commit ea21777

Please sign in to comment.