@@ -818,6 +818,8 @@ struct whisper_state {
818
818
819
819
whisper_decoder decoders[WHISPER_MAX_DECODERS];
820
820
821
+ ggml_backend_t backend = nullptr ;
822
+
821
823
// ggml-alloc:
822
824
// - stores meta info about the intermediate tensors into the `meta` buffers
823
825
// - stores the actual tensor data into the `data` buffers
@@ -2261,7 +2263,7 @@ static bool whisper_encode_internal(
2261
2263
}
2262
2264
2263
2265
if (!whisper_encode_external (wstate)) {
2264
- if (!ggml_graph_compute_helper (wctx .backend , gf, n_threads)) {
2266
+ if (!ggml_graph_compute_helper (wstate .backend , gf, n_threads)) {
2265
2267
return false ;
2266
2268
}
2267
2269
} else {
@@ -2284,7 +2286,7 @@ static bool whisper_encode_internal(
2284
2286
return false ;
2285
2287
}
2286
2288
2287
- if (!ggml_graph_compute_helper (wctx .backend , gf, n_threads)) {
2289
+ if (!ggml_graph_compute_helper (wstate .backend , gf, n_threads)) {
2288
2290
return false ;
2289
2291
}
2290
2292
}
@@ -2300,7 +2302,7 @@ static bool whisper_encode_internal(
2300
2302
return false ;
2301
2303
}
2302
2304
2303
- if (!ggml_graph_compute_helper (wctx .backend , gf, n_threads)) {
2305
+ if (!ggml_graph_compute_helper (wstate .backend , gf, n_threads)) {
2304
2306
return false ;
2305
2307
}
2306
2308
}
@@ -2801,7 +2803,7 @@ static bool whisper_decode_internal(
2801
2803
2802
2804
logits = gf->nodes [gf->n_nodes - 1 ];
2803
2805
2804
- if (!ggml_graph_compute_helper (wctx .backend , gf, n_threads)) {
2806
+ if (!ggml_graph_compute_helper (wstate .backend , gf, n_threads)) {
2805
2807
return false ;
2806
2808
}
2807
2809
}
@@ -3248,6 +3250,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3248
3250
3249
3251
whisper_state * state = new whisper_state;
3250
3252
3253
+ state->backend = whisper_backend_init (ctx->params );
3254
+ if (!state->backend ) {
3255
+ WHISPER_LOG_ERROR (" %s: whisper_backend_init() failed\n " , __func__);
3256
+ whisper_free_state (state);
3257
+ return nullptr ;
3258
+ }
3259
+
3251
3260
// at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
3252
3261
// in theory, there can be a case where this is not enough, but in practice it should always be enough
3253
3262
const int factor = 3 ;
@@ -3684,6 +3693,8 @@ void whisper_free_state(struct whisper_state * state) {
3684
3693
ggml_gallocr_free (state->alloc_cross .alloc );
3685
3694
ggml_gallocr_free (state->alloc_decode .alloc );
3686
3695
3696
+ ggml_backend_free (state->backend );
3697
+
3687
3698
// [EXPERIMENTAL] Token-level timestamps with DTW
3688
3699
aheads_masks_free (state->aheads_masks );
3689
3700
0 commit comments