From 6ee881c62f00e97044965e6c93bdd93c737cade3 Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Mon, 14 Feb 2022 17:22:18 -0800 Subject: [PATCH] Add Replicate demo and Cog configuration --- README.md | 3 ++ cog.yaml | 37 ++++++++++++++++++ predict.py | 109 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 149 insertions(+) create mode 100644 cog.yaml create mode 100644 predict.py diff --git a/README.md b/README.md index ee7c876..40f4eb3 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,7 @@ ### DB-AIAT: A Dual-branch attention-in-attention transformer for single-channel SE (https://arxiv.org/abs/2110.06467) + + + This is the repo of the manuscript "Dual-branch Attention-In-Attention Transformer for speech enhancement", which is accepted by ICASSP2022. diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 0000000..d6f99d5 --- /dev/null +++ b/cog.yaml @@ -0,0 +1,37 @@ +# Configuration for Cog ⚙️ +# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md + +build: + # set to true if your model requires a GPU + gpu: true + + # a list of ubuntu apt packages to install + system_packages: + - "libsndfile1" + - "ffmpeg" + + # python version in the form '3.8' or '3.8.12' + python_version: "3.8" + + # a list of packages in the format == + python_packages: + - "h5py==3.6.0" + - "hdf5storage==0.1.18" + - "joblib==1.1.0" + - "librosa==0.8.0" + - "numpy==1.21.4" + - "ptflops==0.6.8" + - "scipy==1.8.0" + - "soundfile==0.10.3.post1" + - "torch==1.8.0" + - "Cython==0.29.27" + + # commands run after the environment is setup + run: + # need to install pesq here since it requires numpy to be installed at build time + - "pip install pesq==0.0.1" + +# predict.py defines how predictions are run on your model +predict: "predict.py:Predictor" + +image: "r8.im/yuguochencuc/db-aiat" diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..76b1f74 --- /dev/null +++ b/predict.py @@ -0,0 +1,109 @@ +# Prediction interface for Cog ⚙️ +# Reference: https://github.com/replicate/cog/blob/main/docs/python.md + +from pathlib import Path +import cog +import torch +import librosa +import numpy as np +from istft import ISTFT +from aia_trans import ( + dual_aia_trans_merge_crm, +) +import soundfile as sf + + +SAMPLE_RATE = 16000 +CHUNK_LENGTH = SAMPLE_RATE * 10 # 10 seconds +CHUNK_OVERLAP = int(SAMPLE_RATE * .1) # 100 ms +CHUNK_HOP = CHUNK_LENGTH - CHUNK_OVERLAP + + +class Predictor(cog.Predictor): + def setup(self): + """Load the model into memory to make running multiple predictions efficient""" + self.model = dual_aia_trans_merge_crm() + checkpoint = torch.load("./BEST_MODEL/vb_aia_merge_new.pth.tar") + self.model.load_state_dict(checkpoint) + self.model.eval() + self.model.cuda() + self.istft = ISTFT(filter_length=320, hop_length=160, window="hanning") + + @cog.input("audio", type=Path, help="Noisy audio input") + def predict(self, audio): + """Run a single prediction on the model""" + + # process audio in chunks to prevent running out of memory + clean_chunks = [] + noisy, _ = librosa.load(str(audio), sr=SAMPLE_RATE, mono=True) + for i in range(0, len(noisy), CHUNK_HOP): + print(f"Processing samples {min(i + CHUNK_LENGTH, len(noisy))} / {len(noisy)}") + noisy_chunk = noisy[i:i + CHUNK_LENGTH] + clean_chunk = self.speech_enhance(noisy_chunk) + clean_chunks.append(clean_chunk) + + last_clean_chunk = clean_chunks[-1] + if len(clean_chunks) > 1 and len(last_clean_chunk) < CHUNK_OVERLAP: + clean_chunks = clean_chunks[:-1] + + # recreate clean audio by overlapping windows + clean = np.zeros(noisy.shape) + hanning = np.hanning(CHUNK_OVERLAP * 2) + fade_in = hanning[:CHUNK_OVERLAP] + fade_out = hanning[CHUNK_OVERLAP:] + for i, clean_chunk in enumerate(clean_chunks): + is_first = i == 0 + is_last = i == len(clean_chunks) - 1 + if not is_first: + clean_chunk[:CHUNK_OVERLAP] *= fade_in + if not is_last: + clean_chunk[CHUNK_HOP:] *= fade_out + clean[i * CHUNK_HOP:(i + 1) * CHUNK_HOP + CHUNK_OVERLAP] += clean_chunk + + out_path = Path("/tmp/out.wav") + sf.write(str(out_path), clean, SAMPLE_RATE) + + return out_path + + def speech_enhance(self, signal): + with torch.no_grad(): + c = np.sqrt(len(signal) / np.sum((signal ** 2.0))) + signal = signal * c + wav_len = len(signal) + frame_num = int(np.ceil((wav_len - 320 + 320) / 160 + 1)) + fake_wav_len = (frame_num - 1) * 160 + 320 - 320 + left_sample = fake_wav_len - wav_len + signal = torch.FloatTensor( + np.concatenate((signal, np.zeros([left_sample])), axis=0) + ) + feat_x = torch.stft( + signal.unsqueeze(dim=0), + n_fft=320, + hop_length=160, + win_length=320, + window=torch.hann_window(320), + ).permute(0, 3, 2, 1) + noisy_phase = torch.atan2(feat_x[:, -1, :, :], feat_x[:, 0, :, :]) + feat_x_mag = (torch.norm(feat_x, dim=1)) ** 0.5 + feat_x = torch.stack( + ( + feat_x_mag * torch.cos(noisy_phase), + feat_x_mag * torch.sin(noisy_phase), + ), + dim=1, + ) + esti_x = self.model(feat_x.cuda()) + esti_mag, esti_phase = torch.norm(esti_x, dim=1), torch.atan2( + esti_x[:, -1, :, :], esti_x[:, 0, :, :] + ) + esti_mag = esti_mag ** 2 + esti_com = torch.stack( + (esti_mag * torch.cos(esti_phase), esti_mag * torch.sin(esti_phase)), + dim=1, + ) + esti_com = esti_com.cpu() + esti_utt = self.istft(esti_com).squeeze().numpy() + esti_utt = esti_utt[:wav_len] + esti_utt = esti_utt / c + + return esti_utt