Skip to content

Commit 2996f65

Browse files
iboBiThalay
authored andcommitted
whisper : add CUDA-specific computation mel spectrograms (ggml-org#2206)
* whisper : use polymorphic class to calculate mel spectrogram * whisper : add cuda-specific mel spectrogram calculation * whisper : conditionally compile cufftGetErrorString to avoid warnings * build : add new files to makefile * ruby : add new files to conf script * build : fix typo in makefile * whisper : suppress cub warning for deprecated C++ std in whisper-mel-cuda
1 parent 3189596 commit 2996f65

File tree

8 files changed

+497
-99
lines changed

8 files changed

+497
-99
lines changed

CMakeLists.txt

+7-3
Original file line numberDiff line numberDiff line change
@@ -337,12 +337,12 @@ if (WHISPER_CUDA)
337337
if (WHISPER_STATIC)
338338
if (WIN32)
339339
# As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library
340-
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
340+
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt CUDA::cufft)
341341
else ()
342-
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
342+
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static CUDA::cufft_static)
343343
endif()
344344
else()
345-
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
345+
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cufft)
346346
endif()
347347

348348
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cuda_driver)
@@ -629,6 +629,10 @@ add_library(${TARGET}
629629
whisper.cpp
630630
)
631631

632+
if (WHISPER_CUDA)
633+
target_sources(${TARGET} PRIVATE whisper-mel-cuda.cu)
634+
endif()
635+
632636
include_directories (
633637
.
634638
)

Makefile

+6-3
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,8 @@ ifdef WHISPER_CUDA
286286

287287
CFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
288288
CXXFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
289-
LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
290-
WHISPER_OBJ += ggml-cuda.o
289+
LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lcufft -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
290+
WHISPER_OBJ += ggml-cuda.o whisper-mel-cuda.o
291291
WHISPER_OBJ += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
292292
NVCC = nvcc
293293
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=$(CUDA_ARCH_FLAG)
@@ -299,6 +299,9 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h
299299
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
300300
endif
301301

302+
whisper-mel-cuda.o: whisper-mel-cuda.cu whisper.h ggml.h ggml-backend.h whisper-mel.hpp whisper-mel-cuda.hpp
303+
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
304+
302305
ifdef WHISPER_HIPBLAS
303306
ROCM_PATH ?= /opt/rocm
304307
HIPCC ?= $(ROCM_PATH)/bin/hipcc
@@ -404,7 +407,7 @@ ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
404407

405408
WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o
406409

407-
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
410+
whisper.o: whisper.cpp whisper.h whisper-mel.hpp ggml.h ggml-cuda.h
408411
$(CXX) $(CXXFLAGS) -c $< -o $@
409412

410413
ifndef WHISPER_COREML

bindings/ruby/ext/extconf.rb

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
require 'mkmf'
22
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .")
33
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .")
4+
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper-mel.hpp')} .")
45
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .")
56
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .")
67
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-impl.h')} .")

0 commit comments

Comments
 (0)