Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
thewh1teagle committed Jul 12, 2024
2 parents c2e399e + b6dbfce commit ad636a6
Show file tree
Hide file tree
Showing 16 changed files with 163 additions and 75 deletions.
149 changes: 92 additions & 57 deletions examples/diarize.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
/*
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/nemo_en_speakerverification_speakernet.onnx
cargo run --example diarize
wget https://github.com/thewh1teagle/sherpa-rs/releases/download/v0.1.0/motivation.wav -O motivation.wav
cargo run --example diarize motivation.wav
*/

use eyre::{bail, Result};
Expand All @@ -11,9 +12,73 @@ use sherpa_rs::{
};
use std::io::Cursor;

fn get_speaker_name(
embedding_manager: &mut embedding_manager::EmbeddingManager,
embedding: &mut [f32],
speaker_counter: &mut i32,
max_speakers: i32,
) -> String {
let mut name = String::from("unknown");

if *speaker_counter == 0 {
name = format!("speaker {}", speaker_counter);
embedding_manager.add(name.clone(), embedding).unwrap();
*speaker_counter += 1;
} else if *speaker_counter <= max_speakers {
if let Some(search_result) = embedding_manager.search(embedding, 0.5) {
name = search_result;
} else {
name = format!("speaker {}", speaker_counter);
embedding_manager.add(name.clone(), embedding).unwrap();
*speaker_counter += 1;
}
} else {
let matches = embedding_manager.get_best_matches(embedding, 0.2, *speaker_counter);
if let Some(name_match) = matches.first().map(|m| m.name.clone()) {
name = name_match;
}
}

name
}

fn process_speech_segment(
vad: &mut Vad,
sample_rate: i32,
mut embedding_manager: &mut embedding_manager::EmbeddingManager,
extractor: &mut speaker_id::EmbeddingExtractor,
speaker_counter: &mut i32,
max_speakers: i32,
) -> Result<()> {
while !vad.is_empty() {
let segment = vad.front();
let start_sec = (segment.start as f32) / sample_rate as f32;
let duration_sec = (segment.samples.len() as f32) / sample_rate as f32;

// Compute the speaker embedding
let mut embedding = extractor.compute_speaker_embedding(sample_rate, segment.samples)?;

let name = get_speaker_name(
&mut embedding_manager,
&mut embedding,
speaker_counter,
max_speakers,
);
println!(
"({}) start={}s end={}s",
name,
start_sec,
start_sec + duration_sec
);
vad.pop();
}
Ok(())
}

fn main() -> Result<()> {
// Read audio data from the file
let audio_data: &[u8] = include_bytes!("../samples/motivation.wav");
let file_path = std::env::args().nth(1).expect("Missing file path argument");
let audio_data = std::fs::read(file_path)?;
let max_speakers = 2;

let cursor = Cursor::new(audio_data);
let mut reader = hound::WavReader::new(cursor)?;
Expand Down Expand Up @@ -43,14 +108,14 @@ fn main() -> Result<()> {
let mut embedding_manager =
embedding_manager::EmbeddingManager::new(extractor.embedding_size.try_into().unwrap()); // Assuming dimension 512 for embeddings

let mut speaker_counter = 0;
let mut speaker_counter = 1;

let vad_model = "silero_vad.onnx".into();
let window_size: usize = 512;
let config = VadConfig::new(
vad_model,
0.4,
0.4,
0.5,
0.5,
0.5,
sample_rate,
window_size.try_into().unwrap(),
Expand All @@ -66,61 +131,31 @@ fn main() -> Result<()> {
vad.accept_waveform(window.to_vec()); // Convert slice to Vec
if vad.is_speech() {
while !vad.is_empty() {
let segment = vad.front();
let start_sec = (segment.start as f32) / sample_rate as f32;
let duration_sec = (segment.samples.len() as f32) / sample_rate as f32;

// Compute the speaker embedding
let mut embedding =
extractor.compute_speaker_embedding(sample_rate, segment.samples)?;

let name = if let Some(speaker_name) = embedding_manager.search(&embedding, 0.45) {
speaker_name
} else {
// Register a new speaker and add the embedding
let name = format!("speaker {}", speaker_counter);
embedding_manager.add(name.clone(), &mut embedding)?;

speaker_counter += 1;
name
};
println!(
"({}) start={}s end={}s",
name,
start_sec,
start_sec + duration_sec
);
vad.pop();
process_speech_segment(
&mut vad,
sample_rate,
&mut embedding_manager,
&mut extractor,
&mut speaker_counter,
max_speakers,
)?;
}
}

index += window_size;
}

if index < samples.len() {
let remaining_samples = &samples[index..];
vad.accept_waveform(remaining_samples.to_vec());
while !vad.is_empty() {
let segment = vad.front();
let start_sec = (segment.start as f32) / sample_rate as f32;
let duration_sec = (segment.samples.len() as f32) / sample_rate as f32;

// Compute the speaker embedding
let mut embedding =
extractor.compute_speaker_embedding(sample_rate, segment.samples)?;

let name = if let Some(speaker_name) = embedding_manager.search(&embedding, 0.45) {
speaker_name
} else {
// Register a new speaker and add the embedding
let name = format!("speaker {}", speaker_counter);
embedding_manager.add(name.clone(), &mut embedding)?;

speaker_counter += 1;
name
};
println!("({}) start={}s duration={}s", name, start_sec, duration_sec);
vad.pop();
}
vad.flush();
// process reamaining
while !vad.is_empty() {
process_speech_segment(
&mut vad,
sample_rate,
&mut embedding_manager,
&mut extractor,
&mut speaker_counter,
max_speakers,
)?;
}

Ok(())
}
4 changes: 2 additions & 2 deletions examples/diarize_whisper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-o
wget https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/nemo_en_speakerverification_speakernet.onnx
tar xvf sherpa-onnx-whisper-tiny.tar.bz2
wget https://github.com/thewh1teagle/sherpa-rs/releases/download/v0.1.0/sam_altman.wav -O samples/sam_altman.wav
wget https://github.com/thewh1teagle/sherpa-rs/releases/download/v0.1.0/sam_altman.wav -O sam_altman.wav
cargo run --example diarize_whisper
*/

Expand Down Expand Up @@ -34,7 +34,7 @@ fn read_audio_file(path: &str) -> Result<(i32, Vec<f32>)> {

fn main() -> Result<()> {
// Read audio data from the file
let (sample_rate, mut samples) = read_audio_file("samples/sam_altman.wav")?;
let (sample_rate, mut samples) = read_audio_file("sam_altman.wav")?;

// Pad with 3 seconds of slience so vad will able to detect stop
for _ in 0..3 * sample_rate {
Expand Down
7 changes: 4 additions & 3 deletions examples/language_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
tar xvf sherpa-onnx-whisper-tiny.tar.bz2
rm sherpa-onnx-whisper-tiny.tar.bz2
cargo run --example language_id
wget https://github.com/thewh1teagle/sherpa-rs/releases/download/v0.1.0/16hz_mono_pcm_s16le.wav -O 16hz_mono_pcm_s16le.wav
cargo run --example language_id 16hz_mono_pcm_s16le.wav
*/

use eyre::{bail, Result};
use sherpa_rs::language_id;
use std::io::Cursor;

fn main() -> Result<()> {
// Read audio data from the file
let audio_data: &[u8] = include_bytes!("../samples/16hz_mono_pcm_s16le.wav");
let file_path = std::env::args().nth(1).expect("Missing file path argument");
let audio_data = std::fs::read(file_path)?;

let cursor = Cursor::new(audio_data);
let mut reader = hound::WavReader::new(cursor)?;
Expand Down
7 changes: 4 additions & 3 deletions examples/speaker_embedding.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/nemo_en_speakerverification_speakernet.onnx
cargo run --example speaker_embedding
wget https://github.com/thewh1teagle/sherpa-rs/releases/download/v0.1.0/16hz_mono_pcm_s16le.wav -O 16hz_mono_pcm_s16le.wav
cargo run --example speaker_embedding 16hz_mono_pcm_s16le.wav
*/

use eyre::{bail, Result};
Expand All @@ -9,8 +10,8 @@ use std::io::Cursor;
use std::path::PathBuf;

fn main() -> Result<()> {
// Read audio data from the file
let audio_data: &[u8] = include_bytes!("../samples/16hz_mono_pcm_s16le.wav");
let file_path = std::env::args().nth(1).expect("Missing file path argument");
let audio_data = std::fs::read(file_path)?;

// Use Cursor to create a reader from the byte slice
let cursor = Cursor::new(audio_data);
Expand Down
9 changes: 3 additions & 6 deletions examples/speaker_id.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
/*
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/nemo_en_speakerverification_speakernet.onnx
wget https://github.com/thewh1teagle/sherpa-rs/releases/download/v0.1.0/biden.wav -O biden.wav
wget https://github.com/thewh1teagle/sherpa-rs/releases/download/v0.1.0/obama.wav -O obama.wav
cargo run --example speaker_id
*/
use eyre::{bail, Result};
Expand Down Expand Up @@ -29,12 +31,7 @@ fn main() -> Result<()> {
env_logger::init();

// Define paths to the audio files
let audio_files = vec![
"samples/obama.wav",
"samples/trump.wav",
"samples/biden.wav",
"samples/biden1.wav",
];
let audio_files = vec!["samples/obama.wav", "biden.wav"];

// Create the extractor configuration and extractor
let mut model_path = PathBuf::from(std::env::current_dir()?);
Expand Down
3 changes: 2 additions & 1 deletion examples/transcribe.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
tar xvf sherpa-onnx-whisper-tiny.tar.bz2
wget https://github.com/thewh1teagle/sherpa-rs/releases/download/v0.1.0/motivation.wav -O motivation.wav
cargo run --example transcribe
*/

Expand All @@ -26,7 +27,7 @@ fn read_audio_file(path: &str) -> Result<(i32, Vec<f32>)> {
}

fn main() -> Result<()> {
let (sample_rate, samples) = read_audio_file("samples/motivation.wav")?;
let (sample_rate, samples) = read_audio_file("motivation.wav")?;

// Check if the sample rate is 16000
if sample_rate != 16000 {
Expand Down
5 changes: 3 additions & 2 deletions examples/vad_segment.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
/*
wget https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
wget https://github.com/thewh1teagle/sherpa-rs/releases/download/v0.1.0/motivation.wav -O motivation.wav
cargo run --example vad_segment
*/
use eyre::{bail, Result};
use sherpa_rs::vad::{Vad, VadConfig};
use std::io::Cursor;

fn main() -> Result<()> {
// Read audio data from the file
let audio_data: &[u8] = include_bytes!("../samples/motivation.wav");
let file_path = std::env::args().nth(1).expect("Missing file path argument");
let audio_data = std::fs::read(file_path)?;

let cursor = Cursor::new(audio_data);
let mut reader = hound::WavReader::new(cursor)?;
Expand Down
Binary file removed samples/16hz_mono_pcm_s16le.wav
Binary file not shown.
Binary file removed samples/biden.wav
Binary file not shown.
Binary file removed samples/biden1.wav
Binary file not shown.
Binary file removed samples/motivation.wav
Binary file not shown.
Binary file removed samples/obama.wav
Binary file not shown.
Binary file removed samples/trump.wav
Binary file not shown.
41 changes: 40 additions & 1 deletion src/embedding_manager.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
use eyre::{bail, Result};
use std::ffi::{CStr, CString};

#[derive(Debug)]
use crate::cstr_to_string;

#[derive(Debug, Clone)]
pub struct EmbeddingManager {
pub(crate) manager: *const sherpa_rs_sys::SherpaOnnxSpeakerEmbeddingManager,
}

#[derive(Debug, Clone)]
pub struct SpeakerMatch {
pub name: String,
pub score: f32,
}

impl EmbeddingManager {
pub fn new(dimension: i32) -> Self {
unsafe {
Expand All @@ -29,6 +37,37 @@ impl EmbeddingManager {
}
}

pub fn get_best_matches(
&mut self,
embedding: &[f32],
threshold: f32,
n: i32,
) -> Vec<SpeakerMatch> {
unsafe {
let result_ptr = sherpa_rs_sys::SherpaOnnxSpeakerEmbeddingManagerGetBestMatches(
self.manager,
embedding.to_owned().as_mut_ptr(),
threshold,
n,
);
if result_ptr.is_null() {
return Vec::new();
}
let result = result_ptr.read();

let matches_c = std::slice::from_raw_parts(result.matches, result.count as usize);
let mut matches: Vec<SpeakerMatch> = Vec::new();
for i in 0..result.count {
let match_c = matches_c[i as usize];
let name = cstr_to_string!(match_c.name);
let score = match_c.score;
matches.push(SpeakerMatch { name, score });
}
sherpa_rs_sys::SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches(result_ptr);
matches
}
}

pub fn add(&mut self, name: String, embedding: &mut [f32]) -> Result<()> {
let name_cstr = CString::new(name.clone())?;

Expand Down
7 changes: 7 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,10 @@ macro_rules! cstr {
CString::new($s).expect("Failed to create CString")
};
}

#[macro_export]
macro_rules! cstr_to_string {
($ptr:expr) => {
std::ffi::CStr::from_ptr($ptr).to_string_lossy().to_string()
};
}
6 changes: 6 additions & 0 deletions src/vad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ impl Vad {
}
}

pub fn flush(&mut self) {
unsafe {
sherpa_rs_sys::SherpaOnnxVoiceActivityDetectorFlush(self.vad);
}
}

pub fn accept_waveform(&mut self, mut samples: Vec<f32>) {
let samples_ptr = samples.as_mut_ptr();
let samples_length = samples.len();
Expand Down

0 comments on commit ad636a6

Please sign in to comment.