Skip to content

Commit

Permalink
add support for large-v3 and distil-whisper
Browse files Browse the repository at this point in the history
  • Loading branch information
shashikg committed Dec 19, 2023
1 parent 88eec90 commit d032bfa
Show file tree
Hide file tree
Showing 6 changed files with 318 additions and 14 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# WhisperS2T ⚡

WhisperS2T is an optimized lightning-fast speech-to-text pipeline tailored for the whisper model! It's designed to be exceptionally fast, boasting a 1.5X speed improvement over WhisperX and a 2X speed boost compared to HuggingFace Pipeline with FlashAttention 2 (Insanely Fast Whisper). Moreover, it includes several heuristics to enhance transcription accuracy.
WhisperS2T is an optimized lightning-fast speech-to-text pipeline tailored for the whisper model! It's designed to be exceptionally fast, boasting a 1.5X speed improvement over WhisperX and a 2X speed boost compared to HuggingFace Pipeline with FlashAttention 2 (Insanely Fast Whisper). Moreover, it includes several heuristics to enhance transcription accuracy.

[**Whisper**](https://github.com/openai/whisper) is a general-purpose speech recognition model developed by OpenAI. It is trained on a large dataset of diverse audio and is also a multitasking model that can perform multilingual speech recognition, speech translation, and language identification.

Expand All @@ -10,7 +10,8 @@ Stay tuned for a technical report comparing WhisperS2T against other whisper pip

![A30 Benchmark](files/benchmarks.png)

**NOTE:** I ran all the benchmarks with `without_timestamps` parameter as `True`. Setting `without_timestamps` as `False` may improve the WER of HuggingFace pipiline at the expense of additional inference time.
**NOTE:** I conducted all the benchmarks using the `without_timestamps` parameter set as `True`. Adjusting this parameter to `False` may enhance the Word Error Rate (WER) of the HuggingFace pipeline but at the expense of increased inference time. Notably, the improvements in inference speed were achieved solely through a **superior pipeline design**, without any specific optimization made to the backend inference engines (such as CTranslate2, FlashAttention2, etc.). For instance, WhisperS2T (utilizing FlashAttention2) demonstrates significantly superior inference speed compared to the HuggingFace pipeline (also using FlashAttention2), despite both leveraging the same inference engine—HuggingFace whisper model with FlashAttention2. Additionally, there is a noticeable difference in the WER as well.


## Features

Expand Down
159 changes: 159 additions & 0 deletions scripts/benchmark_huggingface_distil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import argparse
from rich.console import Console
console = Console()

def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--repo_path', default="", type=str)
parser.add_argument('--batch_size', default=16, type=int)
parser.add_argument('--flash_attention', default="yes", type=str)
parser.add_argument('--better_transformer', default="no", type=str)
parser.add_argument('--eval_mp3', default="no", type=str)
parser.add_argument('--eval_multilingual', default="no", type=str)
args = parser.parse_args()
return args


def run(repo_path, flash_attention=False, better_transformer=False, batch_size=16, eval_mp3=False, eval_multilingual=True):
import torch
import time, os
import pandas as pd
from transformers import pipeline

# Load Model >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
model_kwargs = {
"use_safetensors": True,
"low_cpu_mem_usage": True
}

results_dir = f"{repo_path}/results/HuggingFaceDistilWhisper-bs_{batch_size}"

if flash_attention:
results_dir = f"{results_dir}-fa"
model_kwargs["use_flash_attention_2"] = True

ASR = pipeline("automatic-speech-recognition",
f"distil-whisper/distil-large-v2",
num_workers=1,
torch_dtype=torch.float16,
device="cuda",
model_kwargs=model_kwargs)

if (not flash_attention) and better_transformer:
ASR.model = ASR.model.to_bettertransformer()
results_dir = f"{results_dir}-bt"

os.makedirs(results_dir, exist_ok=True)

# KINCAID46 WAV >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_wav.tsv', sep="\t")
files = [f"{repo_path}/{fn}" for fn in data['audio_path']]

with console.status("Warming"):
st = time.time()
_ = ASR(files,
batch_size=batch_size,
chunk_length_s=15,
generate_kwargs={'num_beams': 1, 'language': 'en'},
return_timestamps=False)

print(f"[Warming Time]: {time.time()-st}")

with console.status("KINCAID WAV"):
st = time.time()
outputs = ASR(files,
batch_size=batch_size,
chunk_length_s=15,
generate_kwargs={'num_beams': 1, 'language': 'en'},
return_timestamps=False)

time_kincaid46_wav = time.time()-st
print(f"[KINCAID WAV Time]: {time_kincaid46_wav}")

data['pred_text'] = [_['text'].strip() for _ in outputs]
data.to_csv(f"{results_dir}/KINCAID46_WAV.tsv", sep="\t", index=False)


# KINCAID46 MP3 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
if eval_mp3:
data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_mp3.tsv', sep="\t")
files = [f"{repo_path}/{fn}" for fn in data['audio_path']]

with console.status("KINCAID MP3"):
st = time.time()
outputs = ASR(files,
batch_size=batch_size,
chunk_length_s=30,
generate_kwargs={'num_beams': 1, 'language': 'en'},
return_timestamps=False)

time_kincaid46_mp3 = time.time()-st

print(f"[KINCAID MP3 Time]: {time_kincaid46_mp3}")

data['pred_text'] = [_['text'].strip() for _ in outputs]
data.to_csv(f"{results_dir}/KINCAID46_MP3.tsv", sep="\t", index=False)
else:
time_kincaid46_mp3 = 0.0

# MultiLingualLongform >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
if eval_multilingual:
data = pd.read_csv(f'{repo_path}/data/MultiLingualLongform/manifest.tsv', sep="\t")
files = [f"{repo_path}/{fn}" for fn in data['audio_path']]
lang_codes = data['lang_code'].to_list()

with console.status("MultiLingualLongform"):
st = time.time()

curr_files = [files[0]]
curr_lang = lang_codes[0]
outputs = []
for fn, lang in zip(files[1:], lang_codes[1:]):
if lang != curr_lang:
_outputs = ASR(curr_files,
batch_size=batch_size,
chunk_length_s=30,
generate_kwargs={'num_beams': 1, 'language': curr_lang},
return_timestamps=False)
outputs.extend(_outputs)

curr_files = [fn]
curr_lang = lang
else:
curr_files.append(fn)

_outputs = ASR(curr_files,
batch_size=batch_size,
chunk_length_s=30,
generate_kwargs={'num_beams': 1, 'language': curr_lang},
return_timestamps=False)

outputs.extend(_outputs)

time_multilingual = time.time()-st
print(f"[MultiLingualLongform Time]: {time_multilingual}")

data['pred_text'] = [_['text'].strip() for _ in outputs]
data.to_csv(f"{results_dir}/MultiLingualLongform.tsv", sep="\t", index=False)
else:
time_multilingual = 0.0

infer_time = [
["Dataset", "Time"],
["KINCAID46 WAV", time_kincaid46_wav],
["KINCAID46 MP3", time_kincaid46_mp3],
["MultiLingualLongform", time_multilingual]
]

infer_time = pd.DataFrame(infer_time[1:], columns=infer_time[0])
infer_time.to_csv(f"{results_dir}/infer_time.tsv", sep="\t", index=False)


if __name__ == '__main__':
args = parse_arguments()
eval_mp3 = True if args.eval_mp3 == "yes" else False
eval_multilingual = True if args.eval_multilingual == "yes" else False
flash_attention = True if args.flash_attention == "yes" else False
better_transformer = True if args.better_transformer == "yes" else False

run(args.repo_path, flash_attention=flash_attention, better_transformer=better_transformer, batch_size=args.batch_size, eval_mp3=eval_mp3, eval_multilingual=eval_multilingual)
129 changes: 129 additions & 0 deletions scripts/benchmark_whisper_s2t_distil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import argparse

def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--repo_path', default="", type=str)
parser.add_argument('--backend', default="HuggingFace", type=str)
parser.add_argument('--batch_size', default=16, type=int)
parser.add_argument('--flash_attention', default="yes", type=str)
parser.add_argument('--better_transformer', default="no", type=str)
parser.add_argument('--eval_mp3', default="no", type=str)
parser.add_argument('--eval_multilingual', default="no", type=str)
args = parser.parse_args()
return args

def run(repo_path, backend, flash_attention=False, better_transformer=False, batch_size=16, eval_mp3=False, eval_multilingual=True):
import sys, time, os

if len(repo_path):
sys.path.append(repo_path)

import whisper_s2t
import pandas as pd

if backend.lower() in ["huggingface", "hf"]:
asr_options = {
"use_flash_attention": flash_attention,
"use_better_transformer": better_transformer
}

if flash_attention:
results_dir = f"{repo_path}/results/WhisperS2T-{backend}DistilWhisper-bs_{batch_size}-fa"
elif better_transformer:
results_dir = f"{repo_path}/results/WhisperS2T-{backend}DistilWhisper-bs_{batch_size}-bt"
else:
results_dir = f"{repo_path}/results/WhisperS2T-{backend}DistilWhisper-bs_{batch_size}"
else:
asr_options = {}
results_dir = f"{repo_path}/results/WhisperS2T-{backend}-bs_{batch_size}"

os.makedirs(results_dir, exist_ok=True)

model = whisper_s2t.load_model("distil-large-v2", backend=backend, asr_options=asr_options)

# KINCAID46 WAV >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_wav.tsv', sep="\t")
files = [f"{repo_path}/{fn}" for fn in data['audio_path']]
lang_codes = len(files)*['en']
tasks = len(files)*['transcribe']
initial_prompts = len(files)*[None]

_ = model.transcribe_with_vad(files,
lang_codes=lang_codes,
tasks=tasks,
initial_prompts=initial_prompts,
batch_size=batch_size)

st = time.time()
out = model.transcribe_with_vad(files,
lang_codes=lang_codes,
tasks=tasks,
initial_prompts=initial_prompts,
batch_size=batch_size)
time_kincaid46_wav = time.time()-st

data['pred_text'] = [" ".join([_['text'] for _ in _transcript]).strip() for _transcript in out]
data.to_csv(f"{results_dir}/KINCAID46_WAV.tsv", sep="\t", index=False)


# KINCAID46 MP3 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
if eval_mp3:
data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_mp3.tsv', sep="\t")
files = [f"{repo_path}/{fn}" for fn in data['audio_path']]
lang_codes = len(files)*['en']
tasks = len(files)*['transcribe']
initial_prompts = len(files)*[None]

st = time.time()
out = model.transcribe_with_vad(files,
lang_codes=lang_codes,
tasks=tasks,
initial_prompts=initial_prompts,
batch_size=batch_size)
time_kincaid46_mp3 = time.time()-st

data['pred_text'] = [" ".join([_['text'] for _ in _transcript]).strip() for _transcript in out]
data.to_csv(f"{results_dir}/KINCAID46_MP3.tsv", sep="\t", index=False)
else:
time_kincaid46_mp3 = 0.0


# MultiLingualLongform
if eval_multilingual:
data = pd.read_csv(f'{repo_path}/data/MultiLingualLongform/manifest.tsv', sep="\t")
files = [f"{repo_path}/{fn}" for fn in data['audio_path']]
lang_codes = data['lang_code'].to_list()
tasks = len(files)*['transcribe']
initial_prompts = len(files)*[None]

st = time.time()
out = model.transcribe_with_vad(files,
lang_codes=lang_codes,
tasks=tasks,
initial_prompts=initial_prompts,
batch_size=batch_size)
time_multilingual = time.time()-st

data['pred_text'] = [" ".join([_['text'] for _ in _transcript]).strip() for _transcript in out]
data.to_csv(f"{results_dir}/MultiLingualLongform.tsv", sep="\t", index=False)
else:
time_multilingual = 0.0

infer_time = [
["Dataset", "Time"],
["KINCAID46 WAV", time_kincaid46_wav],
["KINCAID46 MP3", time_kincaid46_mp3],
["MultiLingualLongform", time_multilingual]
]
infer_time = pd.DataFrame(infer_time[1:], columns=infer_time[0])
infer_time.to_csv(f"{results_dir}/infer_time.tsv", sep="\t", index=False)


if __name__ == '__main__':
args = parse_arguments()
eval_mp3 = True if args.eval_mp3 == "yes" else False
eval_multilingual = True if args.eval_multilingual == "yes" else False
flash_attention = True if args.flash_attention == "yes" else False
better_transformer = True if args.better_transformer == "yes" else False

run(args.repo_path, args.backend, flash_attention=flash_attention, better_transformer=better_transformer, batch_size=args.batch_size, eval_mp3=eval_mp3, eval_multilingual=eval_multilingual)
15 changes: 14 additions & 1 deletion whisper_s2t/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,25 @@ def load_model(model_identifier="large-v2",
backend='CTranslate2',
**model_kwargs):

if model_identifier in ['large-v3']:
model_kwargs['n_mels'] = 128
elif (model_identifier in ['distil-large-v2']) and (backend.lower() not in ["huggingface", "hf"]):
print(f"Switching backend to HuggingFace. Distill whisper is only supported with HuggingFace backend.")
backend = "huggingface"

model_kwargs['max_speech_len'] = 15.0
model_kwargs['max_text_token_len'] = 128

if backend.lower() in ["ctranslate2", "ct2"]:
from .backends.ctranslate2.model import WhisperModelCT2 as WhisperModel

elif backend.lower() in ["huggingface", "hf"]:
from .backends.huggingface.model import WhisperModelHF as WhisperModel
model_identifier = f"openai/whisper-{model_identifier}"

if 'distil' in model_identifier:
model_identifier = f"distil-whisper/{model_identifier}"
else:
model_identifier = f"openai/whisper-{model_identifier}"

elif backend.lower() in ["openai", "oai"]:
from .backends.openai.model import WhisperModelOAI as WhisperModel
Expand Down
1 change: 1 addition & 0 deletions whisper_s2t/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self,

self.vad_model = vad_model
self.speech_segmenter_options = speech_segmenter_options
self.speech_segmenter_options['max_seg_len'] = self.max_speech_len

# Tokenizer
if tokenizer is None:
Expand Down
23 changes: 12 additions & 11 deletions whisper_s2t/backends/ctranslate2/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@


_MODELS = {
"tiny.en": "guillaumekln/faster-whisper-tiny.en",
"tiny": "guillaumekln/faster-whisper-tiny",
"base.en": "guillaumekln/faster-whisper-base.en",
"base": "guillaumekln/faster-whisper-base",
"small.en": "guillaumekln/faster-whisper-small.en",
"small": "guillaumekln/faster-whisper-small",
"medium.en": "guillaumekln/faster-whisper-medium.en",
"medium": "guillaumekln/faster-whisper-medium",
"large-v1": "guillaumekln/faster-whisper-large-v1",
"large-v2": "guillaumekln/faster-whisper-large-v2",
"large": "guillaumekln/faster-whisper-large-v2",
"tiny.en": "Systran/faster-whisper-tiny.en",
"tiny": "Systran/faster-whisper-tiny",
"base.en": "Systran/faster-whisper-base.en",
"base": "Systran/faster-whisper-base",
"small.en": "Systran/faster-whisper-small.en",
"small": "Systran/faster-whisper-small",
"medium.en": "Systran/faster-whisper-medium.en",
"medium": "Systran/faster-whisper-medium",
"large-v1": "Systran/faster-whisper-large-v1",
"large-v2": "Systran/faster-whisper-large-v2",
"large-v3": "Systran/faster-whisper-large-v3",
"large": "Systran/faster-whisper-large-v3",
}


Expand Down

0 comments on commit d032bfa

Please sign in to comment.