Skip to content

Commit

Permalink
refactor(xtts): remove duplicate xtts audio config
Browse files Browse the repository at this point in the history
  • Loading branch information
eginhard committed Dec 5, 2024
1 parent ce20253 commit fe14ca6
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 13 deletions.
3 changes: 2 additions & 1 deletion TTS/demos/xtts_ft_demo/utils/gpt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
from TTS.utils.manage import ModelManager


Expand Down
7 changes: 1 addition & 6 deletions TTS/tts/layers/xtts/trainer/gpt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
from TTS.tts.models.xtts import Xtts, XttsArgs
from TTS.utils.generic_utils import is_pytorch_at_least_2_4

logger = logging.getLogger(__name__)
Expand All @@ -34,11 +34,6 @@ class GPTTrainerConfig(XttsConfig):
test_sentences: List[dict] = field(default_factory=lambda: [])


@dataclass
class XttsAudioConfig(XttsAudioConfig):
dvae_sample_rate: int = 22050


@dataclass
class GPTArgs(XttsArgs):
min_conditioning_length: int = 66150
Expand Down
5 changes: 3 additions & 2 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from coqpit import Coqpit
from trainer.io import load_fsspec

from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support
Expand Down Expand Up @@ -103,10 +102,12 @@ class XttsAudioConfig(Coqpit):
Args:
sample_rate (int): The sample rate in which the GPT operates.
output_sample_rate (int): The sample rate of the output audio waveform.
dvae_sample_rate (int): The sample rate of the DVAE
"""

sample_rate: int = 22050
output_sample_rate: int = 24000
dvae_sample_rate: int = 22050


@dataclass
Expand Down Expand Up @@ -721,7 +722,7 @@ def get_compatible_checkpoint_state_dict(self, model_path):

def load_checkpoint(
self,
config: XttsConfig,
config: "XttsConfig",
checkpoint_dir: Optional[str] = None,
checkpoint_path: Optional[str] = None,
vocab_path: Optional[str] = None,
Expand Down
3 changes: 2 additions & 1 deletion recipes/ljspeech/xtts_v1/train_gpt_xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
from TTS.utils.manage import ModelManager

# Logging parameters
Expand Down
3 changes: 2 additions & 1 deletion recipes/ljspeech/xtts_v2/train_gpt_xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
from TTS.utils.manage import ModelManager

# Logging parameters
Expand Down
3 changes: 2 additions & 1 deletion tests/xtts_tests/test_xtts_gpt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig

config_dataset = BaseDatasetConfig(
formatter="ljspeech",
Expand Down
3 changes: 2 additions & 1 deletion tests/xtts_tests/test_xtts_v2-0_gpt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig

config_dataset = BaseDatasetConfig(
formatter="ljspeech",
Expand Down

0 comments on commit fe14ca6

Please sign in to comment.