Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support command line mode #10

Merged
merged 5 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ jobs:
- name: Run cpplint
run: |
set -eux
pip install cpplint
pip install cpplint==1.6.1
cpplint --version
cpplint --recursive .
if [ $? != 0 ]; then exit 1; fi
Expand Down
20 changes: 20 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from setuptools import setup, find_packages

requirements = [
"tqdm",
"kaldiio",
"torch>=1.12.0",
"torchaudio>=0.12.0",
"silero-vad",
]

setup(
name="wesep",
install_requires=requirements,
packages=find_packages(),
entry_points={
"console_scripts": [
"wesep = wesep.cli.extractor:main",
],
},
)
2 changes: 2 additions & 0 deletions wesep/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from wesep.cli.extractor import load_model # noqa
from wesep.cli.extractor import load_model_local # noqa
2 changes: 2 additions & 0 deletions wesep/bin/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def infer(config="confs/conf.yaml", **kwargs):
else:
sample_rate = 8000

if 'spk_model_init' in configs['model_args']['tse_model']:
configs['model_args']['tse_model']['spk_model_init'] = False
model = get_model(
configs["model"]["tse_model"])(**configs["model_args"]["tse_model"])
model_path = os.path.join(configs["checkpoint"])
Expand Down
Empty file added wesep/cli/__init__.py
Empty file.
187 changes: 187 additions & 0 deletions wesep/cli/extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import os
import sys

from silero_vad import load_silero_vad, get_speech_timestamps
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import yaml
import soundfile

from wesep.cli.hub import Hub
from wesep.cli.utils import get_args
from wesep.models import get_model
from wesep.utils.checkpoint import load_pretrained_model
from wesep.utils.utils import set_seed


class Extractor:

def __init__(self, model_dir: str):
set_seed()

config_path = os.path.join(model_dir, "config.yaml")
model_path = os.path.join(model_dir, "avg_model.pt")
with open(config_path, "r") as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
if 'spk_model_init' in configs['model_args']['tse_model']:
configs['model_args']['tse_model']['spk_model_init'] = False
self.model = get_model(configs["model"]["tse_model"])(
**configs["model_args"]["tse_model"]
)
load_pretrained_model(self.model, model_path)
self.model.eval()
self.vad = load_silero_vad()
self.table = {}
self.resample_rate = configs["dataset_args"].get("resample_rate", 16000)
self.apply_vad = False
self.device = torch.device("cpu")
self.wavform_norm = False

self.speaker_feat = configs["model_args"]["tse_model"].get("spk_feat", False)
self.joint_training = configs["model_args"]["tse_model"].get(
"joint_training", False
)

def set_wavform_norm(self, wavform_norm: bool):
self.wavform_norm = wavform_norm

def set_resample_rate(self, resample_rate: int):
self.resample_rate = resample_rate

def set_vad(self, apply_vad: bool):
self.apply_vad = apply_vad

def set_device(self, device: str):
self.device = torch.device(device)
self.model = self.model.to(self.device)

def compute_fbank(
self,
wavform,
sample_rate=16000,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
cmn=True,
):
feat = kaldi.fbank(
wavform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
sample_frequency=sample_rate,
)
if cmn:
feat = feat - torch.mean(feat, 0)
return feat

def extract_speech(self, audio_path: str, audio_path_2: str):
pcm_mix, sample_rate_mix = torchaudio.load(
audio_path, normalize=self.wavform_norm
)
pcm_enroll, sample_rate_enroll = torchaudio.load(
audio_path_2, normalize=self.wavform_norm
)
return self.extract_speech_from_pcm(pcm_mix,
sample_rate_mix,
pcm_enroll,
sample_rate_enroll)

def extract_speech_from_pcm(self,
pcm_mix: torch.Tensor,
sample_rate_mix: int,
pcm_enroll: torch.Tensor,
sample_rate_enroll: int):
if self.apply_vad:
# TODO(Binbin Zhang): Refine the segments logic, here we just
# suppose there is only silence at the start/end of the speech
# Only do vad on the enrollment
vad_sample_rate = 16000
wav = pcm_enroll
if wav.size(0) > 1:
wav = wav.mean(dim=0, keepdim=True)
if sample_rate_enroll != vad_sample_rate:
transform = torchaudio.transforms.Resample(
orig_freq=sample_rate_enroll, new_freq=vad_sample_rate
)
wav = transform(wav)

segments = get_speech_timestamps(wav, self.vad, return_seconds=True)
pcmTotal = torch.Tensor()
if len(segments) > 0: # remove all the silence
for segment in segments:
start = int(segment["start"] * sample_rate_enroll)
end = int(segment["end"] * sample_rate_enroll)
pcmTemp = pcm_enroll[0, start:end]
pcmTotal = torch.cat([pcmTotal, pcmTemp], 0)
pcm_enroll = pcmTotal.unsqueeze(0)
else: # all silence, nospeech
return None

pcm_mix = pcm_mix.to(torch.float)
if sample_rate_mix != self.resample_rate:
pcm_mix = torchaudio.transforms.Resample(
orig_freq=sample_rate_mix, new_freq=self.resample_rate
)(pcm_mix)
pcm_enroll = pcm_enroll.to(torch.float)
if sample_rate_enroll != self.resample_rate:
pcm_enroll = torchaudio.transforms.Resample(
orig_freq=sample_rate_enroll, new_freq=self.resample_rate
)(pcm_enroll)

if self.joint_training:
if self.speaker_feat:
feats = self.compute_fbank(
pcm_enroll, sample_rate=self.resample_rate, cmn=True
)
feats = feats.unsqueeze(0)
feats = feats.to(self.device)
else:
feats = pcm_enroll

with torch.no_grad():
outputs = self.model(pcm_mix, feats)
outputs = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
target_speech = outputs.to(torch.device("cpu"))
return target_speech
else:
return None


def load_model(language: str) -> Extractor:
model_path = Hub.get_model(language)
return Extractor(model_path)


def load_model_local(model_dir: str) -> Extractor:
return Extractor(model_dir)


def main():
args = get_args()
if args.pretrain == "":
if args.bsrnn:
model = load_model("bsrnn")
else:
model = load_model(args.language)
model.set_wavform_norm(True)
else:
model = load_model_local(args.pretrain)
model.set_resample_rate(args.resample_rate)
model.set_vad(args.vad)
model.set_device(args.device)
if args.task == "extraction":
speech = model.extract_speech(args.audio_file, args.audio_file2)
if speech is not None:
soundfile.write(args.output_file, speech[0], args.resample_rate)
print("Succeed, see {}".format(args.output_file))
else:
print("Fails to extract the target speech")
else:
print("Unsupported task {}".format(args.task))
sys.exit(-1)


if __name__ == "__main__":
main()
123 changes: 123 additions & 0 deletions wesep/cli/hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright (c) 2022 Mddct(hamddct@gmail.com)
# 2023 Binbin Zhang(binbzha@qq.com)
# 2024 Shuai Wang(wsstriving@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
from pathlib import Path
import tarfile
import zipfile
from urllib.request import urlretrieve

import tqdm


def download(url: str, dest: str, only_child=True):
"""download from url to dest"""
assert os.path.exists(dest)
print("Downloading {} to {}".format(url, dest))

def progress_hook(t):
last_b = [0]

def update_to(b=1, bsize=1, tsize=None):
if tsize not in (None, -1):
t.total = tsize
displayed = t.update((b - last_b[0]) * bsize)
last_b[0] = b
return displayed

return update_to

# *.tar.gz
name = url.split("?")[0].split("/")[-1]
file_path = os.path.join(dest, name)
with tqdm.tqdm(
unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=(name)
) as t:
urlretrieve(
url, filename=file_path, reporthook=progress_hook(t), data=None
)
t.total = t.n

if name.endswith((".tar.gz", ".tar")):
with tarfile.open(file_path) as f:
if not only_child:
f.extractall(dest)
else:
for tarinfo in f:
if "/" not in tarinfo.name:
continue
name = os.path.basename(tarinfo.name)
fileobj = f.extractfile(tarinfo)
with open(os.path.join(dest, name), "wb") as writer:
writer.write(fileobj.read())

elif name.endswith(".zip"):
with zipfile.ZipFile(file_path, "r") as zip_ref:
if not only_child:
zip_ref.extractall(dest)
else:
for member in zip_ref.namelist():
member_path = os.path.relpath(
member, start=os.path.commonpath(zip_ref.namelist())
)
print(member_path)
if "/" not in member_path:
continue
name = os.path.basename(member_path)
with zip_ref.open(member_path) as source, open(
os.path.join(dest, name), "wb"
) as target:
target.write(source.read())


class Hub(object):
Assets = {
"english": "bsrnn_ecapa_vox1.tar.gz",
}
# Hard coding of the URL
ModelURLs = {
"bsrnn_ecapa_vox1.tar.gz": (
"https://www.modelscope.cn/datasets/wenet/wesep_pretrained_models/"
"resolve/master/bsrnn_ecapa_vox1.tar.gz"
),
}

def __init__(self) -> None:
pass

@staticmethod
def get_model(lang: str) -> str:
if lang not in Hub.Assets.keys():
print("ERROR: Unsupported lang {} !!!".format(lang))
sys.exit(1)
# model = Hub.Assets[lang]
model_name = Hub.Assets[lang]
model_dir = os.path.join(Path.home(), ".wesep", lang)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
if set(["avg_model.pt", "config.yaml"]).issubset(
set(os.listdir(model_dir))
):
return model_dir
else:
if model_name in Hub.ModelURLs:
model_url = Hub.ModelURLs[model_name]
download(model_url, model_dir)
return model_dir
else:
print(f"ERROR: No URL found for model {model_name}")
return None
Loading
Loading