Skip to content

Commit ffef323

Browse files
authored
whisper : add CUDA-specific computation mel spectrograms (ggml-org#2206)
* whisper : use polymorphic class to calculate mel spectrogram * whisper : add cuda-specific mel spectrogram calculation * whisper : conditionally compile cufftGetErrorString to avoid warnings * build : add new files to makefile * ruby : add new files to conf script * build : fix typo in makefile * whisper : suppress cub warning for deprecated C++ std in whisper-mel-cuda
1 parent af5833e commit ffef323

File tree

8 files changed

+497
-99
lines changed

8 files changed

+497
-99
lines changed

CMakeLists.txt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,12 +364,12 @@ if (WHISPER_CUDA)
364364
if (WHISPER_STATIC)
365365
if (WIN32)
366366
# As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library
367-
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
367+
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt CUDA::cufft)
368368
else ()
369-
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
369+
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static CUDA::cufft_static)
370370
endif()
371371
else()
372-
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
372+
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cufft)
373373
endif()
374374

375375
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cuda_driver)
@@ -679,6 +679,10 @@ add_library(${TARGET}
679679
whisper.cpp
680680
)
681681

682+
if (WHISPER_CUDA)
683+
target_sources(${TARGET} PRIVATE whisper-mel-cuda.cu)
684+
endif()
685+
682686
include_directories (
683687
.
684688
)

Makefile

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,8 @@ ifdef WHISPER_CUDA
286286

287287
CFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
288288
CXXFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
289-
LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
290-
WHISPER_OBJ += ggml-cuda.o
289+
LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lcufft -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
290+
WHISPER_OBJ += ggml-cuda.o whisper-mel-cuda.o
291291
WHISPER_OBJ += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
292292
NVCC = nvcc
293293
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=$(CUDA_ARCH_FLAG)
@@ -299,6 +299,9 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h
299299
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
300300
endif
301301

302+
whisper-mel-cuda.o: whisper-mel-cuda.cu whisper.h ggml.h ggml-backend.h whisper-mel.hpp whisper-mel-cuda.hpp
303+
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
304+
302305
ifdef WHISPER_HIPBLAS
303306
ROCM_PATH ?= /opt/rocm
304307
HIPCC ?= $(ROCM_PATH)/bin/hipcc
@@ -404,7 +407,7 @@ ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
404407

405408
WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o
406409

407-
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
410+
whisper.o: whisper.cpp whisper.h whisper-mel.hpp ggml.h ggml-cuda.h
408411
$(CXX) $(CXXFLAGS) -c $< -o $@
409412

410413
ifndef WHISPER_COREML

bindings/ruby/ext/extconf.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
require 'mkmf'
22
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .")
33
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .")
4+
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper-mel.hpp')} .")
45
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .")
56
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .")
67
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-impl.h')} .")

0 commit comments

Comments
 (0)