Skip to content

Commit

Permalink
Refactor repository
Browse files Browse the repository at this point in the history
  • Loading branch information
nglehuy committed Oct 10, 2020
1 parent 4367a21 commit ee6a553
Show file tree
Hide file tree
Showing 20 changed files with 163 additions and 306 deletions.
125 changes: 2 additions & 123 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,129 +4,8 @@ dist
tensorflow
externals
coc-settings.json
.session.vim
Session.vim
octave-workspace
.idea
.vscode
__pycache__
venv
trained
.env
.env.production
.env.development

# Created by https://www.toptal.com/developers/gitignore/api/vim
# Edit at https://www.toptal.com/developers/gitignore?templates=vim

### Vim ###
# Swap
[._]*.s[a-v][a-z]
!*.svg # comment out if you don't need vector files
[._]*.sw[a-p]
[._]s[a-rt-v][a-z]
[._]ss[a-gi-z]
[._]sw[a-p]

# Session
Session.vim
Sessionx.vim

# Temporary
.netrwhist
*~
# Auto-generated tag files
tags
# Persistent undo
[._]*.un~

# End of https://www.toptal.com/developers/gitignore/api/vim
# Created by https://www.gitignore.io/api/pycharm+all
# Edit at https://www.gitignore.io/?templates=pycharm+all

### PyCharm+all ###
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839

# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf

# Generated files
.idea/**/contentModel.xml

# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml

# Gradle
.idea/**/gradle.xml
.idea/**/libraries

# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr

# CMake
cmake-build-*/

# Mongo Explorer plugin
.idea/**/mongoSettings.xml

# File-based project format
*.iws

# IntelliJ
out/

# mpeltonen/sbt-idea plugin
.idea_modules/

# JIRA plugin
atlassian-ide-plugin.xml

# Cursive Clojure plugin
.idea/replstate.xml

# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties

# Editor-based Rest Client
.idea/httpRequests

# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser

### PyCharm+all Patch ###
# Ignores the whole .idea folder and all .iml files
# See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360

.idea/

# Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023

*.iml
modules.xml
.idea/misc.xml
*.ipr

# Sonarlint plugin
.idea/sonarlint

# End of https://www.gitignore.io/api/pycharm+all
__pycache__
21 changes: 6 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,43 +1,34 @@
<h1 align="center">
<p>TiramisuSE :cake:</p>
<p>SASEGAN</p>
<p align="center">
<a href="https://github.com/usimarit/TiramisuSE/blob/master/LICENSE">
<img alt="GitHub" src="https://img.shields.io/github/license/usimarit/TiramisuSE?style=for-the-badge&logo=apache">
<a href="https://github.com/usimarit/selfattention-segan/blob/master/LICENSE">
<img alt="GitHub" src="https://img.shields.io/github/license/usimarit/selfattention-segan?style=for-the-badge&logo=apache">
</a>
<img alt="python" src="https://img.shields.io/badge/python-%3E%3D3.6-blue?style=for-the-badge&logo=python">
<img alt="tensorflow" src="https://img.shields.io/badge/tensorflow-%3E%3D2.3.0-orange?style=for-the-badge&logo=tensorflow">
<img alt="ubuntu" src="https://img.shields.io/badge/ubuntu-%3E%3D18.04-blueviolet?style=for-the-badge&logo=ubuntu">
</p>
</h1>
<h2 align="center">
<p>The Newest Speech Enhancement in Tensorflow 2</p>
<p>Self Attention Speech Enhancement GAN in Tensorflow 2</p>
</h2>

<p align="center">
TiramisuSE implements some speech enhancement architectures such as Speech Enhancement Generative Adversarial Network (SEGAN). These models can be converted to TFLite to reduce memory and computation for deployment :smile:
</p>

## What's New?

- Moved from [TiramisuASR](https://github.com/usimarit/TiramisuASR)

## :yum: Supported Models

- **SEGAN** (Refer to [https://github.com/santi-pdp/segan](https://github.com/santi-pdp/segan)), see [examples/segan](./examples/segan)

## Requirements

- Ubuntu distribution (`ctc-decoders` and `semetrics` require some packages from apt)
- Python 3.6+
- Tensorflow 2.2+: `pip install tensorflow`
- **SASEGAN**, see [examples/sasegan](./examples/sasegan)

## Setup Environment and Datasets

Install tensorflow: `pip3 install tensorflow` or `pip3 install tf-nightly` (for using tflite)

Install packages: `python3 setup.py install`

For **setting up datasets**, see [datasets](./tiramisu_se/datasets/README.md)
For **setting up datasets**, see [datasets](sasegan/datasets/README.md)

- For _testing_ **Speech Enhancement Model** (i.e SEGAN), install `octave` and run `./scripts/install_semetrics.sh`

Expand Down
6 changes: 3 additions & 3 deletions examples/sasegan/test_sasegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@

setup_devices([args.device])

from tiramisu_asr.runners.segan_runners import SeganTester
from tiramisu_asr.datasets.segan_dataset import SeganTestDataset
from sasegan.runners.tester import SeganTester
from sasegan.datasets.segan_dataset import SeganTestDataset
from tiramisu_asr.configs.user_config import UserConfig
from tiramisu_asr.models.sasegan import Generator
from sasegan.models.sasegan import Generator

config = UserConfig(DEFAULT_YAML, args.config, learning=True)

Expand Down
6 changes: 3 additions & 3 deletions examples/sasegan/train_sasegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@

strategy = setup_strategy(args.devices)

from tiramisu_asr.runners.segan_runners import SeganTrainer
from tiramisu_asr.datasets.segan_dataset import SeganTrainDataset
from sasegan.runners.trainer import SeganTrainer
from sasegan.datasets.segan_dataset import SeganTrainDataset
from tiramisu_asr.configs.user_config import UserConfig
from tiramisu_asr.models.sasegan import Generator, Discriminator
from sasegan.models.sasegan import Generator, Discriminator

config = UserConfig(DEFAULT_YAML, args.config, learning=True)

Expand Down
6 changes: 3 additions & 3 deletions examples/segan/test_segan.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@

setup_devices([args.device])

from tiramisu_asr.runners.segan_runners import SeganTester
from tiramisu_asr.datasets.segan_dataset import SeganAugTestDataset
from sasegan.runners.tester import SeganTester
from sasegan.datasets.segan_dataset import SeganAugTestDataset
from tiramisu_asr.configs.user_config import UserConfig
from tiramisu_asr.models.segan import Generator
from sasegan.models.segan import Generator

config = UserConfig(DEFAULT_YAML, args.config, learning=True)

Expand Down
6 changes: 3 additions & 3 deletions examples/segan/train_segan.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@

strategy = setup_strategy(args.devices)

from tiramisu_asr.runners.segan_runners import SeganTrainer
from tiramisu_asr.datasets.segan_dataset import SeganAugTrainDataset
from sasegan.runners.trainer import SeganTrainer
from sasegan.datasets.segan_dataset import SeganAugTrainDataset
from tiramisu_asr.configs.user_config import UserConfig
from tiramisu_asr.models.segan import Generator, Discriminator
from sasegan.models.segan import Generator, Discriminator

config = UserConfig(DEFAULT_YAML, args.config, learning=True)

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
132 changes: 132 additions & 0 deletions sasegan/runners/tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from tqdm import tqdm

import numpy as np
import soundfile as sf
import tensorflow as tf

from tiramisu_asr.featurizers.speech_featurizers import deemphasis
from tiramisu_asr.featurizers.speech_featurizers import tf_merge_slices, read_raw_audio
from tiramisu_asr.runners.base_runners import BaseTester
from tiramisu_asr.utils.utils import shape_list


class SeganTester(BaseTester):
def __init__(self,
speech_config: dict,
config: dict):
super(SeganTester, self).__init__(config)
self.speech_config = speech_config

self.test_noisy_dir = os.path.join(self.config["outdir"], "test", "noisy")
self.test_gen_dir = os.path.join(self.config["outdir"], "test", "gen")
self.test_clean_dir = os.path.join(self.config["outdir"], "test", "clean")

if not os.path.exists(self.test_noisy_dir): os.makedirs(self.test_noisy_dir)
if not os.path.exists(self.test_gen_dir): os.makedirs(self.test_gen_dir)
if not os.path.exists(self.test_clean_dir): os.makedirs(self.test_clean_dir)

def set_test_data_loader(self, test_dataset):
"""Set train data loader (MUST)."""
self.clean_dir = test_dataset.clean_dir
self.test_data_loader = test_dataset.create()

def _test_epoch(self):
if self.processed_records > 0:
self.test_data_loader = self.test_data_loader.skip(self.processed_records)
progbar = tqdm(initial=self.processed_records, total=None,
unit="batch", position=0, desc="[Test]")
test_iter = iter(self.test_data_loader)
while True:
try:
self._test_function(test_iter)
except StopIteration:
break
except tf.errors.OutOfRangeError:
break
progbar.update(1)

progbar.close()

@tf.function
def _test_function(self, iterator):
batch = next(iterator)
self._test_step(batch)

def _test_step(self, batch):
# Test only available for batch size = 1
clean_wav_path, noisy_wavs = batch
g_wavs = self.model([noisy_wavs, self.model.get_z(shape_list(noisy_wavs)[0])],
training=False)

results = tf.numpy_function(
self._perform, inp=[clean_wav_path, tf_merge_slices(g_wavs),
tf_merge_slices(noisy_wavs)],
Tout=tf.float32
)

def _perform(self,
clean_wav_path: bytes,
gen_signal: np.ndarray,
noisy_signal: np.ndarray) -> tf.Tensor:
clean_wav_path = clean_wav_path.decode("utf-8")
results = self._compare(clean_wav_path, gen_signal, noisy_signal)
return tf.convert_to_tensor(results, dtype=tf.float32)

def _save_to_outdir(self,
clean_wav_path: str,
gen_signal: np.ndarray,
noisy_signal: np.ndarray):
gen_path = clean_wav_path.replace(self.clean_dir, self.test_gen_dir)
noisy_path = clean_wav_path.replace(self.clean_dir, self.test_noisy_dir)
try:
os.makedirs(os.path.dirname(gen_path))
os.makedirs(os.path.dirname(noisy_path))
except Exception:
pass
# Avoid differences by writing original wav using sf
clean_wav = read_raw_audio(clean_wav_path, self.speech_config["sample_rate"])
sf.write("/tmp/clean.wav", clean_wav, self.speech_config["sample_rate"])
sf.write(gen_path,
gen_signal,
self.speech_config["sample_rate"])
sf.write(noisy_path,
noisy_signal,
self.speech_config["sample_rate"])
return gen_path, noisy_path

def _compare(self,
clean_wav_path: str,
gen_signal: np.ndarray,
noisy_signal: np.ndarray) -> list:
gen_signal = deemphasis(gen_signal, self.speech_config["preemphasis"])
noisy_signal = deemphasis(noisy_signal, self.speech_config["preemphasis"])

gen_path, noisy_path = self._save_to_outdir(clean_wav_path, gen_signal, noisy_signal)

(pesq_gen, csig_gen, cbak_gen,
covl_gen, ssnr_gen) = self.composite("/tmp/clean.wav", gen_path)
(pesq_noisy, csig_noisy, cbak_noisy,
covl_noisy, ssnr_noisy) = self.composite("/tmp/clean.wav", noisy_path)

return [pesq_gen, csig_gen, cbak_gen, covl_gen, ssnr_gen,
pesq_noisy, csig_noisy, cbak_noisy, covl_noisy, ssnr_noisy]

def finish(self):
with open(self.test_results, "w", encoding="utf-8") as out:
for idx, key in enumerate(self.test_metrics.keys()):
out.write(f"{key} = {self.test_metrics[key].result().numpy():.2f}\n")
Loading

0 comments on commit ee6a553

Please sign in to comment.