diff --git a/common/speech/lasr_speech_recognition_interfaces/README.md b/common/speech/lasr_speech_recognition_interfaces/README.md
index c96279cb..8e7aab96 100644
--- a/common/speech/lasr_speech_recognition_interfaces/README.md
+++ b/common/speech/lasr_speech_recognition_interfaces/README.md
@@ -3,17 +3,18 @@
Common messages used for speech recognition
This package is maintained by:
+
- [Maayan Armony](mailto:maayan.armony@gmail.com)
- [Paul Makles](mailto:me@insrt.uk) (ROS1)
## Prerequisites
This package depends on the following ROS packages:
+
- colcon (buildtool)
- message_generation (build)
- message_runtime (exec)
-
## Usage
Ask the package maintainer to write a `doc/USAGE.md` for their package!
@@ -36,11 +37,10 @@ This package has no launch files.
#### `Transcription`
-| Field | Type | Description |
-|:-:|:-:|---|
-| phrase | string | |
-| finished | bool | |
-
+| Field | Type | Description |
+|:--------:|:------:|-------------|
+| phrase | string | |
+| finished | bool | |
### Services
diff --git a/common/speech/lasr_speech_recognition_interfaces/package.xml b/common/speech/lasr_speech_recognition_interfaces/package.xml
index b15638eb..fd72011b 100644
--- a/common/speech/lasr_speech_recognition_interfaces/package.xml
+++ b/common/speech/lasr_speech_recognition_interfaces/package.xml
@@ -1,23 +1,23 @@
- lasr_speech_recognition_interfaces
- 0.0.0
- Common messages used for speech recognition
- maayan
- MIT
+ lasr_speech_recognition_interfaces
+ 0.0.0
+ Common messages used for speech recognition
+ maayan
+ MIT
- ament_cmake
-
- rosidl_default_generators
- action_msgs
- rosidl_default_runtime
- rosidl_interface_packages
+ ament_cmake
+
+ rosidl_default_generators
+ action_msgs
+ rosidl_default_runtime
+ rosidl_interface_packages
- ament_lint_auto
- ament_lint_common
+ ament_lint_auto
+ ament_lint_common
-
- ament_cmake
-
+
+ ament_cmake
+
diff --git a/common/speech/lasr_speech_recognition_whisper/README.md b/common/speech/lasr_speech_recognition_whisper/README.md
index 9954290e..c9f58557 100644
--- a/common/speech/lasr_speech_recognition_whisper/README.md
+++ b/common/speech/lasr_speech_recognition_whisper/README.md
@@ -3,18 +3,21 @@
Speech recognition implemented using OpenAI Whisper
This package is maintained by:
+
- [Maayan Armony](mailto:maayan.armony@gmail.com)
- [Paul Makles](mailto:me@insrt.uk) (ROS1)
## Prerequisites
This package depends on the following ROS packages:
+
- colcon (buildtool)
- lasr_speech_recognition_interfaces
This packages requires Python 3.10 to be present.
This package has 48 Python dependencies:
+
- [SpeechRecognition](https://pypi.org/project/SpeechRecognition)==3.10.0
- [openai-whisper](https://pypi.org/project/openai-whisper)==20230314
- [PyAudio](https://pypi.org/project/PyAudio)==0.2.13
@@ -64,15 +67,18 @@ This package does speech recognition in three parts:
- Adjusting for background noise
- We wait for a set period of time monitoring the audio stream to determine what we should ignore when collecting voice data.
+ We wait for a set period of time monitoring the audio stream to determine what we should ignore when collecting voice
+ data.
- Collecting appropriate voice data for phrases
- We use the `SpeechRecognition` package to monitor the input audio stream and determine when a person is actually speaking with enough energy that we would consider them to be speaking to the robot.
+ We use the `SpeechRecognition` package to monitor the input audio stream and determine when a person is actually
+ speaking with enough energy that we would consider them to be speaking to the robot.
- Running inference on phrases
- We continuously combine segments of the spoken phrase to form a sample until a certain timeout or threshold after which the phrase ends. This sample is sent to a local OpenAI Whisper model to transcribe.
+ We continuously combine segments of the spoken phrase to form a sample until a certain timeout or threshold after
+ which the phrase ends. This sample is sent to a local OpenAI Whisper model to transcribe.
The package can input from the following sources:
diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py
index 960bd599..7b3b1f8a 100644
--- a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py
+++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/simple_transcribe_microphone.py
@@ -11,28 +11,31 @@
import sounddevice # needed to remove ALSA error messages
from lasr_speech_recognition_interfaces.srv import TranscribeAudio
-from src import ModelCache # type: ignore
+from src import ModelCache # type: ignore
-MODEL = "medium.en" # Whisper model
-TIMEOUT = 5.0 # Timeout for listening for the start of a phrase
-PHRASE_TIME_LIMIT = None # Timeout for listening for the end of a phrase
+MODEL = "medium.en" # Whisper model
+TIMEOUT = 5.0 # Timeout for listening for the start of a phrase
+PHRASE_TIME_LIMIT = None # Timeout for listening for the end of a phrase
-WHISPER_CACHE = os.path.join(str(Path.home()), '.cache', 'whisper')
+WHISPER_CACHE = os.path.join(str(Path.home()), ".cache", "whisper")
os.makedirs(WHISPER_CACHE, exist_ok=True)
os.environ["TIKTOKEN_CACHE_DIR"] = WHISPER_CACHE
if len(sys.argv) < 3:
- print('Usage:')
- print('ros2 run lasr_speech_recognition transcribe_microphone by-index ')
- print('ros2 run lasr_speech_recognition transcribe_microphone by-name ')
+ print("Usage:")
+ print(
+ "ros2 run lasr_speech_recognition transcribe_microphone by-index "
+ )
+ print("ros2 run lasr_speech_recognition transcribe_microphone by-name ")
exit(1)
else:
matcher = sys.argv[1]
device_index = None
- if matcher == 'by-index':
+ if matcher == "by-index":
device_index = int(sys.argv[2])
- elif matcher == 'by-name':
+ elif matcher == "by-name":
import speech_recognition as sr
+
microphones = enumerate(sr.Microphone.list_microphone_names())
target_name = sys.argv[2]
@@ -40,16 +43,16 @@
if target_name in name:
device_index = index
break
-
+
if device_index is None:
- print('Could not find device!')
+ print("Could not find device!")
exit(1)
else:
- print('Invalid matcher')
+ print("Invalid matcher")
exit(1)
rclpy.init(args=sys.argv)
-node = rclpy.create_node('transcribe_mic')
+node = rclpy.create_node("transcribe_mic")
device = "cuda" if torch.cuda.is_available() else "cpu"
model_cache = ModelCache()
@@ -57,9 +60,15 @@
# try to run inference on the example file
package_install = packages.get_package_prefix("lasr_speech_recognition_whisper")
-package_root = os.path.abspath(os.path.join(package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper"))
+package_root = os.path.abspath(
+ os.path.join(
+ package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper"
+ )
+)
example_fp = os.path.join(package_root, "test.m4a")
-node.get_logger().info("Running transcription on example file to ensure model is loaded...")
+node.get_logger().info(
+ "Running transcription on example file to ensure model is loaded..."
+)
transcription = model.transcribe(example_fp, fp16=torch.cuda.is_available())
node.get_logger().info(str(transcription))
@@ -68,16 +77,25 @@
with microphone as source:
r.adjust_for_ambient_noise(source)
+
def handle_transcribe_audio(_):
with microphone as source:
- wav_data = r.listen(source, timeout=TIMEOUT, phrase_time_limit=PHRASE_TIME_LIMIT).get_wav_data()
- float_data = np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order='C') / 32768.0
+ wav_data = r.listen(
+ source, timeout=TIMEOUT, phrase_time_limit=PHRASE_TIME_LIMIT
+ ).get_wav_data()
+ float_data = (
+ np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C")
+ / 32768.0
+ )
phrase = model.transcribe(float_data, fp16=device == "cuda")["text"]
return TranscribeAudio.Response(phrase=phrase)
-node.create_service(TranscribeAudio, '/whisper/transcribe_audio', handle_transcribe_audio)
+
+node.create_service(
+ TranscribeAudio, "/whisper/transcribe_audio", handle_transcribe_audio
+)
node.get_logger().info("Whisper service ready")
-rclpy.spin(node)
\ No newline at end of file
+rclpy.spin(node)
diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py
index f8553b0b..3225072c 100644
--- a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py
+++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone.py
@@ -10,36 +10,42 @@
from std_srvs.srv import Empty
from src import SpeechRecognitionToTopic, MicrophonePhraseCollector, ModelCache
-WHISPER_CACHE = os.path.join(str(Path.home()), '.cache', 'whisper')
+WHISPER_CACHE = os.path.join(str(Path.home()), ".cache", "whisper")
os.makedirs(WHISPER_CACHE, exist_ok=True)
os.environ["TIKTOKEN_CACHE_DIR"] = WHISPER_CACHE
+
class TranscribeMicrophone(Node):
def __init__(self):
- Node.__init__(self, 'transcribe_microphone')
+ Node.__init__(self, "transcribe_microphone")
self.worker = None
self.collector = None
- self.create_service(Empty, '/whisper/adjust_for_noise', self.adjust_for_noise)
- self.create_service(Empty, '/whisper/start_listening', self.start_listening)
- self.create_service(Empty, '/whisper/stop_listening', self.stop_listening)
+ self.create_service(Empty, "/whisper/adjust_for_noise", self.adjust_for_noise)
+ self.create_service(Empty, "/whisper/start_listening", self.start_listening)
+ self.create_service(Empty, "/whisper/stop_listening", self.stop_listening)
self.get_logger().info("Starting the Whisper worker!")
self.run_transcription()
def run_transcription(self):
if len(sys.argv) < 3:
- print('Usage:')
- print('rosrun lasr_speech_recognition transcribe_microphone by-index ')
- print('rosrun lasr_speech_recognition transcribe_microphone by-name ')
+ print("Usage:")
+ print(
+ "rosrun lasr_speech_recognition transcribe_microphone by-index "
+ )
+ print(
+ "rosrun lasr_speech_recognition transcribe_microphone by-name "
+ )
exit(1)
else:
matcher = sys.argv[1]
device_index = None
- if matcher == 'by-index':
+ if matcher == "by-index":
device_index = int(sys.argv[2])
- elif matcher == 'by-name':
+ elif matcher == "by-name":
import speech_recognition as sr
+
microphones = enumerate(sr.Microphone.list_microphone_names())
target_name = sys.argv[2]
@@ -49,13 +55,12 @@ def run_transcription(self):
break
if device_index is None:
- print('Could not find device!')
+ print("Could not find device!")
exit(1)
else:
- print('Invalid matcher')
+ print("Invalid matcher")
exit(1)
-
self.collector = MicrophonePhraseCollector(device_index=device_index)
self.collector.adjust_for_noise()
@@ -64,14 +69,24 @@ def run_transcription(self):
# try to run inference on the example file
package_install = packages.get_package_prefix("lasr_speech_recognition_whisper")
- package_root = os.path.abspath(os.path.join(package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper"))
+ package_root = os.path.abspath(
+ os.path.join(
+ package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper"
+ )
+ )
example_fp = os.path.join(package_root, "test.m4a")
- self.get_logger().info("Running transcription on example file to ensure model is loaded...")
- model_transcription = model.transcribe(example_fp, fp16=torch.cuda.is_available())
+ self.get_logger().info(
+ "Running transcription on example file to ensure model is loaded..."
+ )
+ model_transcription = model.transcribe(
+ example_fp, fp16=torch.cuda.is_available()
+ )
self.get_logger().info(str(model_transcription))
- self.worker = SpeechRecognitionToTopic(self.collector, model, "transcription", infer_partial = False)
+ self.worker = SpeechRecognitionToTopic(
+ self.collector, model, "transcription", infer_partial=False
+ )
def adjust_for_noise(self, request, response):
self.collector.adjust_for_noise()
@@ -85,11 +100,12 @@ def stop_listening(self, request, response):
self.worker.stop()
return response
+
def main(args=None):
rclpy.init(args=args)
transcribe_microphone = TranscribeMicrophone()
rclpy.spin(transcribe_microphone)
-if __name__ == '__main__':
- main()
\ No newline at end of file
+if __name__ == "__main__":
+ main()
diff --git a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py
index e7c307f5..8adf3bd8 100644
--- a/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py
+++ b/common/speech/lasr_speech_recognition_whisper/lasr_speech_recognition_whisper/transcribe_microphone_server.py
@@ -18,10 +18,11 @@
from lasr_speech_recognition_interfaces.action import TranscribeSpeech # type: ignore
from rclpy.executors import ExternalShutdownException
from std_msgs.msg import String # type: ignore
-from src import ModelCache # type: ignore
+from src import ModelCache # type: ignore
# TODO: argpars -> ROS2 params, test behaviour of preemption
+
@dataclass
class speech_model_params:
"""Class for storing speech recognition model parameters.
@@ -58,9 +59,9 @@ class TranscribeSpeechAction(Node):
_result = TranscribeSpeech.Result()
def __init__(
- self,
- action_name: str,
- model_params: speech_model_params,
+ self,
+ action_name: str,
+ model_params: speech_model_params,
) -> None:
"""Starts an action server for transcribing speech.
@@ -126,9 +127,9 @@ def _configure_microphone(self) -> sr.Microphone:
)
def _configure_recogniser(
- self,
- energy_threshold: Optional[float] = None,
- pause_threshold: Optional[float] = None,
+ self,
+ energy_threshold: Optional[float] = None,
+ pause_threshold: Optional[float] = None,
) -> sr.Recognizer:
"""Configures the speech recogniser object.
@@ -212,8 +213,8 @@ async def execute_cb(self, goal_handle) -> None:
).get_wav_data()
# Magic number 32768.0 is the maximum value of a 16-bit signed integer
float_data = (
- np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C")
- / 32768.0
+ np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C")
+ / 32768.0
)
if goal_handle.is_cancel_requested():
@@ -265,7 +266,6 @@ def parse_args() -> dict:
# port = node.declare_parameter('port', '/dev/ttyUSB0').value
# assert isinstance(port, str), 'port parameter must be a str'
-
parser.add_argument(
"--action_name",
type=str,
@@ -372,6 +372,7 @@ def configure_whisper_cache() -> None:
# Environmental variable required to run whisper locally
os.environ["TIKTOKEN_CACHE_DIR"] = whisper_cache
+
def main(args=None):
rclpy.init(args=args)
@@ -383,4 +384,4 @@ def main(args=None):
try:
rclpy.spin(server)
except (KeyboardInterrupt, ExternalShutdownException):
- pass
\ No newline at end of file
+ pass
diff --git a/common/speech/lasr_speech_recognition_whisper/package.xml b/common/speech/lasr_speech_recognition_whisper/package.xml
index 825aae03..1cac4761 100644
--- a/common/speech/lasr_speech_recognition_whisper/package.xml
+++ b/common/speech/lasr_speech_recognition_whisper/package.xml
@@ -1,30 +1,30 @@
- lasr_speech_recognition_whisper
- 0.0.0
- Speech recognition implemented using OpenAI Whisper
- maayan
- MIT
+ lasr_speech_recognition_whisper
+ 0.0.0
+ Speech recognition implemented using OpenAI Whisper
+ maayan
+ MIT
- ament_copyright
- ament_flake8
- ament_pep257
- python3-pytest
+ ament_copyright
+ ament_flake8
+ ament_pep257
+ python3-pytest
-
-
- lasr_speech_recognition_interfaces
- actionlib
- actionlib_msgs
- actionlib
- actionlib_msgs
+
+
+ lasr_speech_recognition_interfaces
+ actionlib
+ actionlib_msgs
+ actionlib
+ actionlib_msgs
-
- ament_python
- requirements.txt
-
+
+ ament_python
+ requirements.txt
+
diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py b/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py
index 16ca35d2..a3ce2190 100755
--- a/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py
+++ b/common/speech/lasr_speech_recognition_whisper/scripts/list_microphones.py
@@ -14,5 +14,6 @@ def main():
# print("Available microphone devices (sounddevice):")
# print(sounddevice.query_devices())
-if __name__ == '__main__':
- main()
\ No newline at end of file
+
+if __name__ == "__main__":
+ main()
diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py b/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py
index c7739646..026ab287 100755
--- a/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py
+++ b/common/speech/lasr_speech_recognition_whisper/scripts/microphone_tuning_test.py
@@ -5,16 +5,19 @@
import numpy as np
from pathlib import Path
import speech_recognition as sr
-from src import ModelCache # type: ignore
+from src import ModelCache # type: ignore
import sounddevice # needed to remove ALSA error messages
from typing import Dict
import rclpy
# TODO argparse -> ROS params
+
def parse_args() -> Dict:
parser = argparse.ArgumentParser()
- parser.add_argument("--device_index", help="Microphone index", type=int, default=None)
+ parser.add_argument(
+ "--device_index", help="Microphone index", type=int, default=None
+ )
return vars(parser.parse_args())
@@ -67,6 +70,7 @@ def main(args=None):
threshold += 100
recognizer.energy_threshold = threshold
+
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py
index e0c94c23..d14144e2 100755
--- a/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py
+++ b/common/speech/lasr_speech_recognition_whisper/scripts/test_microphones.py
@@ -8,6 +8,7 @@
# TODO argparse -> ROS params
+
def parse_args() -> dict:
"""Parse command line arguments into a dictionary.
@@ -16,7 +17,9 @@ def parse_args() -> dict:
"""
parser = argparse.ArgumentParser(description="Test microphones")
- parser.add_argument("-m", "--microphone", type=int, help="Microphone index", default=None)
+ parser.add_argument(
+ "-m", "--microphone", type=int, help="Microphone index", default=None
+ )
parser.add_argument(
"-o", "--output_dir", type=str, help="Directory to save audio files"
)
@@ -64,5 +67,6 @@ def main(args: dict = None) -> None:
rclpy.shutdown()
+
if __name__ == "__main__":
main()
diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py
index d0670845..2448e73e 100755
--- a/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py
+++ b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py
@@ -7,6 +7,7 @@
# https://docs.ros2.org/latest/api/rclpy/api/actions.html
+
class TestSpeechServerClient(Node):
def __init__(self):
Node.__init__(self, "listen_action_client")
@@ -20,8 +21,12 @@ def send_goal(self, goal):
self.client.wait_for_server()
self.get_logger().info("Server activated, sending goal...")
- self.goal_future = self.client.send_goal_async(goal, feedback_callback=self.feedback_cb) # Returns a Future instance when the goal request has been accepted or rejected.
- self.goal_future.add_done_callback(self.response_cb) # When received get response
+ self.goal_future = self.client.send_goal_async(
+ goal, feedback_callback=self.feedback_cb
+ ) # Returns a Future instance when the goal request has been accepted or rejected.
+ self.goal_future.add_done_callback(
+ self.response_cb
+ ) # When received get response
def feedback_cb(self, msg):
self.get_logger().info(f"Received feedback: {msg.feedback}")
@@ -33,13 +38,16 @@ def response_cb(self, future):
return
self.get_logger().info("Goal was accepted")
- self.result_future = handle.get_result_async() # Not using get_result() in cb, as can cause deadlock according to docs
+ self.result_future = (
+ handle.get_result_async()
+ ) # Not using get_result() in cb, as can cause deadlock according to docs
self.result_future.add_done_callback(self.result_cb)
def result_cb(self, future):
result = future.result().result
self.get_logger().info(f"Transcribed Speech: {result.sequence}")
+
def main(args=None):
rclpy.init(args=args)
while rclpy.ok():
@@ -54,5 +62,6 @@ def main(args=None):
client.destroy_node()
rclpy.shutdown()
+
if __name__ == "__main__":
main()
diff --git a/common/speech/lasr_speech_recognition_whisper/setup.cfg b/common/speech/lasr_speech_recognition_whisper/setup.cfg
index 5ec86217..1f6a5440 100644
--- a/common/speech/lasr_speech_recognition_whisper/setup.cfg
+++ b/common/speech/lasr_speech_recognition_whisper/setup.cfg
@@ -1,4 +1,4 @@
[develop]
-script_dir=$base/lib/lasr_speech_recognition_whisper
+script_dir = $base/lib/lasr_speech_recognition_whisper
[install]
-install_scripts=$base/lib/lasr_speech_recognition_whisper
+install_scripts = $base/lib/lasr_speech_recognition_whisper
diff --git a/common/speech/lasr_speech_recognition_whisper/setup.py b/common/speech/lasr_speech_recognition_whisper/setup.py
index 3fbac464..c6a80148 100644
--- a/common/speech/lasr_speech_recognition_whisper/setup.py
+++ b/common/speech/lasr_speech_recognition_whisper/setup.py
@@ -1,11 +1,11 @@
from setuptools import find_packages, setup
-package_name = 'lasr_speech_recognition_whisper'
+package_name = "lasr_speech_recognition_whisper"
setup(
name=package_name,
- version='0.0.0',
- packages=find_packages(exclude=['test']),
+ version="0.0.0",
+ packages=find_packages(exclude=["test"]),
# packages=[package_name, f"{package_name}.lasr_speech_recognition_whisper", f"{package_name}.src"],
# package_dir={
# '': '.',
@@ -14,26 +14,25 @@
# f"{package_name}.src": os.path.join(package_name, 'src'),
# },
data_files=[
- ('share/ament_index/resource_index/packages',
- ['resource/' + package_name]),
- ('share/' + package_name, ['package.xml']),
+ ("share/ament_index/resource_index/packages", ["resource/" + package_name]),
+ ("share/" + package_name, ["package.xml"]),
],
- install_requires=['setuptools'],
+ install_requires=["setuptools"],
zip_safe=True,
- maintainer='maayan',
- maintainer_email='maayan.armony@gmail.com',
- description='Speech recognition implemented using OpenAI Whisper',
- license='MIT',
- tests_require=['pytest'],
+ maintainer="maayan",
+ maintainer_email="maayan.armony@gmail.com",
+ description="Speech recognition implemented using OpenAI Whisper",
+ license="MIT",
+ tests_require=["pytest"],
entry_points={
- 'console_scripts': [
- 'transcribe_microphone_server = lasr_speech_recognition_whisper.transcribe_microphone_server:main',
- 'transcribe_microphone = lasr_speech_recognition_whisper.transcribe_microphone:main',
- 'simple_transcribe_microphone = lasr_speech_recognition_whisper.simple_transcribe_microphone:main',
- 'list_microphones = scripts.list_microphones:main',
- 'microphone_tuning_test = scripts.microphone_tuning_test:main',
- 'test_microphones = scripts.test_microphones:main',
- 'test_speech_server = scripts.test_speech_server:main',
+ "console_scripts": [
+ "transcribe_microphone_server = lasr_speech_recognition_whisper.transcribe_microphone_server:main",
+ "transcribe_microphone = lasr_speech_recognition_whisper.transcribe_microphone:main",
+ "simple_transcribe_microphone = lasr_speech_recognition_whisper.simple_transcribe_microphone:main",
+ "list_microphones = scripts.list_microphones:main",
+ "microphone_tuning_test = scripts.microphone_tuning_test:main",
+ "test_microphones = scripts.test_microphones:main",
+ "test_speech_server = scripts.test_speech_server:main",
],
},
)
diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py
index 69327473..372e2647 100644
--- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py
+++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/__init__.py
@@ -9,4 +9,4 @@
SpeechRecognitionToStdout,
SpeechRecognitionToTopic,
)
-from .cache import ModelCache
\ No newline at end of file
+from .cache import ModelCache
diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py
index 7a86f38e..259ffffa 100644
--- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py
+++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py
@@ -6,13 +6,13 @@
# Keep all loaded models in memory
MODEL_CACHE = {}
+
class ModelCache(Node):
def __init__(self):
- super().__init__('lasr_speech_recognition_whisper_cache')
+ super().__init__("lasr_speech_recognition_whisper_cache")
def load_model(
- self,
- name: str, device: str = "cpu", load_test_file: bool = False
+ self, name: str, device: str = "cpu", load_test_file: bool = False
) -> whisper.Whisper:
"""Loads a whisper model from disk, or from cache if it has already been loaded.
@@ -34,8 +34,17 @@ def load_model(
MODEL_CACHE[name] = whisper.load_model(name, device=device)
self.get_logger().info(f"Sucessfully loaded model {name} on {device}")
if load_test_file:
- package_install = packages.get_package_prefix("lasr_speech_recognition_whisper")
- package_root = os.path.abspath(os.path.join(package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper"))
+ package_install = packages.get_package_prefix(
+ "lasr_speech_recognition_whisper"
+ )
+ package_root = os.path.abspath(
+ os.path.join(
+ package_install,
+ os.pardir,
+ os.pardir,
+ "lasr_speech_recognition_whisper",
+ )
+ )
example_fp = os.path.join(package_root, "test.m4a")
self.get_logger().info(
"Running transcription on example file to ensure model is loaded..."
diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py
index 9edbc313..d8c5fbea 100644
--- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py
+++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/collector.py
@@ -5,6 +5,7 @@
from queue import Queue
from abc import ABC, abstractmethod
+
class AbstractPhraseCollector(ABC):
"""
Supertype holding a queue of audio data representing a phrase
diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py
index abcd0fd1..e405ca8c 100644
--- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py
+++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/source.py
@@ -9,6 +9,7 @@
# TODO rospy.wait_for_message()
+
class AudioTopic(sr.AudioSource, Node):
"""
Use a ROS topic as an AudioSource
@@ -21,7 +22,9 @@ def __init__(self, topic: str, chunk_size=1024) -> None:
Node.__init__(self, "source")
self._topic = topic
- self.subscription = self.create_subscription(AudioInfo, f"{topic}/audio_info", self.callback, 10)
+ self.subscription = self.create_subscription(
+ AudioInfo, f"{topic}/audio_info", self.callback, 10
+ )
# config: AudioInfo = rospy.wait_for_message(f"{topic}/audio_info", AudioInfo)
self.config = None # TODO test that this works
if self.config is not None:
@@ -48,7 +51,9 @@ def __enter__(self):
self.stream is None
), "This audio source is already inside a context manager"
self.stream = BytesFIFO(1024 * 10) # 10 kB buffer
- self._sub = self.node.create_subscription(AudioData, f"{self._topic}/audio", self._read)
+ self._sub = self.node.create_subscription(
+ AudioData, f"{self._topic}/audio", self._read
+ )
return self
def __exit__(self, exc_type, exc_value, traceback):
@@ -57,7 +62,9 @@ def __exit__(self, exc_type, exc_value, traceback):
"""
self.stream = None
- self.destroy_subscription(self._sub) # TODO behaviour, was self._sub.unregister()
+ self.destroy_subscription(
+ self._sub
+ ) # TODO behaviour, was self._sub.unregister()
def _read(self, msg: AudioData) -> None:
"""
diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py
index 99847557..43eac780 100644
--- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py
+++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/worker.py
@@ -39,7 +39,7 @@ def __init__(
maximum_phrase_length=timedelta(seconds=3),
infer_partial=True,
) -> None:
- Node.__init__(self, 'worker')
+ Node.__init__(self, "worker")
self._collector = collector
self._tmp_file = NamedTemporaryFile().name
self._model = model
diff --git a/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py b/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py
index 97a39196..ceffe896 100644
--- a/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py
+++ b/common/speech/lasr_speech_recognition_whisper/test/test_copyright.py
@@ -17,9 +17,11 @@
# Remove the `skip` decorator once the source file(s) have a copyright header
-@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.')
+@pytest.mark.skip(
+ reason="No copyright header has been placed in the generated source file."
+)
@pytest.mark.copyright
@pytest.mark.linter
def test_copyright():
- rc = main(argv=['.', 'test'])
- assert rc == 0, 'Found errors'
+ rc = main(argv=[".", "test"])
+ assert rc == 0, "Found errors"
diff --git a/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py b/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py
index 27ee1078..ee79f31a 100644
--- a/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py
+++ b/common/speech/lasr_speech_recognition_whisper/test/test_flake8.py
@@ -20,6 +20,6 @@
@pytest.mark.linter
def test_flake8():
rc, errors = main_with_errors(argv=[])
- assert rc == 0, \
- 'Found %d code style errors / warnings:\n' % len(errors) + \
- '\n'.join(errors)
+ assert rc == 0, "Found %d code style errors / warnings:\n" % len(
+ errors
+ ) + "\n".join(errors)
diff --git a/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py b/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py
index b234a384..a2c3deb8 100644
--- a/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py
+++ b/common/speech/lasr_speech_recognition_whisper/test/test_pep257.py
@@ -19,5 +19,5 @@
@pytest.mark.linter
@pytest.mark.pep257
def test_pep257():
- rc = main(argv=['.', 'test'])
- assert rc == 0, 'Found code style errors / warnings'
+ rc = main(argv=[".", "test"])
+ assert rc == 0, "Found code style errors / warnings"