-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7bf311f
commit 9203816
Showing
7 changed files
with
382 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from setuptools import setup, find_packages | ||
|
||
requirements = [ | ||
"tqdm", | ||
"kaldiio", | ||
"torch>=1.12.0", | ||
"torchaudio>=0.12.0", | ||
"silero-vad", | ||
] | ||
|
||
setup( | ||
name="wesep", | ||
install_requires=requirements, | ||
packages=find_packages(), | ||
entry_points={ | ||
"console_scripts": [ | ||
"wesep = wesep.cli.extractor:main", | ||
], | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from wesep.cli.extractor import load_model # noqa | ||
from wesep.cli.extractor import load_model_local # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
import os | ||
import sys | ||
|
||
import numpy as np | ||
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps | ||
import torch | ||
import torchaudio | ||
import torchaudio.compliance.kaldi as kaldi | ||
import yaml | ||
from tqdm import tqdm | ||
import soundfile | ||
|
||
from wesep.cli.hub import Hub | ||
from wesep.cli.utils import get_args | ||
from wesep.models import get_model | ||
from wesep.utils.checkpoint import load_pretrained_model | ||
from wesep.utils.utils import set_seed | ||
|
||
|
||
class Extractor: | ||
|
||
def __init__(self, model_dir: str): | ||
set_seed() | ||
|
||
config_path = os.path.join(model_dir, "config.yaml") | ||
model_path = os.path.join(model_dir, "avg_model.pt") | ||
with open(config_path, "r") as fin: | ||
configs = yaml.load(fin, Loader=yaml.FullLoader) | ||
if 'spk_model_init' in configs['model_args']['tse_model']: | ||
configs['model_args']['tse_model']['spk_model_init'] = False | ||
self.model = get_model(configs["model"]["tse_model"])( | ||
**configs["model_args"]["tse_model"] | ||
) | ||
load_pretrained_model(self.model, model_path) | ||
self.model.eval() | ||
self.vad = load_silero_vad() | ||
self.table = {} | ||
self.resample_rate = configs["dataset_args"].get("resample_rate", 16000) | ||
self.apply_vad = False | ||
self.device = torch.device("cpu") | ||
self.wavform_norm = False | ||
|
||
self.speaker_feat = configs["model_args"]["tse_model"].get("spk_feat", False) | ||
self.joint_training = configs["model_args"]["tse_model"].get( | ||
"joint_training", False | ||
) | ||
|
||
def set_wavform_norm(self, wavform_norm: bool): | ||
self.wavform_norm = wavform_norm | ||
|
||
def set_resample_rate(self, resample_rate: int): | ||
self.resample_rate = resample_rate | ||
|
||
def set_vad(self, apply_vad: bool): | ||
self.apply_vad = apply_vad | ||
|
||
def set_device(self, device: str): | ||
self.device = torch.device(device) | ||
self.model = self.model.to(self.device) | ||
|
||
def compute_fbank( | ||
self, | ||
wavform, | ||
sample_rate=16000, | ||
num_mel_bins=80, | ||
frame_length=25, | ||
frame_shift=10, | ||
cmn=True, | ||
): | ||
feat = kaldi.fbank( | ||
wavform, | ||
num_mel_bins=num_mel_bins, | ||
frame_length=frame_length, | ||
frame_shift=frame_shift, | ||
sample_frequency=sample_rate, | ||
) | ||
if cmn: | ||
feat = feat - torch.mean(feat, 0) | ||
return feat | ||
|
||
def extract_speech(self, audio_path: str, audio_path_2: str): | ||
pcm_mix, sample_rate_mix = torchaudio.load( | ||
audio_path, normalize=self.wavform_norm | ||
) | ||
pcm_enroll, sample_rate_enroll = torchaudio.load( | ||
audio_path_2, normalize=self.wavform_norm | ||
) | ||
return self.extract_speech_from_pcm(pcm_mix, sample_rate_mix, pcm_enroll, sample_rate_enroll) | ||
|
||
def extract_speech_from_pcm(self, pcm_mix: torch.Tensor, sample_rate_mix: int, pcm_enroll: torch.Tensor, sample_rate_enroll: int,): | ||
if self.apply_vad: | ||
# TODO(Binbin Zhang): Refine the segments logic, here we just | ||
# suppose there is only silence at the start/end of the speech | ||
# Only do vad on the enrollment | ||
vad_sample_rate = 16000 | ||
wav = pcm_enroll | ||
if wav.size(0) > 1: | ||
wav = wav.mean(dim=0, keepdim=True) | ||
if sample_rate_enroll != vad_sample_rate: | ||
transform = torchaudio.transforms.Resample( | ||
orig_freq=sample_rate_enroll, new_freq=vad_sample_rate | ||
) | ||
wav = transform(wav) | ||
|
||
segments = get_speech_timestamps(wav, self.vad, return_seconds=True) | ||
pcmTotal = torch.Tensor() | ||
if len(segments) > 0: # remove all the silence | ||
for segment in segments: | ||
start = int(segment["start"] * sample_rate_enroll) | ||
end = int(segment["end"] * sample_rate_enroll) | ||
pcmTemp = pcm_enroll[0, start:end] | ||
pcmTotal = torch.cat([pcmTotal, pcmTemp], 0) | ||
pcm_enroll = pcmTotal.unsqueeze(0) | ||
else: # all silence, nospeech | ||
return None | ||
|
||
pcm_mix = pcm_mix.to(torch.float) | ||
if sample_rate_mix != self.resample_rate: | ||
pcm_mix = torchaudio.transforms.Resample( | ||
orig_freq=sample_rate_mix, new_freq=self.resample_rate | ||
)(pcm_mix) | ||
pcm_enroll = pcm_enroll.to(torch.float) | ||
if sample_rate_enroll != self.resample_rate: | ||
pcm_enroll = torchaudio.transforms.Resample( | ||
orig_freq=sample_rate_enroll, new_freq=self.resample_rate | ||
)(pcm_enroll) | ||
|
||
if self.joint_training: | ||
if self.speaker_feat: | ||
feats = self.compute_fbank( | ||
pcm_enroll, sample_rate=self.resample_rate, cmn=True | ||
) | ||
feats = feats.unsqueeze(0) | ||
feats = feats.to(self.device) | ||
else: | ||
feats = pcm_enroll | ||
|
||
with torch.no_grad(): | ||
outputs = self.model(pcm_mix, feats) | ||
outputs = outputs[0] if isinstance(outputs, (list, tuple)) else outputs | ||
target_speech = outputs.to(torch.device("cpu")) | ||
return target_speech | ||
else: | ||
return None | ||
|
||
|
||
def load_model(language: str) -> Extractor: | ||
model_path = Hub.get_model(language) | ||
return Extractor(model_path) | ||
|
||
|
||
def load_model_local(model_dir: str) -> Extractor: | ||
return Extractor(model_dir) | ||
|
||
|
||
def main(): | ||
args = get_args() | ||
if args.pretrain == "": | ||
if args.bsrnn: | ||
model = load_model("bsrnn") | ||
else: | ||
model = load_model(args.language) | ||
model.set_wavform_norm(True) | ||
else: | ||
model = load_model_local(args.pretrain) | ||
model.set_resample_rate(args.resample_rate) | ||
model.set_vad(args.vad) | ||
model.set_device(args.device) | ||
if args.task == "extraction": | ||
speech = model.extract_speech(args.audio_file, args.audio_file2) | ||
if speech is not None: | ||
soundfile.write(args.output_file, speech[0], args.resample_rate) | ||
print("Succeed, see {}".format(args.output_file)) | ||
else: | ||
print("Fails to extract the target speech") | ||
else: | ||
print("Unsupported task {}".format(args.task)) | ||
sys.exit(-1) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright (c) 2022 Mddct(hamddct@gmail.com) | ||
# 2023 Binbin Zhang(binbzha@qq.com) | ||
# 2024 Shuai Wang(wsstriving@gmail.com) | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import requests | ||
import sys | ||
from pathlib import Path | ||
import tarfile | ||
import zipfile | ||
from urllib.request import urlretrieve | ||
|
||
import tqdm | ||
|
||
|
||
def download(url: str, dest: str, only_child=True): | ||
"""download from url to dest""" | ||
assert os.path.exists(dest) | ||
print("Downloading {} to {}".format(url, dest)) | ||
|
||
def progress_hook(t): | ||
last_b = [0] | ||
|
||
def update_to(b=1, bsize=1, tsize=None): | ||
if tsize not in (None, -1): | ||
t.total = tsize | ||
displayed = t.update((b - last_b[0]) * bsize) | ||
last_b[0] = b | ||
return displayed | ||
|
||
return update_to | ||
|
||
# *.tar.gz | ||
name = url.split("?")[0].split("/")[-1] | ||
file_path = os.path.join(dest, name) | ||
with tqdm.tqdm( | ||
unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=(name) | ||
) as t: | ||
urlretrieve( | ||
url, filename=file_path, reporthook=progress_hook(t), data=None | ||
) | ||
t.total = t.n | ||
|
||
if name.endswith((".tar.gz", ".tar")): | ||
with tarfile.open(file_path) as f: | ||
if not only_child: | ||
f.extractall(dest) | ||
else: | ||
for tarinfo in f: | ||
if "/" not in tarinfo.name: | ||
continue | ||
name = os.path.basename(tarinfo.name) | ||
fileobj = f.extractfile(tarinfo) | ||
with open(os.path.join(dest, name), "wb") as writer: | ||
writer.write(fileobj.read()) | ||
|
||
elif name.endswith(".zip"): | ||
with zipfile.ZipFile(file_path, "r") as zip_ref: | ||
if not only_child: | ||
zip_ref.extractall(dest) | ||
else: | ||
for member in zip_ref.namelist(): | ||
member_path = os.path.relpath( | ||
member, start=os.path.commonpath(zip_ref.namelist()) | ||
) | ||
print(member_path) | ||
if "/" not in member_path: | ||
continue | ||
name = os.path.basename(member_path) | ||
with zip_ref.open(member_path) as source, open( | ||
os.path.join(dest, name), "wb" | ||
) as target: | ||
target.write(source.read()) | ||
|
||
|
||
class Hub(object): | ||
Assets = { | ||
"english": "bsrnn_ecapa_vox1.tar.gz", | ||
} | ||
# Hard coding of the URL | ||
ModelURLs = { | ||
"bsrnn_ecapa_vox1.tar.gz": "https://www.modelscope.cn/datasets/wenet/wesep_pretrained_models/resolve/master/bsrnn_ecapa_vox1.tar.gz", | ||
} | ||
|
||
def __init__(self) -> None: | ||
pass | ||
|
||
@staticmethod | ||
def get_model(lang: str) -> str: | ||
if lang not in Hub.Assets.keys(): | ||
print("ERROR: Unsupported lang {} !!!".format(lang)) | ||
sys.exit(1) | ||
# model = Hub.Assets[lang] | ||
model_name = Hub.Assets[lang] | ||
model_dir = os.path.join(Path.home(), ".wesep", lang) | ||
if not os.path.exists(model_dir): | ||
os.makedirs(model_dir) | ||
if set(["avg_model.pt", "config.yaml"]).issubset( | ||
set(os.listdir(model_dir)) | ||
): | ||
return model_dir | ||
else: | ||
if model_name in Hub.ModelURLs: | ||
model_url = Hub.ModelURLs[model_name] | ||
download(model_url, model_dir) | ||
return model_dir | ||
else: | ||
print(f"ERROR: No URL found for model {model_name}") | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import argparse | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser(description="") | ||
parser.add_argument( | ||
"-t", | ||
"--task", | ||
choices=[ | ||
"extraction", | ||
], | ||
default="extraction", | ||
help="task type", | ||
) | ||
parser.add_argument( | ||
"-l", | ||
"--language", | ||
choices=[ | ||
# "chinese", | ||
"english", | ||
], | ||
default="english", | ||
help="language type", | ||
) | ||
parser.add_argument( | ||
"--bsrnn", | ||
action="store_true", | ||
help="whether to use the bsrnn model", | ||
) | ||
parser.add_argument( | ||
"-p", "--pretrain", type=str, default="", help="model directory" | ||
) | ||
parser.add_argument( | ||
"--device", | ||
type=str, | ||
default="cpu", | ||
help="device type (most commonly cpu or cuda," | ||
"but also potentially mps, xpu, xla or meta)" | ||
"and optional device ordinal for the device type.", | ||
) | ||
parser.add_argument("--audio_file", help="mixture's audio file") | ||
parser.add_argument("--audio_file2", help="enroll's audio file") | ||
parser.add_argument( | ||
"--resample_rate", type=int, default=16000, help="resampling rate" | ||
) | ||
parser.add_argument( | ||
"--vad", action="store_true", help="whether to do VAD or not" | ||
) | ||
parser.add_argument( | ||
"--output_file", | ||
default='./extracted_speech.wav', | ||
help="extracted speech saved in .wav" | ||
) | ||
args = parser.parse_args() | ||
return args |