From a9363e889fefaee71544138d7730a7971abb4921 Mon Sep 17 00:00:00 2001 From: Matt Barker <105945282+m-barker@users.noreply.github.com> Date: Sun, 9 Jun 2024 16:57:19 +0100 Subject: [PATCH] Microphone testing for noisy environments (#214) * fix: incorrect remapping keys for handover * fix: remove speech server in launch file * feat: add script to test speech energy thresholds * Update common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py Co-authored-by: Jared Swift --------- Co-authored-by: Jared Swift --- .../CMakeLists.txt | 1 + .../scripts/microphone_tuning_test.py | 64 +++++++++++++++++++ 2 files changed, 65 insertions(+) create mode 100644 common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py diff --git a/common/speech/lasr_speech_recognition_whisper/CMakeLists.txt b/common/speech/lasr_speech_recognition_whisper/CMakeLists.txt index a11465954..a62469d83 100644 --- a/common/speech/lasr_speech_recognition_whisper/CMakeLists.txt +++ b/common/speech/lasr_speech_recognition_whisper/CMakeLists.txt @@ -172,6 +172,7 @@ catkin_install_python(PROGRAMS scripts/test_microphones.py scripts/repeat_after_me.py scripts/test_speech_server.py + scripts/microphone_tuning_test.py DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} ) diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py b/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py new file mode 100644 index 000000000..806e801d0 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +import argparse +import os +import torch +import numpy as np +from pathlib import Path +import speech_recognition as sr +from lasr_speech_recognition_whisper import load_model # type: ignore +import sounddevice # needed to remove ALSA error messages +from typing import Dict + + +def parse_args() -> Dict: + parser = argparse.ArgumentParser() + parser.add_argument("--device_index", type=int, default=None) + return vars(parser.parse_args()) + + +def configure_whisper_cache() -> None: + """Configures the whisper cache directory.""" + whisper_cache = os.path.join(str(Path.home()), ".cache", "whisper") + os.makedirs(whisper_cache, exist_ok=True) + # Environemntal variable required to run whisper locally + os.environ["TIKTOKEN_CACHE_DIR"] = whisper_cache + + +def main(): + args = parse_args() + + recognizer = sr.Recognizer() + microphone = sr.Microphone(device_index=args["device_index"], sample_rate=16000) + threshold = 100 + recognizer.dynamic_energy_threshold = False + recognizer.energy_threshold = threshold + transcription_model = load_model( + "medium.en", "cuda" if torch.cuda.is_available() else "cpu", True + ) + transcription_result = "The quick brown fox jumps over the lazy dog." + while transcription_result != "": + print(f"Listening...") + with microphone as source: + wav_data = recognizer.listen(source).get_wav_data() + print(f"Processing...") + # Magic number 32768.0 is the maximum value of a 16-bit signed integer + float_data = ( + np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C") + / 32768.0 + ) + + # Cast to fp16 if using GPU + transcription_result = transcription_model.transcribe( + float_data, fp16=torch.cuda.is_available() + )["text"] + + print( + f"Transcription: {transcription_result} at energy threshold {recognizer.energy_threshold}" + ) + threshold += 100 + recognizer.energy_threshold = threshold + + +if __name__ == "__main__": + configure_whisper_cache() + main()