Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions crates/cactus/src/llm/complete.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
use std::cell::{Cell, UnsafeCell};
use std::ffi::{CStr, CString};

use crate::error::{Error, Result};
use crate::ffi_utils::{RESPONSE_BUF_SIZE, parse_buf};
use crate::model::Model;
use crate::model::{InferenceGuard, Model};

use super::{CompleteOptions, CompletionResult, Message};

type TokenCallback = unsafe extern "C" fn(*const std::ffi::c_char, u32, *mut std::ffi::c_void);

struct CallbackState<'a, F: FnMut(&str) -> bool> {
on_token: &'a mut F,
on_token: UnsafeCell<&'a mut F>,
model: &'a Model,
stopped: bool,
stopped: Cell<bool>,
in_callback: Cell<bool>,
}

unsafe extern "C" fn token_trampoline<F: FnMut(&str) -> bool>(
Expand All @@ -23,21 +25,28 @@
return;
}

let state = unsafe { &mut *(user_data as *mut CallbackState<F>) };
if state.stopped {
// SAFETY: We only create a shared reference to CallbackState. Interior
// mutability (Cell/UnsafeCell) handles mutation. The `in_callback` guard
// prevents re-entrant access to the UnsafeCell contents.
let state = unsafe { &*(user_data as *const CallbackState<F>) };
if state.stopped.get() || state.in_callback.get() {
return;
}
state.in_callback.set(true);

let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let chunk = unsafe { CStr::from_ptr(token) }.to_string_lossy();
if !(state.on_token)(&chunk) {
state.stopped = true;
// SAFETY: The `in_callback` flag ensures exclusive access to the closure.
let on_token = unsafe { &mut *state.on_token.get() };
if !on_token(&chunk) {
state.stopped.set(true);
state.model.stop();
}
}));

state.in_callback.set(false);
if result.is_err() {
state.stopped = true;
state.stopped.set(true);
state.model.stop();
}
}
Expand All @@ -58,6 +67,7 @@
impl Model {
fn call_complete(
&self,
guard: &InferenceGuard<'_>,
messages_c: &CString,
options_c: &CString,
callback: Option<TokenCallback>,
Expand All @@ -67,7 +77,7 @@

let rc = unsafe {
cactus_sys::cactus_complete(
self.raw_handle(),
guard.raw_handle(),
messages_c.as_ptr(),
buf.as_mut_ptr().cast::<std::ffi::c_char>(),
buf.len(),
Expand All @@ -86,9 +96,10 @@
messages: &[Message],
options: &CompleteOptions,
) -> Result<CompletionResult> {
let _guard = self.lock_inference();
let guard = self.lock_inference();
let (messages_c, options_c) = serialize_complete_request(messages, options)?;
let (rc, buf) = self.call_complete(&messages_c, &options_c, None, std::ptr::null_mut());
let (rc, buf) =
self.call_complete(&guard, &messages_c, &options_c, None, std::ptr::null_mut());

if rc < 0 {
return Err(complete_error(rc));
Expand All @@ -106,23 +117,28 @@
where
F: FnMut(&str) -> bool,
{
let _guard = self.lock_inference();
let guard = self.lock_inference();
let (messages_c, options_c) = serialize_complete_request(messages, options)?;

let mut state = CallbackState {
on_token: &mut on_token,
let state = CallbackState {
on_token: UnsafeCell::new(&mut on_token),
model: self,
stopped: false,
stopped: Cell::new(false),
in_callback: Cell::new(false),
};

// SAFETY: `state` is stack-allocated and lives for the duration of the
// FFI call. The C++ side must not retain this pointer beyond the return
// of `cactus_complete`.
let (rc, buf) = self.call_complete(
&guard,
&messages_c,
&options_c,
Some(token_trampoline::<F>),
(&mut state as *mut CallbackState<F>).cast::<std::ffi::c_void>(),
(&state as *const CallbackState<F> as *mut std::ffi::c_void),

Check warning on line 138 in crates/cactus/src/llm/complete.rs

View workflow job for this annotation

GitHub Actions / desktop_ci (linux-aarch64, depot-ubuntu-22.04-arm-8)

unnecessary parentheses around method argument

Check warning on line 138 in crates/cactus/src/llm/complete.rs

View workflow job for this annotation

GitHub Actions / desktop_ci (linux-aarch64, depot-ubuntu-22.04-arm-8)

unnecessary parentheses around method argument

Check warning on line 138 in crates/cactus/src/llm/complete.rs

View workflow job for this annotation

GitHub Actions / desktop_ci (macos, depot-macos-15)

unnecessary parentheses around method argument

Check warning on line 138 in crates/cactus/src/llm/complete.rs

View workflow job for this annotation

GitHub Actions / desktop_ci (macos, depot-macos-15)

unnecessary parentheses around method argument
);

if rc < 0 && !state.stopped {
if rc < 0 && !state.stopped.get() {
return Err(complete_error(rc));
}

Expand Down
37 changes: 26 additions & 11 deletions crates/cactus/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,23 @@ pub struct Model {
}

unsafe impl Send for Model {}
// SAFETY: All FFI methods that touch model state are serialized by `inference_lock`.
// SAFETY: All FFI methods that touch model state are serialized by `inference_lock`,
// which is enforced at compile time via `InferenceGuard` — the model's raw handle is
// only accessible through the guard returned by `lock_inference()`.
// The sole exception is `stop()`, which only sets a `std::atomic<bool>` on the C++ side.
unsafe impl Sync for Model {}

pub(crate) struct InferenceGuard<'a> {
handle: NonNull<std::ffi::c_void>,
_guard: MutexGuard<'a, ()>,
}

impl InferenceGuard<'_> {
pub(crate) fn raw_handle(&self) -> *mut std::ffi::c_void {
self.handle.as_ptr()
}
}

pub struct ModelBuilder {
model_path: PathBuf,
}
Expand Down Expand Up @@ -53,27 +66,29 @@ impl Model {
}

pub fn reset(&mut self) {
let _guard = self.lock_inference();
let guard = self.lock_inference();
unsafe {
cactus_sys::cactus_reset(self.handle.as_ptr());
cactus_sys::cactus_reset(guard.raw_handle());
}
}

pub(crate) fn lock_inference(&self) -> MutexGuard<'_, ()> {
self.inference_lock
pub(crate) fn lock_inference(&self) -> InferenceGuard<'_> {
let guard = self
.inference_lock
.lock()
.unwrap_or_else(|e| e.into_inner())
}

pub(crate) fn raw_handle(&self) -> *mut std::ffi::c_void {
self.handle.as_ptr()
.unwrap_or_else(|e| e.into_inner());
InferenceGuard {
handle: self.handle,
_guard: guard,
}
}
}

impl Drop for Model {
fn drop(&mut self) {
let guard = self.lock_inference();
unsafe {
cactus_sys::cactus_destroy(self.handle.as_ptr());
cactus_sys::cactus_destroy(guard.raw_handle());
}
}
}
4 changes: 2 additions & 2 deletions crates/cactus/src/stt/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl Model {
input: TranscribeInput<'_>,
options: &TranscribeOptions,
) -> Result<TranscriptionResult> {
let _guard = self.lock_inference();
let guard = self.lock_inference();
let prompt_c = CString::new(build_whisper_prompt(options))?;
let options_c = CString::new(serde_json::to_string(options)?)?;
let mut buf = vec![0u8; RESPONSE_BUF_SIZE];
Expand All @@ -44,7 +44,7 @@ impl Model {

let rc = unsafe {
cactus_sys::cactus_transcribe(
self.raw_handle(),
guard.raw_handle(),
path_ptr,
prompt_c.as_ptr(),
buf.as_mut_ptr() as *mut std::ffi::c_char,
Expand Down
4 changes: 2 additions & 2 deletions crates/cactus/src/stt/transcriber.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ impl std::str::FromStr for StreamResult {

impl<'a> Transcriber<'a> {
pub fn new(model: &'a Model, options: &TranscribeOptions, cloud: CloudConfig) -> Result<Self> {
let _guard = model.lock_inference();
let guard = model.lock_inference();
let options_c = serialize_stream_options(options, &cloud)?;

let raw = unsafe {
cactus_sys::cactus_stream_transcribe_start(model.raw_handle(), options_c.as_ptr())
cactus_sys::cactus_stream_transcribe_start(guard.raw_handle(), options_c.as_ptr())
};

let handle = NonNull::new(raw).ok_or_else(|| {
Expand Down
4 changes: 2 additions & 2 deletions crates/cactus/src/vad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl Model {
pcm: Option<&[u8]>,
options: &VadOptions,
) -> Result<VadResult> {
let _guard = self.lock_inference();
let guard = self.lock_inference();
let options_c = CString::new(serde_json::to_string(options)?)?;
let mut buf = vec![0u8; RESPONSE_BUF_SIZE];

Expand All @@ -64,7 +64,7 @@ impl Model {

let rc = unsafe {
cactus_sys::cactus_vad(
self.raw_handle(),
guard.raw_handle(),
path.map_or(std::ptr::null(), |p| p.as_ptr()),
buf.as_mut_ptr() as *mut std::ffi::c_char,
buf.len(),
Expand Down
Loading