Skip to content

Commit 959e549

Browse files
Merge pull request #77 from nickovchinnikov/nick/review
Add review page
2 parents bcaeeba + 0685def commit 959e549

22 files changed

+859
-100
lines changed

README.md

Lines changed: 11 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,23 @@
11
# TTS-Framework
2-
Modified version of DelightfulTTS and UnivNet
3-
4-
### Conda env
5-
6-
Create / activate env
7-
8-
```
9-
conda create --name tts_framework python=3.11
10-
conda activate tts_framework
11-
```
12-
13-
Export / import env
14-
15-
```
16-
conda env export > environment.yml
172

18-
```
19-
20-
By default, conda will export your environment with builds, but builds can be platform-specific.
21-
A solution that worked for me is to use the `--no-build` flag:
22-
23-
```
24-
conda env export --no-build > environment.yml
25-
```
3+
Modified version of DelightfulTTS and UnivNet
264

27-
Create an env
28-
```
29-
conda env create -f environment.yml
30-
```
5+
## Install deps
316

32-
If you have troubles with export, like:
33-
```
34-
InvalidVersionSpec: Invalid version '3.0<3.3': invalid character(s)
7+
```bash
8+
sudo apt install ffmpeg libasound2-dev build-essential espeak-ng -y
359
```
3610

37-
Find a problem by this way:
11+
Create env from the `environment.yml` file:
3812

39-
```
40-
cd /mnt/Data/anaconda3/envs/tts_framework/lib/python3.11/site-packages/
41-
42-
grep -Rnw . -e "3.0<3.3"
13+
```bash
14+
conda env create -f ./tts-framework/environment.yml python=3.11
4315

16+
# After the setup
17+
conda activate tts_framework
4418
```
4519

46-
A Faster Solver for Conda: [Libmamba](https://www.anaconda.com/blog/a-faster-conda-for-a-growing-community)
47-
48-
49-
Generate docs:
50-
20+
## Generate docs:
5121

5222
```
5323
# live preview server
@@ -57,35 +27,8 @@ mkdocs serve
5727
mkdocs build
5828
```
5929

60-
Test cases:
30+
## Test cases:
6131

6232
```
6333
python -m unittest discover -v
6434
```
65-
66-
### [Libmamba solver](https://www.anaconda.com/blog/a-faster-conda-for-a-growing-community):
67-
68-
```
69-
conda update -n base conda
70-
```
71-
72-
And then:
73-
74-
```
75-
conda install -n base conda-libmamba-solver
76-
conda config --set solver libmamba
77-
```
78-
79-
### Env Installation process
80-
81-
Install separately
82-
83-
```
84-
# First - pytorch
85-
# conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch-nightly -c nvidia
86-
87-
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
88-
89-
# Second - lightning
90-
pip3 install lightning
91-
```

docs-md/review.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -701,9 +701,9 @@ Additionally, you can explore the dataset code I prepared for various experiment
701701

702702
I firmly believe that a good project starts with comprehensive documentation, and good code is built upon a solid foundation of test cases. With this in mind, I have made concerted efforts to maintain consistent documentation and ensure thorough test coverage for my code. The repository serves as a comprehensive resource where you can explore the implementation details, review the documentation, and examine the test cases that ensure the code's reliability and correctness.
703703

704-
You can find all the documentation markdown inside the `docs-md` directory, run the docs locally with `mkdocs serve`
704+
You can find all the documentation inside the `docs` directory, run the docs locally with `mkdocs serve`
705705

706-
Also here you can the [docs online](https://storage.googleapis.com/tts-docs/index.html)
706+
Also here you can the [docs online](https://peechapp.github.io/tts-peech/)
707707

708708
### Acoustic model
709709

docs/assets/1-Figure1-1.png

28.6 KB
Loading
28.5 KB
Binary file not shown.

docs/assets/20240514174225.png

56.2 KB
Loading

docs/assets/20240517155909.png

34.1 KB
Loading

docs/assets/20240520150027.png

76.1 KB
Loading

docs/assets/20240524164947.png

20.5 KB
Loading

docs/assets/20240527121619.png

104 KB
Loading

docs/assets/20240527161523.png

117 KB
Loading

docs/assets/AR_Model.png

11.3 KB
Loading

docs/assets/NAR_schema.png

6.91 KB
Loading

docs/assets/audio-animation.gif

3.9 MB
Loading

docs/assets/mel_loss.png

47.9 KB
Loading

docs/assets/total_loss.png

50.2 KB
Loading

docs/review.md

Lines changed: 744 additions & 0 deletions
Large diffs are not rendered by default.

models/config/configs.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from dataclasses import dataclass, field
22
from typing import List, Literal, Tuple, Union
33

4-
import torch
5-
64
PreprocessLangType = Literal["english_only", "multilingual"]
75

86

@@ -22,15 +20,10 @@ class PreprocessingConfig:
2220
language: PreprocessLangType
2321
stft: STFTConfig
2422
sampling_rate: int = 22050
25-
val_size: float = 0.05
2623
min_seconds: float = 0.5
2724
max_seconds: float = 6.0
2825
use_audio_normalization: bool = True
2926
workers: int = 8
30-
forced_alignment_batch_size: int = 200000
31-
skip_on_error: bool = True
32-
pitch_fmin: int = 1
33-
pitch_fmax: int = 640
3427

3528

3629
@dataclass
@@ -61,10 +54,7 @@ class PreprocessingConfigHifiGAN(PreprocessingConfig):
6154
)
6255

6356
def __post_init__(self):
64-
r"""Post-initialization method for the `PreprocessingConfig` dataclass.
65-
66-
This method is automatically called after the instance is initialized.
67-
It modifies the 'stft' attribute based on the 'sampling_rate' attribute.
57+
r"""It modifies the 'stft' attribute based on the 'sampling_rate' attribute.
6858
If 'sampling_rate' is 44100, 'stft' is set with specific values for this rate.
6959
If 'sampling_rate' is not 22050 or 44100, a ValueError is raised.
7060
@@ -84,21 +74,6 @@ def __post_init__(self):
8474
raise ValueError("Sampling rate must be 22050 or 44100")
8575

8676

87-
@dataclass
88-
class SampleSplittingRunConfig:
89-
workers: int
90-
device: torch.device
91-
skip_on_error: bool
92-
forced_alignment_batch_size: int
93-
94-
95-
@dataclass
96-
class CleaningRunConfig:
97-
workers: int
98-
device: torch.device
99-
skip_on_error: bool
100-
101-
10277
@dataclass
10378
class AcousticTrainingOptimizerConfig:
10479
learning_rate: float

models/tts/delightful_tts/acoustic_model/acoustic_model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,6 @@ def freeze_params(self) -> None:
237237
for par in self.parameters():
238238
par.requires_grad = False
239239
self.speaker_embed.requires_grad = True
240-
# NOTE: requires_grad prop
241-
# self.pitch_adaptor.pitch_embedding.embeddings.requires_grad = True
242240

243241
# NOTE: freeze/unfreeze params changed, because of the conflict with the lightning module
244242
def unfreeze_params(self, freeze_text_embed: bool, freeze_lang_embed: bool) -> None:

models/tts/delightful_tts/train/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from lightning.pytorch.tuner.tuning import Tuner
1010
import torch
1111

12-
from models.tts.delightful_tts.delightful_tts_refined import DelightfulTTS
12+
from models.tts.delightful_tts.delightful_tts import DelightfulTTS
1313

1414
# Node runk in the cluster
1515
node_rank = 0
@@ -87,7 +87,6 @@
8787
# NOTE: Preload the cached dataset into the RAM
8888
cache_dir="/dev/shm/",
8989
cache=True,
90-
mem_cache=False,
9190
)
9291

9392
trainer.fit(
File renamed without changes.
File renamed without changes.

train.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from datetime import datetime
2+
import logging
3+
import os
4+
import sys
5+
6+
from lightning.pytorch import Trainer
7+
from lightning.pytorch.accelerators import find_usable_cuda_devices # type: ignore
8+
from lightning.pytorch.strategies import DDPStrategy
9+
from lightning.pytorch.tuner.tuning import Tuner
10+
import torch
11+
12+
from models.config import PreprocessingConfigUnivNet as PreprocessingConfig
13+
from models.tts.delightful_tts.delightful_tts import DelightfulTTS
14+
15+
# Num nodes in the cluster
16+
num_nodes = 1
17+
# Node runk in the cluster
18+
node_rank = 0
19+
20+
os.environ["WORLD_SIZE"] = f"{num_nodes}"
21+
os.environ["NODE_RANK"] = f"{node_rank}"
22+
23+
# IP/Port of the master node
24+
os.environ["MASTER_PORT"] = "12355"
25+
os.environ["MASTER_ADDR"] = "10.148.0.6"
26+
27+
# Create a logger
28+
# Set the level of the logger to ERROR
29+
logger = logging.getLogger("my_logger")
30+
logger.setLevel(logging.ERROR)
31+
32+
# Format the current date and time as a string
33+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
34+
35+
# Create a file handler that logs error messages to a file with the current timestamp in its name
36+
handler = logging.FileHandler(f"logs/error_{timestamp}.log")
37+
38+
# Create a formatter and add it to the handler
39+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
40+
handler.setFormatter(formatter)
41+
42+
# Add the handler to the logger
43+
logger.addHandler(handler)
44+
45+
print("usable_cuda_devices: ", find_usable_cuda_devices())
46+
47+
# Set the precision of the matrix multiplication to float32 to improve the performance of the training
48+
torch.set_float32_matmul_precision("high")
49+
50+
# Set the logs dir and the checkpoint paths
51+
default_root_dir = "logs"
52+
ckpt_acoustic = "./checkpoints/epoch=301-step=124630.ckpt"
53+
ckpt_vocoder = "./checkpoints/vocoder.ckpt"
54+
55+
try:
56+
trainer = Trainer(
57+
accelerator="cuda",
58+
devices=-1,
59+
num_nodes=num_nodes,
60+
strategy=DDPStrategy(
61+
gradient_as_bucket_view=True,
62+
find_unused_parameters=True,
63+
),
64+
# Save checkpoints to the `default_root_dir` directory
65+
default_root_dir=default_root_dir,
66+
enable_checkpointing=True,
67+
accumulate_grad_batches=5,
68+
max_epochs=-1,
69+
log_every_n_steps=10,
70+
gradient_clip_val=0.5,
71+
)
72+
73+
preprocessing_config = PreprocessingConfig("multilingual")
74+
model = DelightfulTTS(preprocessing_config)
75+
# NOTE: Load the model from the checkpoint file
76+
# In case of loading the model from the checkpoint file, model states will be restored
77+
# from the checkpoint file but the training states will be reset
78+
# model = DelightfulTTS.load_from_checkpoint(ckpt_acoustic, strict=False)
79+
80+
tuner = Tuner(trainer)
81+
# NOTE: Tune the learning rate of the model if needed
82+
# tuner.lr_find(model)
83+
84+
train_dataloader = model.train_dataloader(
85+
# NOTE: Preload the cached dataset into the RAM
86+
cache_dir="/dev/shm/",
87+
cache=True,
88+
)
89+
90+
trainer.fit(
91+
model=model,
92+
train_dataloaders=train_dataloader,
93+
# Resume training states from the checkpoint file
94+
# ckpt_path=ckpt_acoustic,
95+
)
96+
97+
except Exception as e:
98+
# Log the error message
99+
logger.error(f"An error occurred: {e}")
100+
sys.exit(1)

0 commit comments

Comments
 (0)