From e130b666425879af4b538f2441f741cc70b6f9d7 Mon Sep 17 00:00:00 2001 From: Borislav Stanimirov Date: Wed, 29 May 2024 19:09:21 +0300 Subject: [PATCH] whisper: use global cache for sin/cos vals and Hann window (#2194) - also rename Hanning to Hann as it's named after Julius von Hann as per Wikipedia --- whisper.cpp | 97 +++++++++++++++++++++++++++++------------------------ 1 file changed, 54 insertions(+), 43 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 7b8c683fca7..a22da8896bb 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2857,20 +2857,44 @@ static std::string to_timestamp(int64_t t, bool comma = false) { } #define SIN_COS_N_COUNT WHISPER_N_FFT -static float sin_vals[SIN_COS_N_COUNT]; -static float cos_vals[SIN_COS_N_COUNT]; +namespace { +struct whisper_global_cache { + // In FFT, we frequently use sine and cosine operations with the same values. + // We can use precalculated values to speed up the process. + float sin_vals[SIN_COS_N_COUNT]; + float cos_vals[SIN_COS_N_COUNT]; + + // Hann window (Use cosf to eliminate difference) + // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html + // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 + float hann_window[WHISPER_N_FFT]; + float hann_window2x[WHISPER_N_FFT * 2]; + + whisper_global_cache() { + fill_sin_cos_table(); +#define FILL_HANN_WINDOW(arr) fill_hann_window(sizeof(arr) / sizeof(arr[0]), true, arr) + FILL_HANN_WINDOW(hann_window); + FILL_HANN_WINDOW(hann_window2x); + } + + void fill_sin_cos_table() { + for (int i = 0; i < SIN_COS_N_COUNT; i++) { + double theta = (2 * M_PI * i) / SIN_COS_N_COUNT; + sin_vals[i] = sinf(theta); + cos_vals[i] = cosf(theta); + } + } -// In FFT, we frequently use sine and cosine operations with the same values. -// We can use precalculated values to speed up the process. -static void fill_sin_cos_table() { - static bool is_filled = false; - if (is_filled) return; - for (int i = 0; i < SIN_COS_N_COUNT; i++) { - double theta = (2*M_PI*i)/SIN_COS_N_COUNT; - sin_vals[i] = sinf(theta); - cos_vals[i] = cosf(theta); + void fill_hann_window(int length, bool periodic, float* output) { + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + } } - is_filled = true; +} global_cache; } // naive Discrete Fourier Transform @@ -2888,8 +2912,8 @@ static void dft(const std::vector & in, std::vector & out) { for (int n = 0; n < N; n++) { int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N - re += in[n]*cos_vals[idx]; // cos(t) - im -= in[n]*sin_vals[idx]; // sin(t) + re += in[n]*global_cache.cos_vals[idx]; // cos(t) + im -= in[n]*global_cache.sin_vals[idx]; // sin(t) } out[k*2 + 0] = re; @@ -2940,8 +2964,8 @@ static void fft(const std::vector & in, std::vector & out) { const int sin_cos_step = SIN_COS_N_COUNT / N; for (int k = 0; k < N/2; k++) { int idx = k * sin_cos_step; // t = 2*M_PI*k/N - float re = cos_vals[idx]; // cos(t) - float im = -sin_vals[idx]; // sin(t) + float re = global_cache.cos_vals[idx]; // cos(t) + float im = -global_cache.sin_vals[idx]; // sin(t) float re_odd = odd_fft[2*k + 0]; float im_odd = odd_fft[2*k + 1]; @@ -2954,22 +2978,7 @@ static void fft(const std::vector & in, std::vector & out) { } } -static bool hann_window(int length, bool periodic, std::vector & output) { - if (output.size() < static_cast(length)) { - output.resize(length); - } - int offset = -1; - if (periodic) { - offset = 0; - } - for (int i = 0; i < length; i++) { - output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset))); - } - - return true; -} - -static void log_mel_spectrogram_worker_thread(int ith, const std::vector & hann, const std::vector & samples, +static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples, int n_samples, int frame_size, int frame_step, int n_threads, const whisper_filters & filters, whisper_mel & mel) { std::vector fft_in(frame_size, 0.0); @@ -2984,7 +2993,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) { const int offset = i * frame_step; - // apply Hanning window (~10% faster) + // apply Hann window (~10% faster) for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) { fft_in[j] = hann[j] * samples[offset + j]; } @@ -3051,12 +3060,16 @@ static bool log_mel_spectrogram( whisper_mel & mel) { const int64_t t_start_us = ggml_time_us(); - // Hanning window (Use cosf to eliminate difference) - // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html - // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 - std::vector hann; - hann_window(frame_size, true, hann); - + // Hann window + const float * hann = nullptr; + if (frame_size == WHISPER_N_FFT) { + hann = global_cache.hann_window; + } else if (frame_size == 2 * WHISPER_N_FFT) { + hann = global_cache.hann_window2x; + } else { + WHISPER_ASSERT(false && "Unsupported frame_size"); + return false; + } // Calculate the length of padding int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30; @@ -3086,7 +3099,7 @@ static bool log_mel_spectrogram( std::vector workers(n_threads - 1); for (int iw = 0; iw < n_threads - 1; ++iw) { workers[iw] = std::thread( - log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded, + log_mel_spectrogram_worker_thread, iw + 1, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, std::cref(filters), std::ref(mel)); } @@ -3246,8 +3259,6 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) { #endif struct whisper_state * whisper_init_state(whisper_context * ctx) { - fill_sin_cos_table(); - whisper_state * state = new whisper_state; state->backend = whisper_backend_init(ctx->params); @@ -7235,7 +7246,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw( // operation (after median filter) // IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims // OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims - w = ggml_norm(gctx, w, 1e-9); + w = ggml_norm(gctx, w, 1e-9f); w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3); // Pass median filter - this is done over AUDIO_TOKENS dimension.