diff --git a/common/speech/lasr_speech_recognition_interfaces/README.md b/common/speech/lasr_speech_recognition_interfaces/README.md index c96279cb..8e7aab96 100644 --- a/common/speech/lasr_speech_recognition_interfaces/README.md +++ b/common/speech/lasr_speech_recognition_interfaces/README.md @@ -3,17 +3,18 @@ Common messages used for speech recognition This package is maintained by: + - [Maayan Armony](mailto:maayan.armony@gmail.com) - [Paul Makles](mailto:me@insrt.uk) (ROS1) ## Prerequisites This package depends on the following ROS packages: + - colcon (buildtool) - message_generation (build) - message_runtime (exec) - ## Usage Ask the package maintainer to write a `doc/USAGE.md` for their package! @@ -36,11 +37,10 @@ This package has no launch files. #### `Transcription` -| Field | Type | Description | -|:-:|:-:|---| -| phrase | string | | -| finished | bool | | - +| Field | Type | Description | +|:--------:|:------:|-------------| +| phrase | string | | +| finished | bool | | ### Services diff --git a/common/speech/lasr_speech_recognition_interfaces/package.xml b/common/speech/lasr_speech_recognition_interfaces/package.xml index b15638eb..fd72011b 100644 --- a/common/speech/lasr_speech_recognition_interfaces/package.xml +++ b/common/speech/lasr_speech_recognition_interfaces/package.xml @@ -1,23 +1,23 @@ - lasr_speech_recognition_interfaces - 0.0.0 - Common messages used for speech recognition - maayan - MIT + lasr_speech_recognition_interfaces + 0.0.0 + Common messages used for speech recognition + maayan + MIT - ament_cmake - - rosidl_default_generators - action_msgs - rosidl_default_runtime - rosidl_interface_packages + ament_cmake + + rosidl_default_generators + action_msgs + rosidl_default_runtime + rosidl_interface_packages - ament_lint_auto - ament_lint_common + ament_lint_auto + ament_lint_common - - ament_cmake - + + ament_cmake + diff --git a/common/speech/lasr_speech_recognition_whisper/README.md b/common/speech/lasr_speech_recognition_whisper/README.md index 9954290e..c9f58557 100644 --- a/common/speech/lasr_speech_recognition_whisper/README.md +++ b/common/speech/lasr_speech_recognition_whisper/README.md @@ -3,18 +3,21 @@ Speech recognition implemented using OpenAI Whisper This package is maintained by: + - [Maayan Armony](mailto:maayan.armony@gmail.com) - [Paul Makles](mailto:me@insrt.uk) (ROS1) ## Prerequisites This package depends on the following ROS packages: + - colcon (buildtool) - lasr_speech_recognition_interfaces This packages requires Python 3.10 to be present. This package has 48 Python dependencies: + - [SpeechRecognition](https://pypi.org/project/SpeechRecognition)==3.10.0 - [openai-whisper](https://pypi.org/project/openai-whisper)==20230314 - [PyAudio](https://pypi.org/project/PyAudio)==0.2.13 @@ -64,15 +67,18 @@ This package does speech recognition in three parts: - Adjusting for background noise - We wait for a set period of time monitoring the audio stream to determine what we should ignore when collecting voice data. + We wait for a set period of time monitoring the audio stream to determine what we should ignore when collecting voice + data. - Collecting appropriate voice data for phrases - We use the `SpeechRecognition` package to monitor the input audio stream and determine when a person is actually speaking with enough energy that we would consider them to be speaking to the robot. + We use the `SpeechRecognition` package to monitor the input audio stream and determine when a person is actually + speaking with enough energy that we would consider them to be speaking to the robot. - Running inference on phrases - We continuously combine segments of the spoken phrase to form a sample until a certain timeout or threshold after which the phrase ends. This sample is sent to a local OpenAI Whisper model to transcribe. + We continuously combine segments of the spoken phrase to form a sample until a certain timeout or threshold after + which the phrase ends. This sample is sent to a local OpenAI Whisper model to transcribe. The package can input from the following sources: diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py index 960bd599..7b3b1f8a 100644 --- a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py +++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py @@ -11,28 +11,31 @@ import sounddevice # needed to remove ALSA error messages from lasr_speech_recognition_interfaces.srv import TranscribeAudio -from src import ModelCache # type: ignore +from src import ModelCache # type: ignore -MODEL = "medium.en" # Whisper model -TIMEOUT = 5.0 # Timeout for listening for the start of a phrase -PHRASE_TIME_LIMIT = None # Timeout for listening for the end of a phrase +MODEL = "medium.en" # Whisper model +TIMEOUT = 5.0 # Timeout for listening for the start of a phrase +PHRASE_TIME_LIMIT = None # Timeout for listening for the end of a phrase -WHISPER_CACHE = os.path.join(str(Path.home()), '.cache', 'whisper') +WHISPER_CACHE = os.path.join(str(Path.home()), ".cache", "whisper") os.makedirs(WHISPER_CACHE, exist_ok=True) os.environ["TIKTOKEN_CACHE_DIR"] = WHISPER_CACHE if len(sys.argv) < 3: - print('Usage:') - print('ros2 run lasr_speech_recognition transcribe_microphone by-index ') - print('ros2 run lasr_speech_recognition transcribe_microphone by-name ') + print("Usage:") + print( + "ros2 run lasr_speech_recognition transcribe_microphone by-index " + ) + print("ros2 run lasr_speech_recognition transcribe_microphone by-name ") exit(1) else: matcher = sys.argv[1] device_index = None - if matcher == 'by-index': + if matcher == "by-index": device_index = int(sys.argv[2]) - elif matcher == 'by-name': + elif matcher == "by-name": import speech_recognition as sr + microphones = enumerate(sr.Microphone.list_microphone_names()) target_name = sys.argv[2] @@ -40,16 +43,16 @@ if target_name in name: device_index = index break - + if device_index is None: - print('Could not find device!') + print("Could not find device!") exit(1) else: - print('Invalid matcher') + print("Invalid matcher") exit(1) rclpy.init(args=sys.argv) -node = rclpy.create_node('transcribe_mic') +node = rclpy.create_node("transcribe_mic") device = "cuda" if torch.cuda.is_available() else "cpu" model_cache = ModelCache() @@ -57,9 +60,15 @@ # try to run inference on the example file package_install = packages.get_package_prefix("lasr_speech_recognition_whisper") -package_root = os.path.abspath(os.path.join(package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper")) +package_root = os.path.abspath( + os.path.join( + package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper" + ) +) example_fp = os.path.join(package_root, "test.m4a") -node.get_logger().info("Running transcription on example file to ensure model is loaded...") +node.get_logger().info( + "Running transcription on example file to ensure model is loaded..." +) transcription = model.transcribe(example_fp, fp16=torch.cuda.is_available()) node.get_logger().info(str(transcription)) @@ -68,16 +77,25 @@ with microphone as source: r.adjust_for_ambient_noise(source) + def handle_transcribe_audio(_): with microphone as source: - wav_data = r.listen(source, timeout=TIMEOUT, phrase_time_limit=PHRASE_TIME_LIMIT).get_wav_data() - float_data = np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order='C') / 32768.0 + wav_data = r.listen( + source, timeout=TIMEOUT, phrase_time_limit=PHRASE_TIME_LIMIT + ).get_wav_data() + float_data = ( + np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C") + / 32768.0 + ) phrase = model.transcribe(float_data, fp16=device == "cuda")["text"] return TranscribeAudio.Response(phrase=phrase) -node.create_service(TranscribeAudio, '/whisper/transcribe_audio', handle_transcribe_audio) + +node.create_service( + TranscribeAudio, "/whisper/transcribe_audio", handle_transcribe_audio +) node.get_logger().info("Whisper service ready") -rclpy.spin(node) \ No newline at end of file +rclpy.spin(node) diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py index f8553b0b..3225072c 100644 --- a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py +++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py @@ -10,36 +10,42 @@ from std_srvs.srv import Empty from src import SpeechRecognitionToTopic, MicrophonePhraseCollector, ModelCache -WHISPER_CACHE = os.path.join(str(Path.home()), '.cache', 'whisper') +WHISPER_CACHE = os.path.join(str(Path.home()), ".cache", "whisper") os.makedirs(WHISPER_CACHE, exist_ok=True) os.environ["TIKTOKEN_CACHE_DIR"] = WHISPER_CACHE + class TranscribeMicrophone(Node): def __init__(self): - Node.__init__(self, 'transcribe_microphone') + Node.__init__(self, "transcribe_microphone") self.worker = None self.collector = None - self.create_service(Empty, '/whisper/adjust_for_noise', self.adjust_for_noise) - self.create_service(Empty, '/whisper/start_listening', self.start_listening) - self.create_service(Empty, '/whisper/stop_listening', self.stop_listening) + self.create_service(Empty, "/whisper/adjust_for_noise", self.adjust_for_noise) + self.create_service(Empty, "/whisper/start_listening", self.start_listening) + self.create_service(Empty, "/whisper/stop_listening", self.stop_listening) self.get_logger().info("Starting the Whisper worker!") self.run_transcription() def run_transcription(self): if len(sys.argv) < 3: - print('Usage:') - print('rosrun lasr_speech_recognition transcribe_microphone by-index ') - print('rosrun lasr_speech_recognition transcribe_microphone by-name ') + print("Usage:") + print( + "rosrun lasr_speech_recognition transcribe_microphone by-index " + ) + print( + "rosrun lasr_speech_recognition transcribe_microphone by-name " + ) exit(1) else: matcher = sys.argv[1] device_index = None - if matcher == 'by-index': + if matcher == "by-index": device_index = int(sys.argv[2]) - elif matcher == 'by-name': + elif matcher == "by-name": import speech_recognition as sr + microphones = enumerate(sr.Microphone.list_microphone_names()) target_name = sys.argv[2] @@ -49,13 +55,12 @@ def run_transcription(self): break if device_index is None: - print('Could not find device!') + print("Could not find device!") exit(1) else: - print('Invalid matcher') + print("Invalid matcher") exit(1) - self.collector = MicrophonePhraseCollector(device_index=device_index) self.collector.adjust_for_noise() @@ -64,14 +69,24 @@ def run_transcription(self): # try to run inference on the example file package_install = packages.get_package_prefix("lasr_speech_recognition_whisper") - package_root = os.path.abspath(os.path.join(package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper")) + package_root = os.path.abspath( + os.path.join( + package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper" + ) + ) example_fp = os.path.join(package_root, "test.m4a") - self.get_logger().info("Running transcription on example file to ensure model is loaded...") - model_transcription = model.transcribe(example_fp, fp16=torch.cuda.is_available()) + self.get_logger().info( + "Running transcription on example file to ensure model is loaded..." + ) + model_transcription = model.transcribe( + example_fp, fp16=torch.cuda.is_available() + ) self.get_logger().info(str(model_transcription)) - self.worker = SpeechRecognitionToTopic(self.collector, model, "transcription", infer_partial = False) + self.worker = SpeechRecognitionToTopic( + self.collector, model, "transcription", infer_partial=False + ) def adjust_for_noise(self, request, response): self.collector.adjust_for_noise() @@ -85,11 +100,12 @@ def stop_listening(self, request, response): self.worker.stop() return response + def main(args=None): rclpy.init(args=args) transcribe_microphone = TranscribeMicrophone() rclpy.spin(transcribe_microphone) -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py index e7c307f5..8adf3bd8 100644 --- a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py +++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py @@ -18,10 +18,11 @@ from lasr_speech_recognition_interfaces.action import TranscribeSpeech # type: ignore from rclpy.executors import ExternalShutdownException from std_msgs.msg import String # type: ignore -from src import ModelCache # type: ignore +from src import ModelCache # type: ignore # TODO: argpars -> ROS2 params, test behaviour of preemption + @dataclass class speech_model_params: """Class for storing speech recognition model parameters. @@ -58,9 +59,9 @@ class TranscribeSpeechAction(Node): _result = TranscribeSpeech.Result() def __init__( - self, - action_name: str, - model_params: speech_model_params, + self, + action_name: str, + model_params: speech_model_params, ) -> None: """Starts an action server for transcribing speech. @@ -126,9 +127,9 @@ def _configure_microphone(self) -> sr.Microphone: ) def _configure_recogniser( - self, - energy_threshold: Optional[float] = None, - pause_threshold: Optional[float] = None, + self, + energy_threshold: Optional[float] = None, + pause_threshold: Optional[float] = None, ) -> sr.Recognizer: """Configures the speech recogniser object. @@ -212,8 +213,8 @@ async def execute_cb(self, goal_handle) -> None: ).get_wav_data() # Magic number 32768.0 is the maximum value of a 16-bit signed integer float_data = ( - np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C") - / 32768.0 + np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C") + / 32768.0 ) if goal_handle.is_cancel_requested(): @@ -265,7 +266,6 @@ def parse_args() -> dict: # port = node.declare_parameter('port', '/dev/ttyUSB0').value # assert isinstance(port, str), 'port parameter must be a str' - parser.add_argument( "--action_name", type=str, @@ -372,6 +372,7 @@ def configure_whisper_cache() -> None: # Environmental variable required to run whisper locally os.environ["TIKTOKEN_CACHE_DIR"] = whisper_cache + def main(args=None): rclpy.init(args=args) @@ -383,4 +384,4 @@ def main(args=None): try: rclpy.spin(server) except (KeyboardInterrupt, ExternalShutdownException): - pass \ No newline at end of file + pass diff --git a/common/speech/lasr_speech_recognition_whisper/package.xml b/common/speech/lasr_speech_recognition_whisper/package.xml index 825aae03..1cac4761 100644 --- a/common/speech/lasr_speech_recognition_whisper/package.xml +++ b/common/speech/lasr_speech_recognition_whisper/package.xml @@ -1,30 +1,30 @@ - lasr_speech_recognition_whisper - 0.0.0 - Speech recognition implemented using OpenAI Whisper - maayan - MIT + lasr_speech_recognition_whisper + 0.0.0 + Speech recognition implemented using OpenAI Whisper + maayan + MIT - ament_copyright - ament_flake8 - ament_pep257 - python3-pytest + ament_copyright + ament_flake8 + ament_pep257 + python3-pytest - - - lasr_speech_recognition_interfaces - actionlib - actionlib_msgs - actionlib - actionlib_msgs + + + lasr_speech_recognition_interfaces + actionlib + actionlib_msgs + actionlib + actionlib_msgs - - ament_python - requirements.txt - + + ament_python + requirements.txt + diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py b/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py index 16ca35d2..a3ce2190 100755 --- a/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py +++ b/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py @@ -14,5 +14,6 @@ def main(): # print("Available microphone devices (sounddevice):") # print(sounddevice.query_devices()) -if __name__ == '__main__': - main() \ No newline at end of file + +if __name__ == "__main__": + main() diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py b/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py index c7739646..026ab287 100755 --- a/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py +++ b/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py @@ -5,16 +5,19 @@ import numpy as np from pathlib import Path import speech_recognition as sr -from src import ModelCache # type: ignore +from src import ModelCache # type: ignore import sounddevice # needed to remove ALSA error messages from typing import Dict import rclpy # TODO argparse -> ROS params + def parse_args() -> Dict: parser = argparse.ArgumentParser() - parser.add_argument("--device_index", help="Microphone index", type=int, default=None) + parser.add_argument( + "--device_index", help="Microphone index", type=int, default=None + ) return vars(parser.parse_args()) @@ -67,6 +70,7 @@ def main(args=None): threshold += 100 recognizer.energy_threshold = threshold + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py index e0c94c23..d14144e2 100755 --- a/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py +++ b/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py @@ -8,6 +8,7 @@ # TODO argparse -> ROS params + def parse_args() -> dict: """Parse command line arguments into a dictionary. @@ -16,7 +17,9 @@ def parse_args() -> dict: """ parser = argparse.ArgumentParser(description="Test microphones") - parser.add_argument("-m", "--microphone", type=int, help="Microphone index", default=None) + parser.add_argument( + "-m", "--microphone", type=int, help="Microphone index", default=None + ) parser.add_argument( "-o", "--output_dir", type=str, help="Directory to save audio files" ) @@ -64,5 +67,6 @@ def main(args: dict = None) -> None: rclpy.shutdown() + if __name__ == "__main__": main() diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py index d0670845..2448e73e 100755 --- a/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py +++ b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py @@ -7,6 +7,7 @@ # https://docs.ros2.org/latest/api/rclpy/api/actions.html + class TestSpeechServerClient(Node): def __init__(self): Node.__init__(self, "listen_action_client") @@ -20,8 +21,12 @@ def send_goal(self, goal): self.client.wait_for_server() self.get_logger().info("Server activated, sending goal...") - self.goal_future = self.client.send_goal_async(goal, feedback_callback=self.feedback_cb) # Returns a Future instance when the goal request has been accepted or rejected. - self.goal_future.add_done_callback(self.response_cb) # When received get response + self.goal_future = self.client.send_goal_async( + goal, feedback_callback=self.feedback_cb + ) # Returns a Future instance when the goal request has been accepted or rejected. + self.goal_future.add_done_callback( + self.response_cb + ) # When received get response def feedback_cb(self, msg): self.get_logger().info(f"Received feedback: {msg.feedback}") @@ -33,13 +38,16 @@ def response_cb(self, future): return self.get_logger().info("Goal was accepted") - self.result_future = handle.get_result_async() # Not using get_result() in cb, as can cause deadlock according to docs + self.result_future = ( + handle.get_result_async() + ) # Not using get_result() in cb, as can cause deadlock according to docs self.result_future.add_done_callback(self.result_cb) def result_cb(self, future): result = future.result().result self.get_logger().info(f"Transcribed Speech: {result.sequence}") + def main(args=None): rclpy.init(args=args) while rclpy.ok(): @@ -54,5 +62,6 @@ def main(args=None): client.destroy_node() rclpy.shutdown() + if __name__ == "__main__": main() diff --git a/common/speech/lasr_speech_recognition_whisper/setup.cfg b/common/speech/lasr_speech_recognition_whisper/setup.cfg index 5ec86217..1f6a5440 100644 --- a/common/speech/lasr_speech_recognition_whisper/setup.cfg +++ b/common/speech/lasr_speech_recognition_whisper/setup.cfg @@ -1,4 +1,4 @@ [develop] -script_dir=$base/lib/lasr_speech_recognition_whisper +script_dir = $base/lib/lasr_speech_recognition_whisper [install] -install_scripts=$base/lib/lasr_speech_recognition_whisper +install_scripts = $base/lib/lasr_speech_recognition_whisper diff --git a/common/speech/lasr_speech_recognition_whisper/setup.py b/common/speech/lasr_speech_recognition_whisper/setup.py index 3fbac464..c6a80148 100644 --- a/common/speech/lasr_speech_recognition_whisper/setup.py +++ b/common/speech/lasr_speech_recognition_whisper/setup.py @@ -1,11 +1,11 @@ from setuptools import find_packages, setup -package_name = 'lasr_speech_recognition_whisper' +package_name = "lasr_speech_recognition_whisper" setup( name=package_name, - version='0.0.0', - packages=find_packages(exclude=['test']), + version="0.0.0", + packages=find_packages(exclude=["test"]), # packages=[package_name, f"{package_name}.lasr_speech_recognition_whisper", f"{package_name}.src"], # package_dir={ # '': '.', @@ -14,26 +14,25 @@ # f"{package_name}.src": os.path.join(package_name, 'src'), # }, data_files=[ - ('share/ament_index/resource_index/packages', - ['resource/' + package_name]), - ('share/' + package_name, ['package.xml']), + ("share/ament_index/resource_index/packages", ["resource/" + package_name]), + ("share/" + package_name, ["package.xml"]), ], - install_requires=['setuptools'], + install_requires=["setuptools"], zip_safe=True, - maintainer='maayan', - maintainer_email='maayan.armony@gmail.com', - description='Speech recognition implemented using OpenAI Whisper', - license='MIT', - tests_require=['pytest'], + maintainer="maayan", + maintainer_email="maayan.armony@gmail.com", + description="Speech recognition implemented using OpenAI Whisper", + license="MIT", + tests_require=["pytest"], entry_points={ - 'console_scripts': [ - 'transcribe_microphone_server = lasr_speech_recognition_whisper.transcribe_microphone_server:main', - 'transcribe_microphone = lasr_speech_recognition_whisper.transcribe_microphone:main', - 'simple_transcribe_microphone = lasr_speech_recognition_whisper.simple_transcribe_microphone:main', - 'list_microphones = scripts.list_microphones:main', - 'microphone_tuning_test = scripts.microphone_tuning_test:main', - 'test_microphones = scripts.test_microphones:main', - 'test_speech_server = scripts.test_speech_server:main', + "console_scripts": [ + "transcribe_microphone_server = lasr_speech_recognition_whisper.transcribe_microphone_server:main", + "transcribe_microphone = lasr_speech_recognition_whisper.transcribe_microphone:main", + "simple_transcribe_microphone = lasr_speech_recognition_whisper.simple_transcribe_microphone:main", + "list_microphones = scripts.list_microphones:main", + "microphone_tuning_test = scripts.microphone_tuning_test:main", + "test_microphones = scripts.test_microphones:main", + "test_speech_server = scripts.test_speech_server:main", ], }, ) diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py index 69327473..372e2647 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py @@ -9,4 +9,4 @@ SpeechRecognitionToStdout, SpeechRecognitionToTopic, ) -from .cache import ModelCache \ No newline at end of file +from .cache import ModelCache diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py index 7a86f38e..259ffffa 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py @@ -6,13 +6,13 @@ # Keep all loaded models in memory MODEL_CACHE = {} + class ModelCache(Node): def __init__(self): - super().__init__('lasr_speech_recognition_whisper_cache') + super().__init__("lasr_speech_recognition_whisper_cache") def load_model( - self, - name: str, device: str = "cpu", load_test_file: bool = False + self, name: str, device: str = "cpu", load_test_file: bool = False ) -> whisper.Whisper: """Loads a whisper model from disk, or from cache if it has already been loaded. @@ -34,8 +34,17 @@ def load_model( MODEL_CACHE[name] = whisper.load_model(name, device=device) self.get_logger().info(f"Sucessfully loaded model {name} on {device}") if load_test_file: - package_install = packages.get_package_prefix("lasr_speech_recognition_whisper") - package_root = os.path.abspath(os.path.join(package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper")) + package_install = packages.get_package_prefix( + "lasr_speech_recognition_whisper" + ) + package_root = os.path.abspath( + os.path.join( + package_install, + os.pardir, + os.pardir, + "lasr_speech_recognition_whisper", + ) + ) example_fp = os.path.join(package_root, "test.m4a") self.get_logger().info( "Running transcription on example file to ensure model is loaded..." diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py index 9edbc313..d8c5fbea 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py @@ -5,6 +5,7 @@ from queue import Queue from abc import ABC, abstractmethod + class AbstractPhraseCollector(ABC): """ Supertype holding a queue of audio data representing a phrase diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py index abcd0fd1..e405ca8c 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py @@ -9,6 +9,7 @@ # TODO rospy.wait_for_message() + class AudioTopic(sr.AudioSource, Node): """ Use a ROS topic as an AudioSource @@ -21,7 +22,9 @@ def __init__(self, topic: str, chunk_size=1024) -> None: Node.__init__(self, "source") self._topic = topic - self.subscription = self.create_subscription(AudioInfo, f"{topic}/audio_info", self.callback, 10) + self.subscription = self.create_subscription( + AudioInfo, f"{topic}/audio_info", self.callback, 10 + ) # config: AudioInfo = rospy.wait_for_message(f"{topic}/audio_info", AudioInfo) self.config = None # TODO test that this works if self.config is not None: @@ -48,7 +51,9 @@ def __enter__(self): self.stream is None ), "This audio source is already inside a context manager" self.stream = BytesFIFO(1024 * 10) # 10 kB buffer - self._sub = self.node.create_subscription(AudioData, f"{self._topic}/audio", self._read) + self._sub = self.node.create_subscription( + AudioData, f"{self._topic}/audio", self._read + ) return self def __exit__(self, exc_type, exc_value, traceback): @@ -57,7 +62,9 @@ def __exit__(self, exc_type, exc_value, traceback): """ self.stream = None - self.destroy_subscription(self._sub) # TODO behaviour, was self._sub.unregister() + self.destroy_subscription( + self._sub + ) # TODO behaviour, was self._sub.unregister() def _read(self, msg: AudioData) -> None: """ diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py index 99847557..43eac780 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py @@ -39,7 +39,7 @@ def __init__( maximum_phrase_length=timedelta(seconds=3), infer_partial=True, ) -> None: - Node.__init__(self, 'worker') + Node.__init__(self, "worker") self._collector = collector self._tmp_file = NamedTemporaryFile().name self._model = model diff --git a/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py b/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py index 97a39196..ceffe896 100644 --- a/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py +++ b/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py @@ -17,9 +17,11 @@ # Remove the `skip` decorator once the source file(s) have a copyright header -@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.') +@pytest.mark.skip( + reason="No copyright header has been placed in the generated source file." +) @pytest.mark.copyright @pytest.mark.linter def test_copyright(): - rc = main(argv=['.', 'test']) - assert rc == 0, 'Found errors' + rc = main(argv=[".", "test"]) + assert rc == 0, "Found errors" diff --git a/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py b/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py index 27ee1078..ee79f31a 100644 --- a/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py +++ b/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py @@ -20,6 +20,6 @@ @pytest.mark.linter def test_flake8(): rc, errors = main_with_errors(argv=[]) - assert rc == 0, \ - 'Found %d code style errors / warnings:\n' % len(errors) + \ - '\n'.join(errors) + assert rc == 0, "Found %d code style errors / warnings:\n" % len( + errors + ) + "\n".join(errors) diff --git a/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py b/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py index b234a384..a2c3deb8 100644 --- a/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py +++ b/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py @@ -19,5 +19,5 @@ @pytest.mark.linter @pytest.mark.pep257 def test_pep257(): - rc = main(argv=['.', 'test']) - assert rc == 0, 'Found code style errors / warnings' + rc = main(argv=[".", "test"]) + assert rc == 0, "Found code style errors / warnings"