Skip to content

Commit

Permalink
refactor: Black format
Browse files Browse the repository at this point in the history
  • Loading branch information
maayan25 committed Nov 24, 2024
1 parent 4482710 commit 7f0dd55
Show file tree
Hide file tree
Showing 21 changed files with 224 additions and 147 deletions.
12 changes: 6 additions & 6 deletions common/speech/lasr_speech_recognition_interfaces/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
Common messages used for speech recognition

This package is maintained by:

- [Maayan Armony](mailto:maayan.armony@gmail.com)
- [Paul Makles](mailto:me@insrt.uk) (ROS1)

## Prerequisites

This package depends on the following ROS packages:

- colcon (buildtool)
- message_generation (build)
- message_runtime (exec)


## Usage

Ask the package maintainer to write a `doc/USAGE.md` for their package!
Expand All @@ -36,11 +37,10 @@ This package has no launch files.

#### `Transcription`

| Field | Type | Description |
|:-:|:-:|---|
| phrase | string | |
| finished | bool | |

| Field | Type | Description |
|:--------:|:------:|-------------|
| phrase | string | |
| finished | bool | |

### Services

Expand Down
32 changes: 16 additions & 16 deletions common/speech/lasr_speech_recognition_interfaces/package.xml
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>lasr_speech_recognition_interfaces</name>
<version>0.0.0</version>
<description>Common messages used for speech recognition</description>
<maintainer email="maayan.armony@gmail.com">maayan</maintainer>
<license>MIT</license>
<name>lasr_speech_recognition_interfaces</name>
<version>0.0.0</version>
<description>Common messages used for speech recognition</description>
<maintainer email="maayan.armony@gmail.com">maayan</maintainer>
<license>MIT</license>

<buildtool_depend>ament_cmake</buildtool_depend>
<!-- Required for actions, messages, and services -->
<buildtool_depend>rosidl_default_generators</buildtool_depend>
<depend>action_msgs</depend>
<exec_depend>rosidl_default_runtime</exec_depend>
<member_of_group>rosidl_interface_packages</member_of_group>
<buildtool_depend>ament_cmake</buildtool_depend>
<!-- Required for actions, messages, and services -->
<buildtool_depend>rosidl_default_generators</buildtool_depend>
<depend>action_msgs</depend>
<exec_depend>rosidl_default_runtime</exec_depend>
<member_of_group>rosidl_interface_packages</member_of_group>

<test_depend>ament_lint_auto</test_depend>
<test_depend>ament_lint_common</test_depend>
<test_depend>ament_lint_auto</test_depend>
<test_depend>ament_lint_common</test_depend>

<export>
<build_type>ament_cmake</build_type>
</export>
<export>
<build_type>ament_cmake</build_type>
</export>
</package>
12 changes: 9 additions & 3 deletions common/speech/lasr_speech_recognition_whisper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@
Speech recognition implemented using OpenAI Whisper

This package is maintained by:

- [Maayan Armony](mailto:maayan.armony@gmail.com)
- [Paul Makles](mailto:me@insrt.uk) (ROS1)

## Prerequisites

This package depends on the following ROS packages:

- colcon (buildtool)
- lasr_speech_recognition_interfaces

This packages requires Python 3.10 to be present.

This package has 48 Python dependencies:

- [SpeechRecognition](https://pypi.org/project/SpeechRecognition)==3.10.0
- [openai-whisper](https://pypi.org/project/openai-whisper)==20230314
- [PyAudio](https://pypi.org/project/PyAudio)==0.2.13
Expand Down Expand Up @@ -64,15 +67,18 @@ This package does speech recognition in three parts:

- Adjusting for background noise

We wait for a set period of time monitoring the audio stream to determine what we should ignore when collecting voice data.
We wait for a set period of time monitoring the audio stream to determine what we should ignore when collecting voice
data.

- Collecting appropriate voice data for phrases

We use the `SpeechRecognition` package to monitor the input audio stream and determine when a person is actually speaking with enough energy that we would consider them to be speaking to the robot.
We use the `SpeechRecognition` package to monitor the input audio stream and determine when a person is actually
speaking with enough energy that we would consider them to be speaking to the robot.

- Running inference on phrases

We continuously combine segments of the spoken phrase to form a sample until a certain timeout or threshold after which the phrase ends. This sample is sent to a local OpenAI Whisper model to transcribe.
We continuously combine segments of the spoken phrase to form a sample until a certain timeout or threshold after
which the phrase ends. This sample is sent to a local OpenAI Whisper model to transcribe.

The package can input from the following sources:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,55 +11,64 @@

import sounddevice # needed to remove ALSA error messages
from lasr_speech_recognition_interfaces.srv import TranscribeAudio
from src import ModelCache # type: ignore
from src import ModelCache # type: ignore

MODEL = "medium.en" # Whisper model
TIMEOUT = 5.0 # Timeout for listening for the start of a phrase
PHRASE_TIME_LIMIT = None # Timeout for listening for the end of a phrase
MODEL = "medium.en" # Whisper model
TIMEOUT = 5.0 # Timeout for listening for the start of a phrase
PHRASE_TIME_LIMIT = None # Timeout for listening for the end of a phrase

WHISPER_CACHE = os.path.join(str(Path.home()), '.cache', 'whisper')
WHISPER_CACHE = os.path.join(str(Path.home()), ".cache", "whisper")
os.makedirs(WHISPER_CACHE, exist_ok=True)
os.environ["TIKTOKEN_CACHE_DIR"] = WHISPER_CACHE

if len(sys.argv) < 3:
print('Usage:')
print('ros2 run lasr_speech_recognition transcribe_microphone by-index <device_index>')
print('ros2 run lasr_speech_recognition transcribe_microphone by-name <substring>')
print("Usage:")
print(
"ros2 run lasr_speech_recognition transcribe_microphone by-index <device_index>"
)
print("ros2 run lasr_speech_recognition transcribe_microphone by-name <substring>")
exit(1)
else:
matcher = sys.argv[1]
device_index = None
if matcher == 'by-index':
if matcher == "by-index":
device_index = int(sys.argv[2])
elif matcher == 'by-name':
elif matcher == "by-name":
import speech_recognition as sr

microphones = enumerate(sr.Microphone.list_microphone_names())

target_name = sys.argv[2]
for index, name in microphones:
if target_name in name:
device_index = index
break

if device_index is None:
print('Could not find device!')
print("Could not find device!")
exit(1)
else:
print('Invalid matcher')
print("Invalid matcher")
exit(1)

rclpy.init(args=sys.argv)
node = rclpy.create_node('transcribe_mic')
node = rclpy.create_node("transcribe_mic")

device = "cuda" if torch.cuda.is_available() else "cpu"
model_cache = ModelCache()
model = model_cache.load_model("medium.en", device=device)

# try to run inference on the example file
package_install = packages.get_package_prefix("lasr_speech_recognition_whisper")
package_root = os.path.abspath(os.path.join(package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper"))
package_root = os.path.abspath(
os.path.join(
package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper"
)
)
example_fp = os.path.join(package_root, "test.m4a")
node.get_logger().info("Running transcription on example file to ensure model is loaded...")
node.get_logger().info(
"Running transcription on example file to ensure model is loaded..."
)
transcription = model.transcribe(example_fp, fp16=torch.cuda.is_available())
node.get_logger().info(str(transcription))

Expand All @@ -68,16 +77,25 @@
with microphone as source:
r.adjust_for_ambient_noise(source)


def handle_transcribe_audio(_):
with microphone as source:

wav_data = r.listen(source, timeout=TIMEOUT, phrase_time_limit=PHRASE_TIME_LIMIT).get_wav_data()
float_data = np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order='C') / 32768.0
wav_data = r.listen(
source, timeout=TIMEOUT, phrase_time_limit=PHRASE_TIME_LIMIT
).get_wav_data()
float_data = (
np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C")
/ 32768.0
)

phrase = model.transcribe(float_data, fp16=device == "cuda")["text"]
return TranscribeAudio.Response(phrase=phrase)

node.create_service(TranscribeAudio, '/whisper/transcribe_audio', handle_transcribe_audio)

node.create_service(
TranscribeAudio, "/whisper/transcribe_audio", handle_transcribe_audio
)

node.get_logger().info("Whisper service ready")
rclpy.spin(node)
rclpy.spin(node)
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,42 @@
from std_srvs.srv import Empty
from src import SpeechRecognitionToTopic, MicrophonePhraseCollector, ModelCache

WHISPER_CACHE = os.path.join(str(Path.home()), '.cache', 'whisper')
WHISPER_CACHE = os.path.join(str(Path.home()), ".cache", "whisper")
os.makedirs(WHISPER_CACHE, exist_ok=True)
os.environ["TIKTOKEN_CACHE_DIR"] = WHISPER_CACHE


class TranscribeMicrophone(Node):
def __init__(self):
Node.__init__(self, 'transcribe_microphone')
Node.__init__(self, "transcribe_microphone")
self.worker = None
self.collector = None

self.create_service(Empty, '/whisper/adjust_for_noise', self.adjust_for_noise)
self.create_service(Empty, '/whisper/start_listening', self.start_listening)
self.create_service(Empty, '/whisper/stop_listening', self.stop_listening)
self.create_service(Empty, "/whisper/adjust_for_noise", self.adjust_for_noise)
self.create_service(Empty, "/whisper/start_listening", self.start_listening)
self.create_service(Empty, "/whisper/stop_listening", self.stop_listening)

self.get_logger().info("Starting the Whisper worker!")
self.run_transcription()

def run_transcription(self):
if len(sys.argv) < 3:
print('Usage:')
print('rosrun lasr_speech_recognition transcribe_microphone by-index <device_index>')
print('rosrun lasr_speech_recognition transcribe_microphone by-name <substring>')
print("Usage:")
print(
"rosrun lasr_speech_recognition transcribe_microphone by-index <device_index>"
)
print(
"rosrun lasr_speech_recognition transcribe_microphone by-name <substring>"
)
exit(1)
else:
matcher = sys.argv[1]
device_index = None
if matcher == 'by-index':
if matcher == "by-index":
device_index = int(sys.argv[2])
elif matcher == 'by-name':
elif matcher == "by-name":
import speech_recognition as sr

microphones = enumerate(sr.Microphone.list_microphone_names())

target_name = sys.argv[2]
Expand All @@ -49,13 +55,12 @@ def run_transcription(self):
break

if device_index is None:
print('Could not find device!')
print("Could not find device!")
exit(1)
else:
print('Invalid matcher')
print("Invalid matcher")
exit(1)


self.collector = MicrophonePhraseCollector(device_index=device_index)
self.collector.adjust_for_noise()

Expand All @@ -64,14 +69,24 @@ def run_transcription(self):

# try to run inference on the example file
package_install = packages.get_package_prefix("lasr_speech_recognition_whisper")
package_root = os.path.abspath(os.path.join(package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper"))
package_root = os.path.abspath(
os.path.join(
package_install, os.pardir, os.pardir, "lasr_speech_recognition_whisper"
)
)
example_fp = os.path.join(package_root, "test.m4a")

self.get_logger().info("Running transcription on example file to ensure model is loaded...")
model_transcription = model.transcribe(example_fp, fp16=torch.cuda.is_available())
self.get_logger().info(
"Running transcription on example file to ensure model is loaded..."
)
model_transcription = model.transcribe(
example_fp, fp16=torch.cuda.is_available()
)
self.get_logger().info(str(model_transcription))

self.worker = SpeechRecognitionToTopic(self.collector, model, "transcription", infer_partial = False)
self.worker = SpeechRecognitionToTopic(
self.collector, model, "transcription", infer_partial=False
)

def adjust_for_noise(self, request, response):
self.collector.adjust_for_noise()
Expand All @@ -85,11 +100,12 @@ def stop_listening(self, request, response):
self.worker.stop()
return response


def main(args=None):
rclpy.init(args=args)
transcribe_microphone = TranscribeMicrophone()
rclpy.spin(transcribe_microphone)


if __name__ == '__main__':
main()
if __name__ == "__main__":
main()
Loading

0 comments on commit 7f0dd55

Please sign in to comment.