diff --git a/README.md b/README.md index 2acafc4..1b356c2 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # WhisperS2T ⚡ -WhisperS2T is an optimized lightning-fast speech-to-text pipeline tailored for the whisper model! It's designed to be exceptionally fast, boasting a 1.5X speed improvement over WhisperX and a 2X speed boost compared to HuggingFace Pipeline with FlashAttention 2 (Insanely Fast Whisper). Moreover, it includes several heuristics to enhance transcription accuracy. +WhisperS2T is an optimized lightning-fast speech-to-text pipeline tailored for the whisper model! It's designed to be exceptionally fast, boasting a 1.5X speed improvement over WhisperX and a 2X speed boost compared to HuggingFace Pipeline with FlashAttention 2 (Insanely Fast Whisper). Moreover, it includes several heuristics to enhance transcription accuracy. [**Whisper**](https://github.com/openai/whisper) is a general-purpose speech recognition model developed by OpenAI. It is trained on a large dataset of diverse audio and is also a multitasking model that can perform multilingual speech recognition, speech translation, and language identification. @@ -10,7 +10,8 @@ Stay tuned for a technical report comparing WhisperS2T against other whisper pip ![A30 Benchmark](files/benchmarks.png) -**NOTE:** I ran all the benchmarks with `without_timestamps` parameter as `True`. Setting `without_timestamps` as `False` may improve the WER of HuggingFace pipiline at the expense of additional inference time. +**NOTE:** I conducted all the benchmarks using the `without_timestamps` parameter set as `True`. Adjusting this parameter to `False` may enhance the Word Error Rate (WER) of the HuggingFace pipeline but at the expense of increased inference time. Notably, the improvements in inference speed were achieved solely through a **superior pipeline design**, without any specific optimization made to the backend inference engines (such as CTranslate2, FlashAttention2, etc.). For instance, WhisperS2T (utilizing FlashAttention2) demonstrates significantly superior inference speed compared to the HuggingFace pipeline (also using FlashAttention2), despite both leveraging the same inference engine—HuggingFace whisper model with FlashAttention2. Additionally, there is a noticeable difference in the WER as well. + ## Features diff --git a/scripts/benchmark_huggingface_distil.py b/scripts/benchmark_huggingface_distil.py new file mode 100644 index 0000000..405816b --- /dev/null +++ b/scripts/benchmark_huggingface_distil.py @@ -0,0 +1,159 @@ +import argparse +from rich.console import Console +console = Console() + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--repo_path', default="", type=str) + parser.add_argument('--batch_size', default=16, type=int) + parser.add_argument('--flash_attention', default="yes", type=str) + parser.add_argument('--better_transformer', default="no", type=str) + parser.add_argument('--eval_mp3', default="no", type=str) + parser.add_argument('--eval_multilingual', default="no", type=str) + args = parser.parse_args() + return args + + +def run(repo_path, flash_attention=False, better_transformer=False, batch_size=16, eval_mp3=False, eval_multilingual=True): + import torch + import time, os + import pandas as pd + from transformers import pipeline + + # Load Model >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + model_kwargs = { + "use_safetensors": True, + "low_cpu_mem_usage": True + } + + results_dir = f"{repo_path}/results/HuggingFaceDistilWhisper-bs_{batch_size}" + + if flash_attention: + results_dir = f"{results_dir}-fa" + model_kwargs["use_flash_attention_2"] = True + + ASR = pipeline("automatic-speech-recognition", + f"distil-whisper/distil-large-v2", + num_workers=1, + torch_dtype=torch.float16, + device="cuda", + model_kwargs=model_kwargs) + + if (not flash_attention) and better_transformer: + ASR.model = ASR.model.to_bettertransformer() + results_dir = f"{results_dir}-bt" + + os.makedirs(results_dir, exist_ok=True) + + # KINCAID46 WAV >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_wav.tsv', sep="\t") + files = [f"{repo_path}/{fn}" for fn in data['audio_path']] + + with console.status("Warming"): + st = time.time() + _ = ASR(files, + batch_size=batch_size, + chunk_length_s=15, + generate_kwargs={'num_beams': 1, 'language': 'en'}, + return_timestamps=False) + + print(f"[Warming Time]: {time.time()-st}") + + with console.status("KINCAID WAV"): + st = time.time() + outputs = ASR(files, + batch_size=batch_size, + chunk_length_s=15, + generate_kwargs={'num_beams': 1, 'language': 'en'}, + return_timestamps=False) + + time_kincaid46_wav = time.time()-st + print(f"[KINCAID WAV Time]: {time_kincaid46_wav}") + + data['pred_text'] = [_['text'].strip() for _ in outputs] + data.to_csv(f"{results_dir}/KINCAID46_WAV.tsv", sep="\t", index=False) + + + # KINCAID46 MP3 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + if eval_mp3: + data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_mp3.tsv', sep="\t") + files = [f"{repo_path}/{fn}" for fn in data['audio_path']] + + with console.status("KINCAID MP3"): + st = time.time() + outputs = ASR(files, + batch_size=batch_size, + chunk_length_s=30, + generate_kwargs={'num_beams': 1, 'language': 'en'}, + return_timestamps=False) + + time_kincaid46_mp3 = time.time()-st + + print(f"[KINCAID MP3 Time]: {time_kincaid46_mp3}") + + data['pred_text'] = [_['text'].strip() for _ in outputs] + data.to_csv(f"{results_dir}/KINCAID46_MP3.tsv", sep="\t", index=False) + else: + time_kincaid46_mp3 = 0.0 + + # MultiLingualLongform >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + if eval_multilingual: + data = pd.read_csv(f'{repo_path}/data/MultiLingualLongform/manifest.tsv', sep="\t") + files = [f"{repo_path}/{fn}" for fn in data['audio_path']] + lang_codes = data['lang_code'].to_list() + + with console.status("MultiLingualLongform"): + st = time.time() + + curr_files = [files[0]] + curr_lang = lang_codes[0] + outputs = [] + for fn, lang in zip(files[1:], lang_codes[1:]): + if lang != curr_lang: + _outputs = ASR(curr_files, + batch_size=batch_size, + chunk_length_s=30, + generate_kwargs={'num_beams': 1, 'language': curr_lang}, + return_timestamps=False) + outputs.extend(_outputs) + + curr_files = [fn] + curr_lang = lang + else: + curr_files.append(fn) + + _outputs = ASR(curr_files, + batch_size=batch_size, + chunk_length_s=30, + generate_kwargs={'num_beams': 1, 'language': curr_lang}, + return_timestamps=False) + + outputs.extend(_outputs) + + time_multilingual = time.time()-st + print(f"[MultiLingualLongform Time]: {time_multilingual}") + + data['pred_text'] = [_['text'].strip() for _ in outputs] + data.to_csv(f"{results_dir}/MultiLingualLongform.tsv", sep="\t", index=False) + else: + time_multilingual = 0.0 + + infer_time = [ + ["Dataset", "Time"], + ["KINCAID46 WAV", time_kincaid46_wav], + ["KINCAID46 MP3", time_kincaid46_mp3], + ["MultiLingualLongform", time_multilingual] + ] + + infer_time = pd.DataFrame(infer_time[1:], columns=infer_time[0]) + infer_time.to_csv(f"{results_dir}/infer_time.tsv", sep="\t", index=False) + + +if __name__ == '__main__': + args = parse_arguments() + eval_mp3 = True if args.eval_mp3 == "yes" else False + eval_multilingual = True if args.eval_multilingual == "yes" else False + flash_attention = True if args.flash_attention == "yes" else False + better_transformer = True if args.better_transformer == "yes" else False + + run(args.repo_path, flash_attention=flash_attention, better_transformer=better_transformer, batch_size=args.batch_size, eval_mp3=eval_mp3, eval_multilingual=eval_multilingual) \ No newline at end of file diff --git a/scripts/benchmark_whisper_s2t_distil.py b/scripts/benchmark_whisper_s2t_distil.py new file mode 100644 index 0000000..de413e8 --- /dev/null +++ b/scripts/benchmark_whisper_s2t_distil.py @@ -0,0 +1,129 @@ +import argparse + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--repo_path', default="", type=str) + parser.add_argument('--backend', default="HuggingFace", type=str) + parser.add_argument('--batch_size', default=16, type=int) + parser.add_argument('--flash_attention', default="yes", type=str) + parser.add_argument('--better_transformer', default="no", type=str) + parser.add_argument('--eval_mp3', default="no", type=str) + parser.add_argument('--eval_multilingual', default="no", type=str) + args = parser.parse_args() + return args + +def run(repo_path, backend, flash_attention=False, better_transformer=False, batch_size=16, eval_mp3=False, eval_multilingual=True): + import sys, time, os + + if len(repo_path): + sys.path.append(repo_path) + + import whisper_s2t + import pandas as pd + + if backend.lower() in ["huggingface", "hf"]: + asr_options = { + "use_flash_attention": flash_attention, + "use_better_transformer": better_transformer + } + + if flash_attention: + results_dir = f"{repo_path}/results/WhisperS2T-{backend}DistilWhisper-bs_{batch_size}-fa" + elif better_transformer: + results_dir = f"{repo_path}/results/WhisperS2T-{backend}DistilWhisper-bs_{batch_size}-bt" + else: + results_dir = f"{repo_path}/results/WhisperS2T-{backend}DistilWhisper-bs_{batch_size}" + else: + asr_options = {} + results_dir = f"{repo_path}/results/WhisperS2T-{backend}-bs_{batch_size}" + + os.makedirs(results_dir, exist_ok=True) + + model = whisper_s2t.load_model("distil-large-v2", backend=backend, asr_options=asr_options) + + # KINCAID46 WAV >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_wav.tsv', sep="\t") + files = [f"{repo_path}/{fn}" for fn in data['audio_path']] + lang_codes = len(files)*['en'] + tasks = len(files)*['transcribe'] + initial_prompts = len(files)*[None] + + _ = model.transcribe_with_vad(files, + lang_codes=lang_codes, + tasks=tasks, + initial_prompts=initial_prompts, + batch_size=batch_size) + + st = time.time() + out = model.transcribe_with_vad(files, + lang_codes=lang_codes, + tasks=tasks, + initial_prompts=initial_prompts, + batch_size=batch_size) + time_kincaid46_wav = time.time()-st + + data['pred_text'] = [" ".join([_['text'] for _ in _transcript]).strip() for _transcript in out] + data.to_csv(f"{results_dir}/KINCAID46_WAV.tsv", sep="\t", index=False) + + + # KINCAID46 MP3 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + if eval_mp3: + data = pd.read_csv(f'{repo_path}/data/KINCAID46/manifest_mp3.tsv', sep="\t") + files = [f"{repo_path}/{fn}" for fn in data['audio_path']] + lang_codes = len(files)*['en'] + tasks = len(files)*['transcribe'] + initial_prompts = len(files)*[None] + + st = time.time() + out = model.transcribe_with_vad(files, + lang_codes=lang_codes, + tasks=tasks, + initial_prompts=initial_prompts, + batch_size=batch_size) + time_kincaid46_mp3 = time.time()-st + + data['pred_text'] = [" ".join([_['text'] for _ in _transcript]).strip() for _transcript in out] + data.to_csv(f"{results_dir}/KINCAID46_MP3.tsv", sep="\t", index=False) + else: + time_kincaid46_mp3 = 0.0 + + + # MultiLingualLongform + if eval_multilingual: + data = pd.read_csv(f'{repo_path}/data/MultiLingualLongform/manifest.tsv', sep="\t") + files = [f"{repo_path}/{fn}" for fn in data['audio_path']] + lang_codes = data['lang_code'].to_list() + tasks = len(files)*['transcribe'] + initial_prompts = len(files)*[None] + + st = time.time() + out = model.transcribe_with_vad(files, + lang_codes=lang_codes, + tasks=tasks, + initial_prompts=initial_prompts, + batch_size=batch_size) + time_multilingual = time.time()-st + + data['pred_text'] = [" ".join([_['text'] for _ in _transcript]).strip() for _transcript in out] + data.to_csv(f"{results_dir}/MultiLingualLongform.tsv", sep="\t", index=False) + else: + time_multilingual = 0.0 + + infer_time = [ + ["Dataset", "Time"], + ["KINCAID46 WAV", time_kincaid46_wav], + ["KINCAID46 MP3", time_kincaid46_mp3], + ["MultiLingualLongform", time_multilingual] + ] + infer_time = pd.DataFrame(infer_time[1:], columns=infer_time[0]) + infer_time.to_csv(f"{results_dir}/infer_time.tsv", sep="\t", index=False) + + +if __name__ == '__main__': + args = parse_arguments() + eval_mp3 = True if args.eval_mp3 == "yes" else False + eval_multilingual = True if args.eval_multilingual == "yes" else False + flash_attention = True if args.flash_attention == "yes" else False + better_transformer = True if args.better_transformer == "yes" else False + + run(args.repo_path, args.backend, flash_attention=flash_attention, better_transformer=better_transformer, batch_size=args.batch_size, eval_mp3=eval_mp3, eval_multilingual=eval_multilingual) \ No newline at end of file diff --git a/whisper_s2t/__init__.py b/whisper_s2t/__init__.py index 42ab2ad..5845334 100644 --- a/whisper_s2t/__init__.py +++ b/whisper_s2t/__init__.py @@ -12,12 +12,25 @@ def load_model(model_identifier="large-v2", backend='CTranslate2', **model_kwargs): + if model_identifier in ['large-v3']: + model_kwargs['n_mels'] = 128 + elif (model_identifier in ['distil-large-v2']) and (backend.lower() not in ["huggingface", "hf"]): + print(f"Switching backend to HuggingFace. Distill whisper is only supported with HuggingFace backend.") + backend = "huggingface" + + model_kwargs['max_speech_len'] = 15.0 + model_kwargs['max_text_token_len'] = 128 + if backend.lower() in ["ctranslate2", "ct2"]: from .backends.ctranslate2.model import WhisperModelCT2 as WhisperModel elif backend.lower() in ["huggingface", "hf"]: from .backends.huggingface.model import WhisperModelHF as WhisperModel - model_identifier = f"openai/whisper-{model_identifier}" + + if 'distil' in model_identifier: + model_identifier = f"distil-whisper/{model_identifier}" + else: + model_identifier = f"openai/whisper-{model_identifier}" elif backend.lower() in ["openai", "oai"]: from .backends.openai.model import WhisperModelOAI as WhisperModel diff --git a/whisper_s2t/backends/__init__.py b/whisper_s2t/backends/__init__.py index d59cd71..df5ebba 100644 --- a/whisper_s2t/backends/__init__.py +++ b/whisper_s2t/backends/__init__.py @@ -55,6 +55,7 @@ def __init__(self, self.vad_model = vad_model self.speech_segmenter_options = speech_segmenter_options + self.speech_segmenter_options['max_seg_len'] = self.max_speech_len # Tokenizer if tokenizer is None: diff --git a/whisper_s2t/backends/ctranslate2/hf_utils.py b/whisper_s2t/backends/ctranslate2/hf_utils.py index 6d03feb..f494387 100644 --- a/whisper_s2t/backends/ctranslate2/hf_utils.py +++ b/whisper_s2t/backends/ctranslate2/hf_utils.py @@ -14,17 +14,18 @@ _MODELS = { - "tiny.en": "guillaumekln/faster-whisper-tiny.en", - "tiny": "guillaumekln/faster-whisper-tiny", - "base.en": "guillaumekln/faster-whisper-base.en", - "base": "guillaumekln/faster-whisper-base", - "small.en": "guillaumekln/faster-whisper-small.en", - "small": "guillaumekln/faster-whisper-small", - "medium.en": "guillaumekln/faster-whisper-medium.en", - "medium": "guillaumekln/faster-whisper-medium", - "large-v1": "guillaumekln/faster-whisper-large-v1", - "large-v2": "guillaumekln/faster-whisper-large-v2", - "large": "guillaumekln/faster-whisper-large-v2", + "tiny.en": "Systran/faster-whisper-tiny.en", + "tiny": "Systran/faster-whisper-tiny", + "base.en": "Systran/faster-whisper-base.en", + "base": "Systran/faster-whisper-base", + "small.en": "Systran/faster-whisper-small.en", + "small": "Systran/faster-whisper-small", + "medium.en": "Systran/faster-whisper-medium.en", + "medium": "Systran/faster-whisper-medium", + "large-v1": "Systran/faster-whisper-large-v1", + "large-v2": "Systran/faster-whisper-large-v2", + "large-v3": "Systran/faster-whisper-large-v3", + "large": "Systran/faster-whisper-large-v3", }