Skip to content

Commit

Permalink
Added structs and functions necessary for GPU inferencing support (#8)
Browse files Browse the repository at this point in the history
* Added bindings to batch/gpu inferencing

* Finished adding rudimentary documentation

* Finish documentation I missed

* Add a line to top level Cargo.toml to remove a warning while compiling

* Put gpu inferencing features behind cuda feature flag

* Fixed whitespace error

* Added in-file modules for batch operations

* Changed return type of get_pending_chunks
  • Loading branch information
lightningpwr28 authored Aug 17, 2024
1 parent 2ceb01a commit 5fc7ad1
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 1 deletion.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ members = [
"vosk",
"vosk-sys",
]
resolver = "2"
2 changes: 1 addition & 1 deletion vosk-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ extern "C" {
#[doc = " Creates the batch recognizer object"]
#[doc = ""]
#[doc = " @returns model object or NULL if problem occured"]
pub fn vosk_batch_model_new() -> *mut VoskBatchModel;
pub fn vosk_batch_model_new(model_path: *const ::std::os::raw::c_char) -> *mut VoskBatchModel;

#[doc = " Releases batch model object"]
pub fn vosk_batch_model_free(model: *mut VoskBatchModel);
Expand Down
3 changes: 3 additions & 0 deletions vosk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ serde = { version = "1.0", features = ["derive"] }
cpal = "0.14"
dasp = "0.11"
hound = "3.5"

[features]
cuda = []
14 changes: 14 additions & 0 deletions vosk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,17 @@ mod recognition;
pub use log::*;
pub use models::*;
pub use recognition::*;

/// Init, automatically select a CUDA device and allow multithreading.
/// Must be called once from the main thread.
#[cfg(feature = "cuda")]
pub fn gpu_init() {
unsafe { vosk_sys::vosk_gpu_init() }
}

/// Init CUDA device in a multi-threaded environment.
/// Must be called for each thread.
#[cfg(feature = "cuda")]
pub fn gpu_thread_init() {
unsafe { vosk_sys::vosk_gpu_thread_init() }
}
38 changes: 38 additions & 0 deletions vosk/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,41 @@ impl Drop for SpeakerModel {

unsafe impl Send for SpeakerModel {}
unsafe impl Sync for SpeakerModel {}

#[cfg(feature = "cuda")]
pub mod batch_model {
use std::{ffi::CString, ptr::NonNull};
use vosk_sys::*;
/// The same as [`Model`], but uses a CUDA enabled Nvidia GPU and dynamic batching to enable higher throughput.

pub struct BatchModel(pub(crate) NonNull<VoskBatchModel>);

impl BatchModel {
/// Loads model data from the file and returns the model object, or [`None`]
/// if a problem occured.
///
/// * `model_path` - the path to the model directory.
#[must_use]
pub fn new(model_path: impl Into<String>) -> Option<Self> {
let model_path_c = CString::new(model_path.into()).ok()?;
let model_ptr = unsafe { vosk_batch_model_new(model_path_c.as_ptr()) };

Some(Self(NonNull::new(model_ptr)?))
}

/// Waits for inferencing to finish
pub fn wait(&self) {
unsafe { vosk_batch_model_wait(self.0.as_ptr()) };
}
}

impl Drop for BatchModel {
fn drop(&mut self) {
unsafe { vosk_batch_model_free(self.0.as_ptr()) }
}
}

unsafe impl Send for BatchModel {}

unsafe impl Sync for BatchModel {}
}
89 changes: 89 additions & 0 deletions vosk/src/recognition/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{Model, SpeakerModel};

use serde::Deserialize;
use std::{
ffi::{CStr, CString},
Expand Down Expand Up @@ -189,6 +190,11 @@ impl Recognizer {
unsafe { vosk_recognizer_set_partial_words(self.0.as_ptr(), i32::from(enable)) }
}

/// Enables or disables Natural Language Semantics Markup Language (NLSML) in the output
pub fn set_nlsml(&mut self, enable: bool) {
unsafe { vosk_recognizer_set_nlsml(self.0.as_ptr(), i32::from(enable)) }
}

/// Accept and process new chunk of voice data.
///
/// * `data` - Audio data in PCM 16-bit mono format.
Expand Down Expand Up @@ -277,3 +283,86 @@ impl Drop for Recognizer {
unsafe { vosk_recognizer_free(self.0.as_ptr()) }
}
}

#[cfg(feature = "cuda")]
pub mod batch_recognizer {
use crate::batch_model::BatchModel;
use vosk_sys::*;

use std::{ffi::CStr, ptr::NonNull};

pub use crate::recognition::results::*;

/// The main object which processes data using GPU inferencing.
/// Takes audio as input and returns decoded information as words, confidences, times, and other metadata.

pub struct BatchRecognizer(std::ptr::NonNull<VoskBatchRecognizer>);

impl BatchRecognizer {
/// Creates the recognizer object. Returns [`None`] if a problem occured.
///
/// The recognizers process the speech and return text using shared model data.
///
/// * `model` - [`BatchModel`] containing static data for recognizer. Model can be shared
/// across recognizers, even running in different threads.
///
/// * `sample_rate` - The sample rate of the audio you going to feed into the recognizer.
/// Make sure this rate matches the audio content, it is a common issue causing accuracy problems.
///
/// [`BatchModel`]: crate::BatchModel
#[must_use]
pub fn new(model: &BatchModel, sample_rate: f32) -> Option<Self> {
let recognizer_ptr =
unsafe { vosk_batch_recognizer_new(model.0.as_ptr(), sample_rate) };
Some(Self(NonNull::new(recognizer_ptr)?))
}

/// Enables or disables Natural Language Semantics Markup Language (NLSML) in the output
pub fn set_nlsml(&mut self, enable: bool) {
unsafe { vosk_batch_recognizer_set_nlsml(self.0.as_ptr(), i32::from(enable)) }
}

/// Accept and process new chunk of voice data.
///
/// * `data` - Audio data in PCM 16-bit mono format as an array of i8.
pub fn accept_waveform(&mut self, data: &[i8]) {
unsafe {
vosk_batch_recognizer_accept_waveform(
self.0.as_ptr(),
data.as_ptr(),
data.len() as i32,
)
};
}

/// Closes the stream to the model
pub fn finish_stream(&mut self) {
unsafe { vosk_batch_recognizer_finish_stream(self.0.as_ptr()) };
}

/// Gets the front of the result queue
pub fn front_result(&mut self) -> Result<Word, serde_json::Error> {
serde_json::from_str(
unsafe { CStr::from_ptr(vosk_batch_recognizer_front_result(self.0.as_ptr())) }
.to_str()
.unwrap(),
)
}

/// Removes the front of the result queue
pub fn pop(&mut self) {
unsafe { vosk_batch_recognizer_pop(self.0.as_ptr()) }
}

/// Gets the number of chunks that have yet to be processed
pub fn get_pending_chunks(&mut self) -> i32 {
(unsafe { vosk_batch_recognizer_get_pending_chunks(self.0.as_ptr()) }) as i32
}
}

impl Drop for BatchRecognizer {
fn drop(&mut self) {
unsafe { vosk_batch_recognizer_free(self.0.as_ptr()) }
}
}
}

0 comments on commit 5fc7ad1

Please sign in to comment.