Skip to content

Commit c429a94

Browse files
committed
candle whisper basic version completed
1 parent 0980972 commit c429a94

File tree

6 files changed

+138
-78
lines changed

6 files changed

+138
-78
lines changed

examples/candle_whisper/Cargo.toml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,16 @@ serde = { version = "1.0", features = ["derive"] }
2626
silent = { path = "../../silent", features = ["full"] }
2727
symphonia = { version = "0.5.3", features = ["all"] }
2828
anyhow = "1.0.79"
29+
tokio = { version = "1.35.1", features = ["full"] }
30+
31+
#candle-core = { version = "0.3.2" }
32+
#candle-nn = { version = "0.3.2" }
33+
#candle-transformers = { version = "0.3.2" }
34+
# version = "0.3.2" is not working for metal
35+
candle-core = { git = "https://github.com/huggingface/candle" }
36+
candle-nn = { git = "https://github.com/huggingface/candle" }
37+
candle-transformers = { git = "https://github.com/huggingface/candle" }
2938

30-
candle-core = { version = "0.3.2" }
31-
candle-nn = { version = "0.3.2" }
32-
candle-transformers = { version = "0.3.2" }
3339
tokenizers = { version = "0.15.0", features = ["onig"] }
3440
rand = "0.8.5"
3541
serde_json = "1.0.109"

examples/candle_whisper/src/args.rs

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,39 +26,10 @@ pub(crate) struct Args {
2626
#[arg(long, default_value = "tiny.en")]
2727
pub(crate) model: WhichModel,
2828

29-
/// The input to be processed, in wav format, will default to `jfk.wav`. Alternatively
30-
/// this can be set to sample:jfk, sample:gb1, ... to fetch a sample from the following
31-
/// repo: https://huggingface.co/datasets/Narsil/candle_demo/
32-
#[arg(long)]
33-
pub(crate) input: String,
34-
3529
/// The seed to use when generating random samples.
3630
#[arg(long, default_value_t = 299792458)]
3731
pub(crate) seed: u64,
3832

39-
/// Enable tracing (generates a trace-timestamp.json file).
40-
#[arg(long)]
41-
tracing: bool,
42-
4333
#[arg(long)]
4434
quantized: bool,
45-
46-
/// Language.
47-
#[arg(long)]
48-
pub(crate) language: Option<String>,
49-
50-
/// Task, when no task is specified, the input tokens contain only the sot token which can
51-
/// improve things when in no-timestamp mode.
52-
#[arg(long)]
53-
pub(crate) task: Option<Task>,
54-
55-
/// Timestamps mode, this is not fully implemented yet.
56-
#[arg(long)]
57-
pub(crate) timestamps: bool,
58-
59-
/// Print the full DecodingResult structure rather than just the text.
60-
#[arg(long)]
61-
pub(crate) verbose: bool,
62-
#[arg(long, default_value_t = 0.0)]
63-
pub(crate) temperature: f64,
6435
}

examples/candle_whisper/src/handlers.rs

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ use candle_nn::VarBuilder;
2020
use candle_transformers::models::whisper::{self as m, audio, Config};
2121
use silent::{Request, Response, Result as SilentResult, SilentError, StatusCode};
2222
use std::path::PathBuf;
23+
use std::sync::Arc;
2324
use tokenizers::Tokenizer;
25+
use tokio::sync::Mutex;
2426

2527
use crate::pcm_decode::pcm_decode;
2628
use crate::types::{CreateTranscriptionRequest, CreateTranscriptionResponse};
@@ -40,6 +42,7 @@ pub(crate) struct WhisperModel {
4042
mel_filters: Vec<f32>,
4143
device: candle::Device,
4244
}
45+
4346
pub(crate) fn init_model(args: Args) -> Result<WhisperModel> {
4447
let device = device(args.cpu)?;
4548
let model_id = args.model_id;
@@ -75,23 +78,11 @@ pub(crate) fn init_model(args: Args) -> Result<WhisperModel> {
7578
})
7679
}
7780

78-
fn handle(req: CreateTranscriptionRequest) -> CreateTranscriptionResponse {
79-
CreateTranscriptionResponse::new(vec![], req.response_format.clone())
80-
}
81-
82-
async fn create_transcription(mut req: Request) -> SilentResult<Response> {
83-
let req = req.form_data().await?.try_into().map_err(|e| {
84-
SilentError::business_error(
85-
StatusCode::BAD_REQUEST,
86-
format!("failed to parse request: {}", e),
87-
)
88-
})?;
89-
let res = handle(req);
90-
Ok(res.into())
91-
}
92-
pub(crate) fn handle1(args: Args) -> Result<()> {
93-
let mut whisper_model = init_model(args.clone())?;
94-
let input = PathBuf::from(args.input);
81+
fn handle(
82+
req: CreateTranscriptionRequest,
83+
mut whisper_model: WhisperModel,
84+
) -> Result<CreateTranscriptionResponse> {
85+
let input = req.file.path().clone();
9586

9687
let pcm_data = pcm_decode(input);
9788
let config = whisper_model.config.clone();
@@ -111,7 +102,7 @@ pub(crate) fn handle1(args: Args) -> Result<()> {
111102
let mut model = whisper_model.model.clone();
112103
let tokenizer = whisper_model.tokenizer.clone();
113104
let device = whisper_model.device.clone();
114-
let language_token = match (args.model.is_multilingual(), args.language) {
105+
let language_token = match (req.model.is_multilingual(), req.language) {
115106
(true, None) => Some(multilingual::detect_language(&mut model, &tokenizer, &mel)?),
116107
(false, None) => None,
117108
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
@@ -126,16 +117,40 @@ pub(crate) fn handle1(args: Args) -> Result<()> {
126117
let mut dc = Decoder::new(
127118
model,
128119
tokenizer,
129-
args.seed,
120+
299792458,
130121
&device,
131122
language_token,
132-
args.task,
133-
args.timestamps,
134-
args.verbose,
135-
args.temperature,
123+
None,
124+
req.response_format.has_timestamps(),
125+
req.response_format.is_verbose(),
126+
req.temperature,
136127
)?;
137-
println!("starting decoding");
138-
let s = dc.run(&mel)?;
139-
println!("done: {:?}", s);
140-
Ok(())
128+
let segments = dc.run(&mel)?;
129+
Ok(CreateTranscriptionResponse::new(
130+
segments,
131+
req.response_format.clone(),
132+
))
133+
}
134+
135+
pub(crate) async fn create_transcription(mut req: Request) -> SilentResult<Response> {
136+
let whisper_model = req
137+
.configs()
138+
.get::<Arc<Mutex<WhisperModel>>>()
139+
.unwrap()
140+
.lock()
141+
.await
142+
.clone();
143+
let req = req.form_data().await?.try_into().map_err(|e| {
144+
SilentError::business_error(
145+
StatusCode::BAD_REQUEST,
146+
format!("failed to parse request: {}", e),
147+
)
148+
})?;
149+
let res = handle(req, whisper_model).map_err(|e| {
150+
SilentError::business_error(
151+
StatusCode::INTERNAL_SERVER_ERROR,
152+
format!("failed to create transcription: {}", e),
153+
)
154+
})?;
155+
Ok(res.into())
141156
}

examples/candle_whisper/src/main.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ mod pcm_decode;
33
mod types;
44

55
use crate::args::Args;
6-
use crate::handlers::{handle1, init_model};
6+
use crate::handlers::{create_transcription, init_model};
77
use clap::Parser;
88
use silent::prelude::*;
9-
use std::path::PathBuf;
10-
use tokenizers::Tokenizer;
9+
use std::sync::Arc;
10+
use tokio::sync::Mutex;
1111

1212
mod args;
1313
mod decoder;
@@ -19,7 +19,11 @@ fn main() {
1919
logger::fmt().with_max_level(Level::INFO).init();
2020
let args = Args::parse();
2121
let mut configs = Configs::default();
22-
handle1(args).unwrap();
23-
// let route = Route::new("").get(|_req| async { Ok("hello world") });
24-
// Server::new().run(route);
22+
let whisper_model = init_model(args.clone()).expect("failed to initialize model");
23+
configs.insert(Arc::new(Mutex::new(whisper_model)));
24+
let route = Route::new("/v1/audio/transcriptions").post(create_transcription);
25+
Server::new()
26+
.with_configs(configs)
27+
.bind("0.0.0.0:8000".parse().unwrap())
28+
.run(route);
2529
}

examples/candle_whisper/src/model.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use anyhow::{bail, Error};
12
use candle_core::{self as candle, Tensor};
23
use candle_transformers::models::whisper::{self as m, Config};
34
use clap::ValueEnum;
@@ -26,6 +27,29 @@ pub(crate) enum WhichModel {
2627
DistilLargeV2,
2728
}
2829

30+
impl TryFrom<String> for WhichModel {
31+
type Error = Error;
32+
33+
fn try_from(value: String) -> Result<Self, Self::Error> {
34+
Ok(match value {
35+
s if s == "tiny" => Self::Tiny,
36+
s if s == "tiny.en" => Self::TinyEn,
37+
s if s == "base" => Self::Base,
38+
s if s == "base.en" => Self::BaseEn,
39+
s if s == "small" => Self::Small,
40+
s if s == "small.en" => Self::SmallEn,
41+
s if s == "medium" => Self::Medium,
42+
s if s == "medium.en" => Self::MediumEn,
43+
s if s == "large" => Self::Large,
44+
s if s == "large-v2" => Self::LargeV2,
45+
s if s == "large-v3" => Self::LargeV3,
46+
s if s == "distil-medium.en" => Self::DistilMediumEn,
47+
s if s == "distil-large-v2" => Self::DistilLargeV2,
48+
_ => bail!("invalid model"),
49+
})
50+
}
51+
}
52+
2953
impl WhichModel {
3054
pub(crate) fn is_multilingual(&self) -> bool {
3155
match self {
@@ -61,6 +85,7 @@ impl WhichModel {
6185
}
6286
}
6387
}
88+
6489
#[derive(Clone, Debug)]
6590
pub enum Model {
6691
Normal(m::model::Whisper),

examples/candle_whisper/src/types.rs

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::decoder::Segment;
22
use crate::model::WhichModel;
3+
use anyhow::Error;
34
use serde::Deserialize;
45
use serde_json::json;
56
use silent::prelude::{FilePart, FormData};
@@ -14,11 +15,39 @@ pub(crate) enum ResponseFormat {
1415
Vtt,
1516
}
1617

18+
impl From<String> for ResponseFormat {
19+
fn from(value: String) -> Self {
20+
match value.as_str() {
21+
"json" => Self::Json,
22+
"text" => Self::Text,
23+
"srt" => Self::Srt,
24+
"verbose_json" => Self::VerboseJson,
25+
"vtt" => Self::Vtt,
26+
_ => Self::Json,
27+
}
28+
}
29+
}
30+
31+
impl ResponseFormat {
32+
pub(crate) fn is_verbose(&self) -> bool {
33+
match self {
34+
Self::Json | Self::Text => false,
35+
Self::VerboseJson | Self::Srt | Self::Vtt => true,
36+
}
37+
}
38+
pub(crate) fn has_timestamps(&self) -> bool {
39+
match self {
40+
Self::VerboseJson | Self::Json | Self::Text => false,
41+
Self::Srt | Self::Vtt => true,
42+
}
43+
}
44+
}
45+
1746
#[derive(Debug, Clone)]
1847
pub struct CreateTranscriptionRequest {
1948
// The audio file object (not file name) to transcribe, in one of these formats: wav.
2049
pub(crate) file: FilePart,
21-
// ID of the model to use. Only whisper-large-v3 is currently available.
50+
// ID of the model to use. Only large-v3 is currently available.
2251
pub(crate) model: WhichModel,
2352
// The language of the input audio. Supplying the input language in ISO-639-1 format will improve accuracy and latency.
2453
pub(crate) language: Option<String>,
@@ -42,16 +71,24 @@ impl TryFrom<&FormData> for CreateTranscriptionRequest {
4271
StatusCode::BAD_REQUEST,
4372
"file is required".to_string(),
4473
))?;
45-
let model = serde_json::from_str(&value.fields.get("model").cloned().ok_or(
46-
SilentError::business_error(StatusCode::BAD_REQUEST, "model is required".to_string()),
47-
)?)?;
48-
let response_format = serde_json::from_str(
49-
&value
50-
.fields
51-
.get("response_format")
52-
.cloned()
53-
.unwrap_or("json".to_string()),
54-
)?;
74+
let model = value
75+
.fields
76+
.get("model")
77+
.cloned()
78+
.ok_or(SilentError::business_error(
79+
StatusCode::BAD_REQUEST,
80+
"model is required".to_string(),
81+
))?
82+
.try_into()
83+
.map_err(|e: Error| {
84+
SilentError::business_error(StatusCode::BAD_REQUEST, e.to_string())
85+
})?;
86+
let response_format = value
87+
.fields
88+
.get("response_format")
89+
.cloned()
90+
.unwrap_or("json".to_string())
91+
.into();
5592
Ok(Self {
5693
file,
5794
model,
@@ -73,6 +110,7 @@ impl TryFrom<&FormData> for CreateTranscriptionRequest {
73110
})
74111
}
75112
}
113+
76114
pub struct CreateTranscriptionResponse {
77115
segments: Vec<Segment>,
78116
format: ResponseFormat,
@@ -88,18 +126,19 @@ impl From<CreateTranscriptionResponse> for Response {
88126
fn from(value: CreateTranscriptionResponse) -> Self {
89127
match value.format {
90128
ResponseFormat::Json => json!({
91-
"text": value.segments.iter().map(|s| s.text()).collect::<Vec<_>>(),
129+
"text": value.segments.iter().map(|s| s.text()).collect::<Vec<_>>().join(""),
92130
})
93131
.into(),
94132
ResponseFormat::Text => value
95133
.segments
96134
.iter()
97135
.map(|s| s.text())
98136
.collect::<Vec<_>>()
137+
.join("")
99138
.into(),
100139
ResponseFormat::Srt => !unimplemented!("Srt"),
101140
ResponseFormat::VerboseJson => json!({
102-
"text": value.segments.iter().map(|s| s.text()).collect::<Vec<_>>(),
141+
"text": value.segments.iter().map(|s| s.text()).collect::<Vec<_>>().join(""),
103142
})
104143
.into(),
105144
ResponseFormat::Vtt => !unimplemented!("Vtt"),

0 commit comments

Comments
 (0)