Skip to content

Commit

Permalink
Add Replicate demo and Cog configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
andreasjansson committed Feb 15, 2022
1 parent 94ab8da commit 6ee881c
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 0 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
### DB-AIAT: A Dual-branch attention-in-attention transformer for single-channel SE (https://arxiv.org/abs/2110.06467)

<a href="https://replicate.com/yuguochencuc/db-aiat"><img src="https://replicate.com/yuguochencuc/db-aiat/badge"></a>

This is the repo of the manuscript "Dual-branch Attention-In-Attention Transformer for speech enhancement", which is accepted by ICASSP2022.


Expand Down
37 changes: 37 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Configuration for Cog ⚙️
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md

build:
# set to true if your model requires a GPU
gpu: true

# a list of ubuntu apt packages to install
system_packages:
- "libsndfile1"
- "ffmpeg"

# python version in the form '3.8' or '3.8.12'
python_version: "3.8"

# a list of packages in the format <package-name>==<version>
python_packages:
- "h5py==3.6.0"
- "hdf5storage==0.1.18"
- "joblib==1.1.0"
- "librosa==0.8.0"
- "numpy==1.21.4"
- "ptflops==0.6.8"
- "scipy==1.8.0"
- "soundfile==0.10.3.post1"
- "torch==1.8.0"
- "Cython==0.29.27"

# commands run after the environment is setup
run:
# need to install pesq here since it requires numpy to be installed at build time
- "pip install pesq==0.0.1"

# predict.py defines how predictions are run on your model
predict: "predict.py:Predictor"

image: "r8.im/yuguochencuc/db-aiat"
109 changes: 109 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Prediction interface for Cog ⚙️
# Reference: https://github.com/replicate/cog/blob/main/docs/python.md

from pathlib import Path
import cog
import torch
import librosa
import numpy as np
from istft import ISTFT
from aia_trans import (
dual_aia_trans_merge_crm,
)
import soundfile as sf


SAMPLE_RATE = 16000
CHUNK_LENGTH = SAMPLE_RATE * 10 # 10 seconds
CHUNK_OVERLAP = int(SAMPLE_RATE * .1) # 100 ms
CHUNK_HOP = CHUNK_LENGTH - CHUNK_OVERLAP


class Predictor(cog.Predictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.model = dual_aia_trans_merge_crm()
checkpoint = torch.load("./BEST_MODEL/vb_aia_merge_new.pth.tar")
self.model.load_state_dict(checkpoint)
self.model.eval()
self.model.cuda()
self.istft = ISTFT(filter_length=320, hop_length=160, window="hanning")

@cog.input("audio", type=Path, help="Noisy audio input")
def predict(self, audio):
"""Run a single prediction on the model"""

# process audio in chunks to prevent running out of memory
clean_chunks = []
noisy, _ = librosa.load(str(audio), sr=SAMPLE_RATE, mono=True)
for i in range(0, len(noisy), CHUNK_HOP):
print(f"Processing samples {min(i + CHUNK_LENGTH, len(noisy))} / {len(noisy)}")
noisy_chunk = noisy[i:i + CHUNK_LENGTH]
clean_chunk = self.speech_enhance(noisy_chunk)
clean_chunks.append(clean_chunk)

last_clean_chunk = clean_chunks[-1]
if len(clean_chunks) > 1 and len(last_clean_chunk) < CHUNK_OVERLAP:
clean_chunks = clean_chunks[:-1]

# recreate clean audio by overlapping windows
clean = np.zeros(noisy.shape)
hanning = np.hanning(CHUNK_OVERLAP * 2)
fade_in = hanning[:CHUNK_OVERLAP]
fade_out = hanning[CHUNK_OVERLAP:]
for i, clean_chunk in enumerate(clean_chunks):
is_first = i == 0
is_last = i == len(clean_chunks) - 1
if not is_first:
clean_chunk[:CHUNK_OVERLAP] *= fade_in
if not is_last:
clean_chunk[CHUNK_HOP:] *= fade_out
clean[i * CHUNK_HOP:(i + 1) * CHUNK_HOP + CHUNK_OVERLAP] += clean_chunk

out_path = Path("/tmp/out.wav")
sf.write(str(out_path), clean, SAMPLE_RATE)

return out_path

def speech_enhance(self, signal):
with torch.no_grad():
c = np.sqrt(len(signal) / np.sum((signal ** 2.0)))
signal = signal * c
wav_len = len(signal)
frame_num = int(np.ceil((wav_len - 320 + 320) / 160 + 1))
fake_wav_len = (frame_num - 1) * 160 + 320 - 320
left_sample = fake_wav_len - wav_len
signal = torch.FloatTensor(
np.concatenate((signal, np.zeros([left_sample])), axis=0)
)
feat_x = torch.stft(
signal.unsqueeze(dim=0),
n_fft=320,
hop_length=160,
win_length=320,
window=torch.hann_window(320),
).permute(0, 3, 2, 1)
noisy_phase = torch.atan2(feat_x[:, -1, :, :], feat_x[:, 0, :, :])
feat_x_mag = (torch.norm(feat_x, dim=1)) ** 0.5
feat_x = torch.stack(
(
feat_x_mag * torch.cos(noisy_phase),
feat_x_mag * torch.sin(noisy_phase),
),
dim=1,
)
esti_x = self.model(feat_x.cuda())
esti_mag, esti_phase = torch.norm(esti_x, dim=1), torch.atan2(
esti_x[:, -1, :, :], esti_x[:, 0, :, :]
)
esti_mag = esti_mag ** 2
esti_com = torch.stack(
(esti_mag * torch.cos(esti_phase), esti_mag * torch.sin(esti_phase)),
dim=1,
)
esti_com = esti_com.cpu()
esti_utt = self.istft(esti_com).squeeze().numpy()
esti_utt = esti_utt[:wav_len]
esti_utt = esti_utt / c

return esti_utt

0 comments on commit 6ee881c

Please sign in to comment.