From 270479b396cff14abcaac67ec622c53931e52752 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Thu, 3 Oct 2024 17:18:44 -0700 Subject: [PATCH] Add streaming dependencies to unit tests Separate streaming dependencies from basic WS Refactor streaming client code into a separate module --- .github/workflows/unit_tests.yml | 2 +- Dockerfile | 2 +- neon_hana/mq_websocket_api.py | 97 +++----------------------------- neon_hana/streaming_client.py | 89 +++++++++++++++++++++++++++++ requirements/streaming.txt | 5 ++ requirements/websocket.txt | 5 -- setup.py | 3 +- 7 files changed, 106 insertions(+), 97 deletions(-) create mode 100644 neon_hana/streaming_client.py create mode 100644 requirements/streaming.txt diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 068ca91..08a253b 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -23,7 +23,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install . -r requirements/test_requirements.txt + pip install .[streaming] -r requirements/test_requirements.txt - name: Run Tests run: | pytest tests diff --git a/Dockerfile b/Dockerfile index c3d62ca..6392e63 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,6 +14,6 @@ COPY docker_overlay/ / WORKDIR /app COPY . /app -RUN pip install /app[websocket] +RUN pip install /app[websocket,streaming] CMD ["python3", "/app/neon_hana/app/__main__.py"] \ No newline at end of file diff --git a/neon_hana/mq_websocket_api.py b/neon_hana/mq_websocket_api.py index 40d8531..41d4beb 100644 --- a/neon_hana/mq_websocket_api.py +++ b/neon_hana/mq_websocket_api.py @@ -25,26 +25,15 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from asyncio import run, get_event_loop -from base64 import b64encode from os import makedirs from queue import Queue from time import time -from typing import Optional, Callable - +from typing import Optional from fastapi import WebSocket -from mock.mock import Mock from neon_iris.client import NeonAIClient from ovos_bus_client.message import Message -from threading import RLock, Thread - -from ovos_dinkum_listener.voice_loop import DinkumVoiceLoop -from ovos_dinkum_listener.voice_loop.hotwords import HotwordContainer -from ovos_dinkum_listener.voice_loop.voice_loop import ChunkInfo -from ovos_plugin_manager.templates.microphone import Microphone -from ovos_plugin_manager.vad import OVOSVADFactory +from threading import RLock from ovos_utils import LOG -from ovos_utils.fakebus import FakeBus -from speech_recognition import AudioData class MQWebsocketAPI(NeonAIClient): @@ -79,7 +68,7 @@ def end_session(self, session_id: str): if not session: LOG.error(f"Ended session is not established {session_id}") return - stream: RemoteStreamHandler = session.get('stream') + stream = session.get('stream') if stream: stream.shutdown() stream.join() @@ -142,6 +131,7 @@ def _update_session_data(self, message: Message): self._sessions[session_id]['user'] = user_config def handle_audio_stream(self, audio: bytes, session_id: str): + from neon_hana.streaming_client import RemoteStreamHandler, StreamMicrophone if not self._sessions[session_id].get('stream'): LOG.info(f"starting stream for session {session_id}") audio_queue = Queue() @@ -184,6 +174,10 @@ def handle_klat_response(self, message: Message): """ self._update_session_data(message) run(self.send_to_client(message)) + session_id = message.context.get('session', {}).get('session_id') + if self._sessions.get(session_id, {}).get('stream'): + # TODO: stream response audio to streaming socket + pass LOG.debug(message.context.get("timing")) def handle_complete_intent_failure(self, message: Message): @@ -249,78 +243,3 @@ def shutdown(self, *_, **__): loop.call_soon_threadsafe(loop.stop) LOG.info("Stopped Event Loop") super().shutdown() - - -class StreamMicrophone(Microphone): - def __init__(self, queue: Queue): - self.queue = queue - - def start(self): - pass - - def stop(self): - self.queue.put(None) - - def read_chunk(self) -> Optional[bytes]: - return self.queue.get() - - -class RemoteStreamHandler(Thread): - def __init__(self, mic: StreamMicrophone, session_id: str, - audio_callback: Callable, - ww_callback: Callable, lang: str = "en-us"): - Thread.__init__(self) - self.session_id = session_id - self.ww_callback = ww_callback - self.audio_callback = audio_callback - self.bus = FakeBus() - self.mic = mic - self.lang = lang - self.hotwords = HotwordContainer(self.bus) - self.hotwords.load_hotword_engines() - self.vad = OVOSVADFactory.create() - self.voice_loop = DinkumVoiceLoop(mic=self.mic, - vad=self.vad, - hotwords=self.hotwords, - listenword_audio_callback=self.on_hotword, - hotword_audio_callback=self.on_hotword, - stopword_audio_callback=self.on_hotword, - wakeupword_audio_callback=self.on_hotword, - stt_audio_callback=self.on_audio, - stt=Mock(transcribe=Mock(return_value=[])), - fallback_stt=Mock(transcribe=Mock(return_value=[])), - transformers=MockTransformers(), - chunk_callback=self.on_chunk, - speech_seconds=0.5, - num_hotword_keep_chunks=0, - num_stt_rewind_chunks=0) - - def run(self): - self.voice_loop.start() - self.voice_loop.run() - - def on_hotword(self, audio_bytes: bytes, context: dict): - self.lang = context.get("stt_lang") or self.lang - LOG.info(f"Hotword: {context}") - self.ww_callback(context, self.session_id) - - def on_audio(self, audio_bytes: bytes, context: dict): - LOG.info(f"Audio: {context}") - audio_data = AudioData(audio_bytes, self.mic.sample_rate, - self.mic.sample_width).get_wav_data() - audio_data = b64encode(audio_data).decode("utf-8") - callback_data = {"type": "neon.audio_input", - "data": {"audio_data": audio_data, "lang": self.lang}} - self.audio_callback(callback_data, self.session_id) - - def on_chunk(self, chunk: ChunkInfo): - LOG.debug(f"Chunk: {chunk}") - - def shutdown(self): - self.mic.stop() - self.voice_loop.stop() - - -class MockTransformers(Mock): - def transform(self, chunk): - return chunk, dict() diff --git a/neon_hana/streaming_client.py b/neon_hana/streaming_client.py new file mode 100644 index 0000000..67a5b7f --- /dev/null +++ b/neon_hana/streaming_client.py @@ -0,0 +1,89 @@ +from base64 import b64encode +from typing import Optional, Callable +from mock.mock import Mock +from threading import Thread +from queue import Queue + +from ovos_dinkum_listener.voice_loop import DinkumVoiceLoop +from ovos_dinkum_listener.voice_loop.hotwords import HotwordContainer +from ovos_dinkum_listener.voice_loop.voice_loop import ChunkInfo +from ovos_plugin_manager.templates.microphone import Microphone +from ovos_plugin_manager.vad import OVOSVADFactory +from ovos_utils.fakebus import FakeBus +from speech_recognition import AudioData +from ovos_utils import LOG + + +class StreamMicrophone(Microphone): + def __init__(self, queue: Queue): + self.queue = queue + + def start(self): + pass + + def stop(self): + self.queue.put(None) + + def read_chunk(self) -> Optional[bytes]: + return self.queue.get() + + +class RemoteStreamHandler(Thread): + def __init__(self, mic: StreamMicrophone, session_id: str, + audio_callback: Callable, + ww_callback: Callable, lang: str = "en-us"): + Thread.__init__(self) + self.session_id = session_id + self.ww_callback = ww_callback + self.audio_callback = audio_callback + self.bus = FakeBus() + self.mic = mic + self.lang = lang + self.hotwords = HotwordContainer(self.bus) + self.hotwords.load_hotword_engines() + self.vad = OVOSVADFactory.create() + self.voice_loop = DinkumVoiceLoop(mic=self.mic, + vad=self.vad, + hotwords=self.hotwords, + listenword_audio_callback=self.on_hotword, + hotword_audio_callback=self.on_hotword, + stopword_audio_callback=self.on_hotword, + wakeupword_audio_callback=self.on_hotword, + stt_audio_callback=self.on_audio, + stt=Mock(transcribe=Mock(return_value=[])), + fallback_stt=Mock(transcribe=Mock(return_value=[])), + transformers=MockTransformers(), + chunk_callback=self.on_chunk, + speech_seconds=0.5, + num_hotword_keep_chunks=0, + num_stt_rewind_chunks=0) + + def run(self): + self.voice_loop.start() + self.voice_loop.run() + + def on_hotword(self, audio_bytes: bytes, context: dict): + self.lang = context.get("stt_lang") or self.lang + LOG.info(f"Hotword: {context}") + self.ww_callback(context, self.session_id) + + def on_audio(self, audio_bytes: bytes, context: dict): + LOG.info(f"Audio: {context}") + audio_data = AudioData(audio_bytes, self.mic.sample_rate, + self.mic.sample_width).get_wav_data() + audio_data = b64encode(audio_data).decode("utf-8") + callback_data = {"type": "neon.audio_input", + "data": {"audio_data": audio_data, "lang": self.lang}} + self.audio_callback(callback_data, self.session_id) + + def on_chunk(self, chunk: ChunkInfo): + LOG.debug(f"Chunk: {chunk}") + + def shutdown(self): + self.mic.stop() + self.voice_loop.stop() + + +class MockTransformers(Mock): + def transform(self, chunk): + return chunk, dict() diff --git a/requirements/streaming.txt b/requirements/streaming.txt new file mode 100644 index 0000000..6a241e7 --- /dev/null +++ b/requirements/streaming.txt @@ -0,0 +1,5 @@ +mock~=5.0 +ovos-dinkum-listener~=0.1 +ovos-vad-plugin-webrtcvad +#ovos-ww-plugin-pocketsphinx +ovos-ww-plugin-vosk \ No newline at end of file diff --git a/requirements/websocket.txt b/requirements/websocket.txt index c9957af..2a5a773 100644 --- a/requirements/websocket.txt +++ b/requirements/websocket.txt @@ -1,7 +1,2 @@ neon-iris~=0.1,>=0.1.1a5 websockets~=12.0 -mock~=5.0 -ovos-dinkum-listener~=0.1 -ovos-vad-plugin-webrtcvad -#ovos-ww-plugin-pocketsphinx -ovos-ww-plugin-vosk \ No newline at end of file diff --git a/setup.py b/setup.py index b9928ee..3777817 100644 --- a/setup.py +++ b/setup.py @@ -74,7 +74,8 @@ def get_requirements(requirements_filename: str): license='BSD-3-Clause', packages=find_packages(), install_requires=get_requirements("requirements.txt"), - extras_require={"websocket": get_requirements("websocket.txt")}, + extras_require={"websocket": get_requirements("websocket.txt"), + "steaming": get_requirements("streaming.txt")}, zip_safe=True, classifiers=[ 'Intended Audience :: Developers',