Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added structs and functions necessary for GPU inferencing support #8

Merged
merged 8 commits into from
Aug 17, 2024
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ members = [
"vosk",
"vosk-sys",
]
resolver = "2"
2 changes: 1 addition & 1 deletion vosk-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ extern "C" {
#[doc = " Creates the batch recognizer object"]
#[doc = ""]
#[doc = " @returns model object or NULL if problem occured"]
pub fn vosk_batch_model_new() -> *mut VoskBatchModel;
pub fn vosk_batch_model_new(model_path: *const ::std::os::raw::c_char) -> *mut VoskBatchModel;

#[doc = " Releases batch model object"]
pub fn vosk_batch_model_free(model: *mut VoskBatchModel);
Expand Down
3 changes: 3 additions & 0 deletions vosk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ serde = { version = "1.0", features = ["derive"] }
cpal = "0.14"
dasp = "0.11"
hound = "3.5"

[features]
cuda = []
14 changes: 14 additions & 0 deletions vosk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,17 @@ mod recognition;
pub use log::*;
pub use models::*;
pub use recognition::*;

/// Init, automatically select a CUDA device and allow multithreading.
/// Must be called once from the main thread.
#[cfg(feature = "cuda")]
pub fn gpu_init() {
unsafe { vosk_sys::vosk_gpu_init() }
}

/// Init CUDA device in a multi-threaded environment.
/// Must be called for each thread.
#[cfg(feature = "cuda")]
pub fn gpu_thread_init() {
unsafe { vosk_sys::vosk_gpu_thread_init() }
}
38 changes: 38 additions & 0 deletions vosk/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,41 @@ impl Drop for SpeakerModel {

unsafe impl Send for SpeakerModel {}
unsafe impl Sync for SpeakerModel {}

#[cfg(feature = "cuda")]
pub mod batch_model {
use std::{ffi::CString, ptr::NonNull};
use vosk_sys::*;
/// The same as [`Model`], but uses a CUDA enabled Nvidia GPU and dynamic batching to enable higher throughput.

pub struct BatchModel(pub(crate) NonNull<VoskBatchModel>);

impl BatchModel {
/// Loads model data from the file and returns the model object, or [`None`]
/// if a problem occured.
///
/// * `model_path` - the path to the model directory.
#[must_use]
pub fn new(model_path: impl Into<String>) -> Option<Self> {
let model_path_c = CString::new(model_path.into()).ok()?;
let model_ptr = unsafe { vosk_batch_model_new(model_path_c.as_ptr()) };

Some(Self(NonNull::new(model_ptr)?))
}

/// Waits for inferencing to finish
pub fn wait(&self) {
unsafe { vosk_batch_model_wait(self.0.as_ptr()) };
}
}

impl Drop for BatchModel {
fn drop(&mut self) {
unsafe { vosk_batch_model_free(self.0.as_ptr()) }
}
}

unsafe impl Send for BatchModel {}

unsafe impl Sync for BatchModel {}
}
89 changes: 89 additions & 0 deletions vosk/src/recognition/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{Model, SpeakerModel};

use serde::Deserialize;
use std::{
ffi::{CStr, CString},
Expand Down Expand Up @@ -189,6 +190,11 @@ impl Recognizer {
unsafe { vosk_recognizer_set_partial_words(self.0.as_ptr(), i32::from(enable)) }
}

/// Enables or disables Natural Language Semantics Markup Language (NLSML) in the output
pub fn set_nlsml(&mut self, enable: bool) {
unsafe { vosk_recognizer_set_nlsml(self.0.as_ptr(), i32::from(enable)) }
}

/// Accept and process new chunk of voice data.
///
/// * `data` - Audio data in PCM 16-bit mono format.
Expand Down Expand Up @@ -277,3 +283,86 @@ impl Drop for Recognizer {
unsafe { vosk_recognizer_free(self.0.as_ptr()) }
}
}

#[cfg(feature = "cuda")]
pub mod batch_recognizer {
use crate::batch_model::BatchModel;
use vosk_sys::*;

use std::{ffi::CStr, ptr::NonNull};

pub use crate::recognition::results::*;

/// The main object which processes data using GPU inferencing.
/// Takes audio as input and returns decoded information as words, confidences, times, and other metadata.

pub struct BatchRecognizer(std::ptr::NonNull<VoskBatchRecognizer>);

impl BatchRecognizer {
/// Creates the recognizer object. Returns [`None`] if a problem occured.
///
/// The recognizers process the speech and return text using shared model data.
///
/// * `model` - [`BatchModel`] containing static data for recognizer. Model can be shared
/// across recognizers, even running in different threads.
///
/// * `sample_rate` - The sample rate of the audio you going to feed into the recognizer.
/// Make sure this rate matches the audio content, it is a common issue causing accuracy problems.
///
/// [`BatchModel`]: crate::BatchModel
#[must_use]
pub fn new(model: &BatchModel, sample_rate: f32) -> Option<Self> {
let recognizer_ptr =
unsafe { vosk_batch_recognizer_new(model.0.as_ptr(), sample_rate) };
Some(Self(NonNull::new(recognizer_ptr)?))
}

/// Enables or disables Natural Language Semantics Markup Language (NLSML) in the output
pub fn set_nlsml(&mut self, enable: bool) {
unsafe { vosk_batch_recognizer_set_nlsml(self.0.as_ptr(), i32::from(enable)) }
}

/// Accept and process new chunk of voice data.
///
/// * `data` - Audio data in PCM 16-bit mono format as an array of i8.
pub fn accept_waveform(&mut self, data: &[i8]) {
unsafe {
vosk_batch_recognizer_accept_waveform(
self.0.as_ptr(),
data.as_ptr(),
data.len() as i32,
)
};
}

/// Closes the stream to the model
pub fn finish_stream(&mut self) {
unsafe { vosk_batch_recognizer_finish_stream(self.0.as_ptr()) };
}

/// Gets the front of the result queue
pub fn front_result(&mut self) -> Result<Word, serde_json::Error> {
serde_json::from_str(
unsafe { CStr::from_ptr(vosk_batch_recognizer_front_result(self.0.as_ptr())) }
.to_str()
.unwrap(),
)
}

/// Removes the front of the result queue
pub fn pop(&mut self) {
unsafe { vosk_batch_recognizer_pop(self.0.as_ptr()) }
}

/// Gets the number of chunks that have yet to be processed
pub fn get_pending_chunks(&mut self) -> usize {
(unsafe { vosk_batch_recognizer_get_pending_chunks(self.0.as_ptr()) }) as usize
Bear-03 marked this conversation as resolved.
Show resolved Hide resolved
}
}

impl Drop for BatchRecognizer {
fn drop(&mut self) {
unsafe { vosk_batch_recognizer_free(self.0.as_ptr()) }
}
}
}
Loading