diff --git a/README.md b/README.md
index ee7c876..40f4eb3 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,7 @@
### DB-AIAT: A Dual-branch attention-in-attention transformer for single-channel SE (https://arxiv.org/abs/2110.06467)
+
+
+
This is the repo of the manuscript "Dual-branch Attention-In-Attention Transformer for speech enhancement", which is accepted by ICASSP2022.
diff --git a/cog.yaml b/cog.yaml
new file mode 100644
index 0000000..d6f99d5
--- /dev/null
+++ b/cog.yaml
@@ -0,0 +1,37 @@
+# Configuration for Cog ⚙️
+# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
+
+build:
+ # set to true if your model requires a GPU
+ gpu: true
+
+ # a list of ubuntu apt packages to install
+ system_packages:
+ - "libsndfile1"
+ - "ffmpeg"
+
+ # python version in the form '3.8' or '3.8.12'
+ python_version: "3.8"
+
+ # a list of packages in the format ==
+ python_packages:
+ - "h5py==3.6.0"
+ - "hdf5storage==0.1.18"
+ - "joblib==1.1.0"
+ - "librosa==0.8.0"
+ - "numpy==1.21.4"
+ - "ptflops==0.6.8"
+ - "scipy==1.8.0"
+ - "soundfile==0.10.3.post1"
+ - "torch==1.8.0"
+ - "Cython==0.29.27"
+
+ # commands run after the environment is setup
+ run:
+ # need to install pesq here since it requires numpy to be installed at build time
+ - "pip install pesq==0.0.1"
+
+# predict.py defines how predictions are run on your model
+predict: "predict.py:Predictor"
+
+image: "r8.im/yuguochencuc/db-aiat"
diff --git a/predict.py b/predict.py
new file mode 100644
index 0000000..76b1f74
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,109 @@
+# Prediction interface for Cog ⚙️
+# Reference: https://github.com/replicate/cog/blob/main/docs/python.md
+
+from pathlib import Path
+import cog
+import torch
+import librosa
+import numpy as np
+from istft import ISTFT
+from aia_trans import (
+ dual_aia_trans_merge_crm,
+)
+import soundfile as sf
+
+
+SAMPLE_RATE = 16000
+CHUNK_LENGTH = SAMPLE_RATE * 10 # 10 seconds
+CHUNK_OVERLAP = int(SAMPLE_RATE * .1) # 100 ms
+CHUNK_HOP = CHUNK_LENGTH - CHUNK_OVERLAP
+
+
+class Predictor(cog.Predictor):
+ def setup(self):
+ """Load the model into memory to make running multiple predictions efficient"""
+ self.model = dual_aia_trans_merge_crm()
+ checkpoint = torch.load("./BEST_MODEL/vb_aia_merge_new.pth.tar")
+ self.model.load_state_dict(checkpoint)
+ self.model.eval()
+ self.model.cuda()
+ self.istft = ISTFT(filter_length=320, hop_length=160, window="hanning")
+
+ @cog.input("audio", type=Path, help="Noisy audio input")
+ def predict(self, audio):
+ """Run a single prediction on the model"""
+
+ # process audio in chunks to prevent running out of memory
+ clean_chunks = []
+ noisy, _ = librosa.load(str(audio), sr=SAMPLE_RATE, mono=True)
+ for i in range(0, len(noisy), CHUNK_HOP):
+ print(f"Processing samples {min(i + CHUNK_LENGTH, len(noisy))} / {len(noisy)}")
+ noisy_chunk = noisy[i:i + CHUNK_LENGTH]
+ clean_chunk = self.speech_enhance(noisy_chunk)
+ clean_chunks.append(clean_chunk)
+
+ last_clean_chunk = clean_chunks[-1]
+ if len(clean_chunks) > 1 and len(last_clean_chunk) < CHUNK_OVERLAP:
+ clean_chunks = clean_chunks[:-1]
+
+ # recreate clean audio by overlapping windows
+ clean = np.zeros(noisy.shape)
+ hanning = np.hanning(CHUNK_OVERLAP * 2)
+ fade_in = hanning[:CHUNK_OVERLAP]
+ fade_out = hanning[CHUNK_OVERLAP:]
+ for i, clean_chunk in enumerate(clean_chunks):
+ is_first = i == 0
+ is_last = i == len(clean_chunks) - 1
+ if not is_first:
+ clean_chunk[:CHUNK_OVERLAP] *= fade_in
+ if not is_last:
+ clean_chunk[CHUNK_HOP:] *= fade_out
+ clean[i * CHUNK_HOP:(i + 1) * CHUNK_HOP + CHUNK_OVERLAP] += clean_chunk
+
+ out_path = Path("/tmp/out.wav")
+ sf.write(str(out_path), clean, SAMPLE_RATE)
+
+ return out_path
+
+ def speech_enhance(self, signal):
+ with torch.no_grad():
+ c = np.sqrt(len(signal) / np.sum((signal ** 2.0)))
+ signal = signal * c
+ wav_len = len(signal)
+ frame_num = int(np.ceil((wav_len - 320 + 320) / 160 + 1))
+ fake_wav_len = (frame_num - 1) * 160 + 320 - 320
+ left_sample = fake_wav_len - wav_len
+ signal = torch.FloatTensor(
+ np.concatenate((signal, np.zeros([left_sample])), axis=0)
+ )
+ feat_x = torch.stft(
+ signal.unsqueeze(dim=0),
+ n_fft=320,
+ hop_length=160,
+ win_length=320,
+ window=torch.hann_window(320),
+ ).permute(0, 3, 2, 1)
+ noisy_phase = torch.atan2(feat_x[:, -1, :, :], feat_x[:, 0, :, :])
+ feat_x_mag = (torch.norm(feat_x, dim=1)) ** 0.5
+ feat_x = torch.stack(
+ (
+ feat_x_mag * torch.cos(noisy_phase),
+ feat_x_mag * torch.sin(noisy_phase),
+ ),
+ dim=1,
+ )
+ esti_x = self.model(feat_x.cuda())
+ esti_mag, esti_phase = torch.norm(esti_x, dim=1), torch.atan2(
+ esti_x[:, -1, :, :], esti_x[:, 0, :, :]
+ )
+ esti_mag = esti_mag ** 2
+ esti_com = torch.stack(
+ (esti_mag * torch.cos(esti_phase), esti_mag * torch.sin(esti_phase)),
+ dim=1,
+ )
+ esti_com = esti_com.cpu()
+ esti_utt = self.istft(esti_com).squeeze().numpy()
+ esti_utt = esti_utt[:wav_len]
+ esti_utt = esti_utt / c
+
+ return esti_utt