diff --git a/Cargo.toml b/Cargo.toml index eea4c66..97ccde7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,3 +3,4 @@ members = [ "vosk", "vosk-sys", ] +resolver = "2" diff --git a/vosk-sys/src/lib.rs b/vosk-sys/src/lib.rs index 2a5b4e3..b870bbc 100644 --- a/vosk-sys/src/lib.rs +++ b/vosk-sys/src/lib.rs @@ -304,7 +304,7 @@ extern "C" { #[doc = " Creates the batch recognizer object"] #[doc = ""] #[doc = " @returns model object or NULL if problem occured"] - pub fn vosk_batch_model_new() -> *mut VoskBatchModel; + pub fn vosk_batch_model_new(model_path: *const ::std::os::raw::c_char) -> *mut VoskBatchModel; #[doc = " Releases batch model object"] pub fn vosk_batch_model_free(model: *mut VoskBatchModel); diff --git a/vosk/Cargo.toml b/vosk/Cargo.toml index 8117b60..19c4334 100644 --- a/vosk/Cargo.toml +++ b/vosk/Cargo.toml @@ -19,3 +19,6 @@ serde = { version = "1.0", features = ["derive"] } cpal = "0.14" dasp = "0.11" hound = "3.5" + +[features] +cuda = [] diff --git a/vosk/src/lib.rs b/vosk/src/lib.rs index af10495..9b765d3 100644 --- a/vosk/src/lib.rs +++ b/vosk/src/lib.rs @@ -15,3 +15,17 @@ mod recognition; pub use log::*; pub use models::*; pub use recognition::*; + +/// Init, automatically select a CUDA device and allow multithreading. +/// Must be called once from the main thread. +#[cfg(feature = "cuda")] +pub fn gpu_init() { + unsafe { vosk_sys::vosk_gpu_init() } +} + +/// Init CUDA device in a multi-threaded environment. +/// Must be called for each thread. +#[cfg(feature = "cuda")] +pub fn gpu_thread_init() { + unsafe { vosk_sys::vosk_gpu_thread_init() } +} diff --git a/vosk/src/models.rs b/vosk/src/models.rs index 8cfa04f..b5a5046 100644 --- a/vosk/src/models.rs +++ b/vosk/src/models.rs @@ -76,3 +76,41 @@ impl Drop for SpeakerModel { unsafe impl Send for SpeakerModel {} unsafe impl Sync for SpeakerModel {} + +#[cfg(feature = "cuda")] +pub mod batch_model { + use std::{ffi::CString, ptr::NonNull}; + use vosk_sys::*; + /// The same as [`Model`], but uses a CUDA enabled Nvidia GPU and dynamic batching to enable higher throughput. + + pub struct BatchModel(pub(crate) NonNull); + + impl BatchModel { + /// Loads model data from the file and returns the model object, or [`None`] + /// if a problem occured. + /// + /// * `model_path` - the path to the model directory. + #[must_use] + pub fn new(model_path: impl Into) -> Option { + let model_path_c = CString::new(model_path.into()).ok()?; + let model_ptr = unsafe { vosk_batch_model_new(model_path_c.as_ptr()) }; + + Some(Self(NonNull::new(model_ptr)?)) + } + + /// Waits for inferencing to finish + pub fn wait(&self) { + unsafe { vosk_batch_model_wait(self.0.as_ptr()) }; + } + } + + impl Drop for BatchModel { + fn drop(&mut self) { + unsafe { vosk_batch_model_free(self.0.as_ptr()) } + } + } + + unsafe impl Send for BatchModel {} + + unsafe impl Sync for BatchModel {} +} diff --git a/vosk/src/recognition/mod.rs b/vosk/src/recognition/mod.rs index 9e1d4e1..d34f14d 100644 --- a/vosk/src/recognition/mod.rs +++ b/vosk/src/recognition/mod.rs @@ -1,4 +1,5 @@ use crate::{Model, SpeakerModel}; + use serde::Deserialize; use std::{ ffi::{CStr, CString}, @@ -189,6 +190,11 @@ impl Recognizer { unsafe { vosk_recognizer_set_partial_words(self.0.as_ptr(), i32::from(enable)) } } + /// Enables or disables Natural Language Semantics Markup Language (NLSML) in the output + pub fn set_nlsml(&mut self, enable: bool) { + unsafe { vosk_recognizer_set_nlsml(self.0.as_ptr(), i32::from(enable)) } + } + /// Accept and process new chunk of voice data. /// /// * `data` - Audio data in PCM 16-bit mono format. @@ -277,3 +283,86 @@ impl Drop for Recognizer { unsafe { vosk_recognizer_free(self.0.as_ptr()) } } } + +#[cfg(feature = "cuda")] +pub mod batch_recognizer { + use crate::batch_model::BatchModel; + use vosk_sys::*; + + use std::{ffi::CStr, ptr::NonNull}; + + pub use crate::recognition::results::*; + + /// The main object which processes data using GPU inferencing. + /// Takes audio as input and returns decoded information as words, confidences, times, and other metadata. + + pub struct BatchRecognizer(std::ptr::NonNull); + + impl BatchRecognizer { + /// Creates the recognizer object. Returns [`None`] if a problem occured. + /// + /// The recognizers process the speech and return text using shared model data. + /// + /// * `model` - [`BatchModel`] containing static data for recognizer. Model can be shared + /// across recognizers, even running in different threads. + /// + /// * `sample_rate` - The sample rate of the audio you going to feed into the recognizer. + /// Make sure this rate matches the audio content, it is a common issue causing accuracy problems. + /// + /// [`BatchModel`]: crate::BatchModel + #[must_use] + pub fn new(model: &BatchModel, sample_rate: f32) -> Option { + let recognizer_ptr = + unsafe { vosk_batch_recognizer_new(model.0.as_ptr(), sample_rate) }; + Some(Self(NonNull::new(recognizer_ptr)?)) + } + + /// Enables or disables Natural Language Semantics Markup Language (NLSML) in the output + pub fn set_nlsml(&mut self, enable: bool) { + unsafe { vosk_batch_recognizer_set_nlsml(self.0.as_ptr(), i32::from(enable)) } + } + + /// Accept and process new chunk of voice data. + /// + /// * `data` - Audio data in PCM 16-bit mono format as an array of i8. + pub fn accept_waveform(&mut self, data: &[i8]) { + unsafe { + vosk_batch_recognizer_accept_waveform( + self.0.as_ptr(), + data.as_ptr(), + data.len() as i32, + ) + }; + } + + /// Closes the stream to the model + pub fn finish_stream(&mut self) { + unsafe { vosk_batch_recognizer_finish_stream(self.0.as_ptr()) }; + } + + /// Gets the front of the result queue + pub fn front_result(&mut self) -> Result { + serde_json::from_str( + unsafe { CStr::from_ptr(vosk_batch_recognizer_front_result(self.0.as_ptr())) } + .to_str() + .unwrap(), + ) + } + + /// Removes the front of the result queue + pub fn pop(&mut self) { + unsafe { vosk_batch_recognizer_pop(self.0.as_ptr()) } + } + + /// Gets the number of chunks that have yet to be processed + pub fn get_pending_chunks(&mut self) -> i32 { + (unsafe { vosk_batch_recognizer_get_pending_chunks(self.0.as_ptr()) }) as i32 + } + } + + impl Drop for BatchRecognizer { + fn drop(&mut self) { + unsafe { vosk_batch_recognizer_free(self.0.as_ptr()) } + } + } +}