forked from idiap/coqui-ai-TTS
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
317 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from dataclasses import dataclass, field | ||
|
||
from coqpit import Coqpit | ||
|
||
from TTS.config.shared_configs import BaseAudioConfig | ||
from TTS.vc.configs.shared_configs import BaseVCConfig | ||
|
||
|
||
@dataclass | ||
class KNNVCAudioConfig(BaseAudioConfig): | ||
"""Audio configuration. | ||
Args: | ||
sample_rate (int): | ||
The sampling rate of the input waveform. | ||
""" | ||
|
||
sample_rate: int = field(default=16000) | ||
|
||
|
||
@dataclass | ||
class KNNVCArgs(Coqpit): | ||
"""Model arguments. | ||
Args: | ||
ssl_dim (int): | ||
The dimension of the self-supervised learning embedding. | ||
""" | ||
|
||
ssl_dim: int = field(default=1024) | ||
|
||
|
||
@dataclass | ||
class KNNVCConfig(BaseVCConfig): | ||
"""Parameters. | ||
Args: | ||
model (str): | ||
Model name. Do not change unless you know what you are doing. | ||
model_args (KNNVCArgs): | ||
Model architecture arguments. Defaults to `KNNVCArgs()`. | ||
audio (KNNVCAudioConfig): | ||
Audio processing configuration. Defaults to `KNNVCAudioConfig()`. | ||
wavlm_layer (int): | ||
WavLM layer to use for feature extraction. | ||
topk (int): | ||
k in the kNN -- the number of nearest neighbors to average over | ||
""" | ||
|
||
model: str = "knnvc" | ||
model_args: KNNVCArgs = field(default_factory=KNNVCArgs) | ||
audio: KNNVCAudioConfig = field(default_factory=KNNVCAudioConfig) | ||
|
||
wavlm_layer: int = 6 | ||
topk: int = 4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,22 @@ | ||
import importlib | ||
import logging | ||
import re | ||
from typing import Dict, List, Union | ||
from typing import Dict, List, Optional, Union | ||
|
||
from TTS.vc.configs.shared_configs import BaseVCConfig | ||
from TTS.vc.models.base_vc import BaseVC | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def setup_model(config: BaseVCConfig) -> BaseVC: | ||
logger.info("Using model: %s", config.model) | ||
# fetch the right model implementation. | ||
if "model" in config and config["model"].lower() == "freevc": | ||
if config["model"].lower() == "freevc": | ||
MyModel = importlib.import_module("TTS.vc.models.freevc").FreeVC | ||
model = MyModel.init_from_config(config) | ||
return model | ||
elif config["model"].lower() == "knnvc": | ||
MyModel = importlib.import_module("TTS.vc.models.knnvc").KNNVC | ||
else: | ||
msg = f"Model {config.model} does not exist!" | ||
raise ValueError(msg) | ||
return MyModel.init_from_config(config) |
Oops, something went wrong.