forked from ggerganov/whisper.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
whisper-mel-cuda.cu
339 lines (276 loc) · 12.4 KB
/
whisper-mel-cuda.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
#define CUB_IGNORE_DEPRECATED_CPP_DIALECT
#include "whisper-mel-cuda.hpp"
#include "whisper.h"
#include <cuda.h>
#include <cuda_runtime.h>
#include <cufft.h>
#include <cublas_v2.h>
#include <cuComplex.h>
#include <cub/device/device_reduce.cuh>
#include <device_launch_parameters.h>
#include <algorithm>
#if defined(_MSC_VER)
#pragma warning(disable: 4324) // added padding
#endif
#ifndef NDEBUG
# define DO_CHECKS 1
#else
# define DO_CHECKS 0
#endif
namespace {
#if DO_CHECKS
const char* cufftGetErrorString(cufftResult_t res) {
switch (res) {
case CUFFT_SUCCESS: return "The cuFFT operation was successful";
case CUFFT_INVALID_PLAN: return "cuFFT was passed an invalid plan handle";
case CUFFT_ALLOC_FAILED: return "cuFFT failed to allocate GPU or CPU memory";
case CUFFT_INVALID_TYPE: return "No longer used";
case CUFFT_INVALID_VALUE: return "User specified an invalid pointer or parameter";
case CUFFT_INTERNAL_ERROR: return "Driver or internal cuFFT library error";
case CUFFT_EXEC_FAILED: return "Failed to execute an FFT on the GPU";
case CUFFT_SETUP_FAILED: return "The cuFFT library failed to initialize";
case CUFFT_INVALID_SIZE: return "User specified an invalid transform size";
case CUFFT_UNALIGNED_DATA: return "No longer used";
case CUFFT_INCOMPLETE_PARAMETER_LIST: return "Missing parameters in call";
case CUFFT_INVALID_DEVICE: return "Execution of a plan was on different GPU than plan creation";
case CUFFT_PARSE_ERROR: return "Internal plan database error";
case CUFFT_NO_WORKSPACE: return "No workspace has been provided prior to plan execution";
case CUFFT_NOT_IMPLEMENTED: return "Function does not implement functionality for parameters given.";
case CUFFT_LICENSE_ERROR: return "Used in previous versions.";
case CUFFT_NOT_SUPPORTED: return "Operation is not supported for parameters given.";
default: return "Unknown error";
}
}
# define CUDA_CHECK_GEN(err, success, error_fn) \
do { \
auto err_ = (err); \
if (err_ != (success)) { \
fprintf(stderr, "%s %s:%d - %s\n", #err, __FILE__, __LINE__, error_fn(err_)); \
} \
} while (0)
#else
# define CUDA_CHECK_GEN(err, success, error_fn) err
#endif
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublasGetStatusString)
#define CUFFT_CHECK(err) CUDA_CHECK_GEN(err, CUFFT_SUCCESS, cufftGetErrorString)
__global__ void k_fill_stft_input(
const float * padded_samples,
const int n_frames,
const float * hann_window,
float * stft_in
) {
auto y = blockIdx.y * blockDim.y + threadIdx.y;
// if (y >= n_frames) return;
auto x = blockIdx.x * blockDim.x + threadIdx.x;
// if (x >= WHISPER_N_FFT) return;
auto line = padded_samples + y * WHISPER_HOP_LENGTH;
auto outLine = stft_in + y * WHISPER_N_FFT;
outLine[x] = line[x] * hann_window[x];
}
__global__ void k_calc_magnitudes(
const cuComplex* stft_out,
const int n_frames,
float * magnitudes
) {
auto y = blockIdx.y * blockDim.y + threadIdx.y;
// if (y >= n_frames) return;
auto x = blockIdx.x * blockDim.x + threadIdx.x;
// if (x >= WHISPER_N_FFT_HALF) return;
auto idx = y * WHISPER_N_FFT_HALF + x;
auto r = stft_out[idx].x;
auto i = stft_out[idx].y;
magnitudes[idx] = r * r + i * i;
}
__global__ void k_calc_log_mel(
const float * mel_data,
const int n_mel,
const float * max_val,
float * log_mel
) {
auto x = blockIdx.x * blockDim.x + threadIdx.x;
if (x >= n_mel) return;
float val = mel_data[x];
constexpr float e = 1e-10f;
if (val < e) val = e;
val = log10(val);
const float max = log10(*max_val) - 8.f;
if (val < max) val = max;
log_mel[x] = (val + 4) / 4;
}
void fill_stft_input(
const float * padded_samples,
int n_frames,
const float * hann_window,
float * stft_in,
cudaStream_t stream
) {
dim3 block(WHISPER_N_FFT, 1);
dim3 grid(1, n_frames);
k_fill_stft_input<<<grid, block, 0, stream>>>(padded_samples, n_frames, hann_window, stft_in);
}
void calc_magnitudes(
const cuComplex* stft_out,
int n_frames,
float * magnitudes,
cudaStream_t stream
) {
dim3 block(WHISPER_N_FFT_HALF, 1);
dim3 grid(1, n_frames);
k_calc_magnitudes<<<grid, block, 0, stream>>>(stft_out, n_frames, magnitudes);
}
constexpr auto LOG_MEL_PREFIX_SIZE = 256;
size_t get_log_mel_temp_storage_size() {
constexpr auto maxPaddedSamples = 2 * WHISPER_N_SAMPLES + WHISPER_N_FFT;
constexpr auto maxFrames = 1 + (maxPaddedSamples - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
constexpr auto maxMels = 160;
size_t nbytes = 0;
float * temp = nullptr;
cub::DeviceReduce::Max(nullptr, nbytes, temp, temp, maxFrames * maxMels);
return nbytes + LOG_MEL_PREFIX_SIZE;
}
void calc_log_mel(
const float * mel_data,
int n_mel,
void * tempStorage,
int tempStorageSize,
float * log_mel,
cudaStream_t stream
) {
float * max_val = reinterpret_cast<float *>(tempStorage);
void * maxTemp = reinterpret_cast<char*>(tempStorage) + LOG_MEL_PREFIX_SIZE;
size_t nbytes = size_t(tempStorageSize - LOG_MEL_PREFIX_SIZE);
cub::DeviceReduce::Max(maxTemp, nbytes, mel_data, max_val, n_mel, stream);
int block = 256;
int grid = (n_mel + block - 1) / block;
k_calc_log_mel<<<grid, block, 0, stream>>>(mel_data, n_mel, max_val, log_mel);
}
class mel_calc_cuda : public whisper_mel_calc {
const int m_n_mel;
ggml_backend_t m_backend = nullptr;
cudaStream_t m_stream = nullptr;
cublasHandle_t m_cublas_handle = nullptr;
float * m_hann_window = nullptr;
size_t m_cufft_workspace_size = 0;
void * m_cufft_workspace = nullptr;
float * m_filters = nullptr;
size_t m_log_mel_temp_storage_size = 0;
void * m_log_mel_temp_storage = nullptr;
public:
mel_calc_cuda(ggml_backend_t backend, const whisper_filters & filters)
: m_n_mel(filters.n_mel)
, m_backend(backend)
{
if (filters.n_fft != WHISPER_N_FFT_HALF) {
throw std::invalid_argument("MelFilters n_frames must be WHISPER_N_FFT_HALF");
}
assert(filters.data.size() == filters.n_mel * WHISPER_N_FFT_HALF);
CUDA_CHECK(cudaStreamCreate(&m_stream));
CUBLAS_CHECK(cublasCreate(&m_cublas_handle));
CUBLAS_CHECK(cublasSetMathMode(m_cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
CUBLAS_CHECK(cublasSetStream(m_cublas_handle, m_stream));
// create Hann window
{
auto hw = whisper_mel_calc::hann_window();
CUDA_CHECK(cudaMallocAsync(&m_hann_window, hw.len * sizeof(float), m_stream));
CUDA_CHECK(cudaMemcpyAsync(m_hann_window, hw.data, hw.len * sizeof(float), cudaMemcpyHostToDevice, m_stream));
}
// create working area
{
constexpr auto maxPaddedSamples = 2 * WHISPER_N_SAMPLES + WHISPER_N_FFT;
constexpr auto maxFrames = 1 + (maxPaddedSamples - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
CUFFT_CHECK(cufftEstimate1d(WHISPER_N_FFT, CUFFT_R2C, maxFrames, &m_cufft_workspace_size));
CUDA_CHECK(cudaMallocAsync(&m_cufft_workspace, m_cufft_workspace_size, m_stream));
}
// fill filters
{
auto& f = filters.data;
CUDA_CHECK(cudaMallocAsync(&m_filters, f.size() * sizeof(float), m_stream));
CUDA_CHECK(cudaMemcpyAsync(m_filters, f.data(), f.size() * sizeof(float), cudaMemcpyHostToDevice, m_stream));
}
{
m_log_mel_temp_storage_size = get_log_mel_temp_storage_size();
CUDA_CHECK(cudaMallocAsync(&m_log_mel_temp_storage, m_log_mel_temp_storage_size, m_stream));
}
}
~mel_calc_cuda() {
CUDA_CHECK(cudaStreamSynchronize(m_stream));
CUDA_CHECK(cudaStreamDestroy(m_stream));
CUDA_CHECK(cudaFree(m_hann_window));
CUDA_CHECK(cudaFree(m_cufft_workspace));
CUDA_CHECK(cudaFree(m_filters));
CUDA_CHECK(cudaFree(m_log_mel_temp_storage));
}
virtual whisper_mel calculate(whisper_span<const float> samples, int /*n_threads*/) const override {
const size_t mirror_pad = WHISPER_N_FFT / 2;
const size_t padded_size = samples.len + WHISPER_N_SAMPLES + WHISPER_N_FFT;
// pad
std::vector<float> padded_samples(padded_size);
std::reverse_copy(samples.data + 1, samples.data + 1 + mirror_pad, padded_samples.begin()); // reflect
std::copy(samples.data, samples.data + samples.len, padded_samples.begin() + mirror_pad); // copy
// fill the rest of the data
// it should canonically be mirrored at the end as well,
// but we just assume the last MEL_FRAME_SIZE/2 samples are zeros
std::fill(padded_samples.begin() + mirror_pad + samples.len, padded_samples.end(), 0.f);
const auto n_frames = 1 + (padded_samples.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
float * cu_padded_samples = nullptr;
CUDA_CHECK(cudaMallocAsync(&cu_padded_samples, padded_samples.size() * sizeof(float), m_stream));
CUDA_CHECK(cudaMemcpyAsync(cu_padded_samples, padded_samples.data(), padded_samples.size() * sizeof(float), cudaMemcpyHostToDevice, m_stream));
float * stft_in = nullptr; // contiguous buffer for stft input
CUDA_CHECK(cudaMallocAsync(&stft_in, n_frames * WHISPER_N_FFT * sizeof(float), m_stream));
fill_stft_input(cu_padded_samples, int(n_frames), m_hann_window, stft_in, m_stream);
cufftComplex* stft_out;
CUDA_CHECK(cudaMallocAsync(&stft_out, n_frames * WHISPER_N_FFT_HALF * sizeof(cufftComplex), m_stream));
cufftHandle plan;
CUFFT_CHECK(cufftCreate(&plan));
CUFFT_CHECK(cufftSetAutoAllocation(plan, 0));
{
size_t waSize;
CUFFT_CHECK(cufftMakePlan1d(plan, WHISPER_N_FFT, CUFFT_R2C, int(n_frames), &waSize));
assert(waSize <= m_cufft_workspace_size);
CUFFT_CHECK(cufftSetWorkArea(plan, m_cufft_workspace));
CUFFT_CHECK(cufftSetStream(plan, m_stream));
}
CUFFT_CHECK(cufftExecR2C(plan, stft_in, stft_out));
const auto n_mag_frames = n_frames - 1; // drop last frame
float * magnitudes;
CUDA_CHECK(cudaMallocAsync(&magnitudes, n_mag_frames * WHISPER_N_FFT_HALF * sizeof(float), m_stream));
calc_magnitudes(stft_out, int(n_mag_frames), magnitudes, m_stream);
float * mel_data = nullptr;
CUDA_CHECK(cudaMallocAsync(&mel_data, m_n_mel * n_mag_frames * sizeof(float), m_stream));
const float fone = 1.0f, fzero = 0.0f;
CUBLAS_CHECK(cublasSgemm(m_cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N,
int(n_mag_frames), m_n_mel, WHISPER_N_FFT_HALF,
&fone,
magnitudes, WHISPER_N_FFT_HALF,
m_filters, WHISPER_N_FFT_HALF,
&fzero,
mel_data, int(n_mag_frames)));
whisper_mel ret;
// Calculate semi-padded sample length to ensure compatibility
int n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
whisper_mel_init(ret, m_backend, int(n_mag_frames), n_len_org, m_n_mel);
assert(ggml_nbytes(ret.tensor) == m_n_mel * n_mag_frames * sizeof(float));
float* log_mels = reinterpret_cast<float*>(ret.tensor->data);
calc_log_mel(
mel_data, int(m_n_mel * n_mag_frames),
m_log_mel_temp_storage , int(m_log_mel_temp_storage_size),
log_mels, m_stream);
CUDA_CHECK(cudaStreamSynchronize(m_stream));
// cleanup
CUFFT_CHECK(cufftDestroy(plan));
CUDA_CHECK(cudaFreeAsync(mel_data, m_stream));
CUDA_CHECK(cudaFreeAsync(magnitudes, m_stream));
CUDA_CHECK(cudaFreeAsync(stft_out, m_stream));
CUDA_CHECK(cudaFreeAsync(stft_in, m_stream));
CUDA_CHECK(cudaFreeAsync(cu_padded_samples, m_stream));
return ret;
}
};
}
whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters) {
if (filters.n_fft != WHISPER_N_FFT_HALF) {
return nullptr;
}
return new mel_calc_cuda(backend, filters);
}