Skip to content

Commit

Permalink
support command line mode
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleZhang1118 committed Oct 11, 2024
1 parent 7bf311f commit 9203816
Show file tree
Hide file tree
Showing 7 changed files with 382 additions and 0 deletions.
20 changes: 20 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from setuptools import setup, find_packages

requirements = [
"tqdm",
"kaldiio",
"torch>=1.12.0",
"torchaudio>=0.12.0",
"silero-vad",
]

setup(
name="wesep",
install_requires=requirements,
packages=find_packages(),
entry_points={
"console_scripts": [
"wesep = wesep.cli.extractor:main",
],
},
)
2 changes: 2 additions & 0 deletions wesep/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from wesep.cli.extractor import load_model # noqa
from wesep.cli.extractor import load_model_local # noqa
2 changes: 2 additions & 0 deletions wesep/bin/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def infer(config="confs/conf.yaml", **kwargs):
else:
sample_rate = 8000

if 'spk_model_init' in configs['model_args']['tse_model']:
configs['model_args']['tse_model']['spk_model_init'] = False
model = get_model(
configs["model"]["tse_model"])(**configs["model_args"]["tse_model"])
model_path = os.path.join(configs["checkpoint"])
Expand Down
Empty file added wesep/cli/__init__.py
Empty file.
182 changes: 182 additions & 0 deletions wesep/cli/extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import os
import sys

import numpy as np
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import yaml
from tqdm import tqdm
import soundfile

from wesep.cli.hub import Hub
from wesep.cli.utils import get_args
from wesep.models import get_model
from wesep.utils.checkpoint import load_pretrained_model
from wesep.utils.utils import set_seed


class Extractor:

def __init__(self, model_dir: str):
set_seed()

config_path = os.path.join(model_dir, "config.yaml")
model_path = os.path.join(model_dir, "avg_model.pt")
with open(config_path, "r") as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
if 'spk_model_init' in configs['model_args']['tse_model']:
configs['model_args']['tse_model']['spk_model_init'] = False
self.model = get_model(configs["model"]["tse_model"])(
**configs["model_args"]["tse_model"]
)
load_pretrained_model(self.model, model_path)
self.model.eval()
self.vad = load_silero_vad()
self.table = {}
self.resample_rate = configs["dataset_args"].get("resample_rate", 16000)
self.apply_vad = False
self.device = torch.device("cpu")
self.wavform_norm = False

self.speaker_feat = configs["model_args"]["tse_model"].get("spk_feat", False)
self.joint_training = configs["model_args"]["tse_model"].get(
"joint_training", False
)

def set_wavform_norm(self, wavform_norm: bool):
self.wavform_norm = wavform_norm

def set_resample_rate(self, resample_rate: int):
self.resample_rate = resample_rate

def set_vad(self, apply_vad: bool):
self.apply_vad = apply_vad

def set_device(self, device: str):
self.device = torch.device(device)
self.model = self.model.to(self.device)

def compute_fbank(
self,
wavform,
sample_rate=16000,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
cmn=True,
):
feat = kaldi.fbank(
wavform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
sample_frequency=sample_rate,
)
if cmn:
feat = feat - torch.mean(feat, 0)
return feat

def extract_speech(self, audio_path: str, audio_path_2: str):
pcm_mix, sample_rate_mix = torchaudio.load(
audio_path, normalize=self.wavform_norm
)
pcm_enroll, sample_rate_enroll = torchaudio.load(
audio_path_2, normalize=self.wavform_norm
)
return self.extract_speech_from_pcm(pcm_mix, sample_rate_mix, pcm_enroll, sample_rate_enroll)

def extract_speech_from_pcm(self, pcm_mix: torch.Tensor, sample_rate_mix: int, pcm_enroll: torch.Tensor, sample_rate_enroll: int,):
if self.apply_vad:
# TODO(Binbin Zhang): Refine the segments logic, here we just
# suppose there is only silence at the start/end of the speech
# Only do vad on the enrollment
vad_sample_rate = 16000
wav = pcm_enroll
if wav.size(0) > 1:
wav = wav.mean(dim=0, keepdim=True)
if sample_rate_enroll != vad_sample_rate:
transform = torchaudio.transforms.Resample(
orig_freq=sample_rate_enroll, new_freq=vad_sample_rate
)
wav = transform(wav)

segments = get_speech_timestamps(wav, self.vad, return_seconds=True)
pcmTotal = torch.Tensor()
if len(segments) > 0: # remove all the silence
for segment in segments:
start = int(segment["start"] * sample_rate_enroll)
end = int(segment["end"] * sample_rate_enroll)
pcmTemp = pcm_enroll[0, start:end]
pcmTotal = torch.cat([pcmTotal, pcmTemp], 0)
pcm_enroll = pcmTotal.unsqueeze(0)
else: # all silence, nospeech
return None

pcm_mix = pcm_mix.to(torch.float)
if sample_rate_mix != self.resample_rate:
pcm_mix = torchaudio.transforms.Resample(
orig_freq=sample_rate_mix, new_freq=self.resample_rate
)(pcm_mix)
pcm_enroll = pcm_enroll.to(torch.float)
if sample_rate_enroll != self.resample_rate:
pcm_enroll = torchaudio.transforms.Resample(
orig_freq=sample_rate_enroll, new_freq=self.resample_rate
)(pcm_enroll)

if self.joint_training:
if self.speaker_feat:
feats = self.compute_fbank(
pcm_enroll, sample_rate=self.resample_rate, cmn=True
)
feats = feats.unsqueeze(0)
feats = feats.to(self.device)
else:
feats = pcm_enroll

with torch.no_grad():
outputs = self.model(pcm_mix, feats)
outputs = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
target_speech = outputs.to(torch.device("cpu"))
return target_speech
else:
return None


def load_model(language: str) -> Extractor:
model_path = Hub.get_model(language)
return Extractor(model_path)


def load_model_local(model_dir: str) -> Extractor:
return Extractor(model_dir)


def main():
args = get_args()
if args.pretrain == "":
if args.bsrnn:
model = load_model("bsrnn")
else:
model = load_model(args.language)
model.set_wavform_norm(True)
else:
model = load_model_local(args.pretrain)
model.set_resample_rate(args.resample_rate)
model.set_vad(args.vad)
model.set_device(args.device)
if args.task == "extraction":
speech = model.extract_speech(args.audio_file, args.audio_file2)
if speech is not None:
soundfile.write(args.output_file, speech[0], args.resample_rate)
print("Succeed, see {}".format(args.output_file))
else:
print("Fails to extract the target speech")
else:
print("Unsupported task {}".format(args.task))
sys.exit(-1)


if __name__ == "__main__":
main()
121 changes: 121 additions & 0 deletions wesep/cli/hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) 2022 Mddct(hamddct@gmail.com)
# 2023 Binbin Zhang(binbzha@qq.com)
# 2024 Shuai Wang(wsstriving@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import requests
import sys
from pathlib import Path
import tarfile
import zipfile
from urllib.request import urlretrieve

import tqdm


def download(url: str, dest: str, only_child=True):
"""download from url to dest"""
assert os.path.exists(dest)
print("Downloading {} to {}".format(url, dest))

def progress_hook(t):
last_b = [0]

def update_to(b=1, bsize=1, tsize=None):
if tsize not in (None, -1):
t.total = tsize
displayed = t.update((b - last_b[0]) * bsize)
last_b[0] = b
return displayed

return update_to

# *.tar.gz
name = url.split("?")[0].split("/")[-1]
file_path = os.path.join(dest, name)
with tqdm.tqdm(
unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=(name)
) as t:
urlretrieve(
url, filename=file_path, reporthook=progress_hook(t), data=None
)
t.total = t.n

if name.endswith((".tar.gz", ".tar")):
with tarfile.open(file_path) as f:
if not only_child:
f.extractall(dest)
else:
for tarinfo in f:
if "/" not in tarinfo.name:
continue
name = os.path.basename(tarinfo.name)
fileobj = f.extractfile(tarinfo)
with open(os.path.join(dest, name), "wb") as writer:
writer.write(fileobj.read())

elif name.endswith(".zip"):
with zipfile.ZipFile(file_path, "r") as zip_ref:
if not only_child:
zip_ref.extractall(dest)
else:
for member in zip_ref.namelist():
member_path = os.path.relpath(
member, start=os.path.commonpath(zip_ref.namelist())
)
print(member_path)
if "/" not in member_path:
continue
name = os.path.basename(member_path)
with zip_ref.open(member_path) as source, open(
os.path.join(dest, name), "wb"
) as target:
target.write(source.read())


class Hub(object):
Assets = {
"english": "bsrnn_ecapa_vox1.tar.gz",
}
# Hard coding of the URL
ModelURLs = {
"bsrnn_ecapa_vox1.tar.gz": "https://www.modelscope.cn/datasets/wenet/wesep_pretrained_models/resolve/master/bsrnn_ecapa_vox1.tar.gz",
}

def __init__(self) -> None:
pass

@staticmethod
def get_model(lang: str) -> str:
if lang not in Hub.Assets.keys():
print("ERROR: Unsupported lang {} !!!".format(lang))
sys.exit(1)
# model = Hub.Assets[lang]
model_name = Hub.Assets[lang]
model_dir = os.path.join(Path.home(), ".wesep", lang)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
if set(["avg_model.pt", "config.yaml"]).issubset(
set(os.listdir(model_dir))
):
return model_dir
else:
if model_name in Hub.ModelURLs:
model_url = Hub.ModelURLs[model_name]
download(model_url, model_dir)
return model_dir
else:
print(f"ERROR: No URL found for model {model_name}")
return None
55 changes: 55 additions & 0 deletions wesep/cli/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import argparse


def get_args():
parser = argparse.ArgumentParser(description="")
parser.add_argument(
"-t",
"--task",
choices=[
"extraction",
],
default="extraction",
help="task type",
)
parser.add_argument(
"-l",
"--language",
choices=[
# "chinese",
"english",
],
default="english",
help="language type",
)
parser.add_argument(
"--bsrnn",
action="store_true",
help="whether to use the bsrnn model",
)
parser.add_argument(
"-p", "--pretrain", type=str, default="", help="model directory"
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="device type (most commonly cpu or cuda,"
"but also potentially mps, xpu, xla or meta)"
"and optional device ordinal for the device type.",
)
parser.add_argument("--audio_file", help="mixture's audio file")
parser.add_argument("--audio_file2", help="enroll's audio file")
parser.add_argument(
"--resample_rate", type=int, default=16000, help="resampling rate"
)
parser.add_argument(
"--vad", action="store_true", help="whether to do VAD or not"
)
parser.add_argument(
"--output_file",
default='./extracted_speech.wav',
help="extracted speech saved in .wav"
)
args = parser.parse_args()
return args

0 comments on commit 9203816

Please sign in to comment.