Skip to content

Commit a44db3a

Browse files
committed
chore(gpu): stf experiment
1 parent 11a2919 commit a44db3a

File tree

4 files changed

+154
-51
lines changed

4 files changed

+154
-51
lines changed

backends/tfhe-cuda-backend/cuda/CMakeLists.txt

+26-1
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,34 @@ else()
8888
set(OPTIMIZATION_FLAGS "${OPTIMIZATION_FLAGS} -O3")
8989
endif()
9090

91+
# Fetch CPM.cmake directly from GitHub if not already present
92+
include(FetchContent)
93+
FetchContent_Declare(
94+
CPM
95+
GIT_REPOSITORY https://github.com/cpm-cmake/CPM.cmake
96+
GIT_TAG v0.38.5 # replace with the desired version or main for latest
97+
)
98+
FetchContent_MakeAvailable(CPM)
99+
100+
include(${cpm_SOURCE_DIR}/cmake/CPM.cmake)
101+
102+
# This will automatically clone CCCL from GitHub and make the exported cmake targets available
103+
cpmaddpackage(
104+
NAME
105+
CCCL
106+
GITHUB_REPOSITORY
107+
"nvidia/cccl"
108+
GIT_TAG
109+
"main"
110+
# The following is required to make the `CCCL::cudax` target available:
111+
OPTIONS
112+
"CCCL_ENABLE_UNSTABLE ON")
113+
91114
# in production, should use -arch=sm_70 --ptxas-options=-v to see register spills -lineinfo for better debugging
92115
set(CMAKE_CUDA_FLAGS
93116
"${CMAKE_CUDA_FLAGS} -ccbin ${CMAKE_CXX_COMPILER} ${OPTIMIZATION_FLAGS}\
94117
-std=c++17 --no-exceptions --expt-relaxed-constexpr -rdc=true \
95-
--use_fast_math -Xcompiler -fPIC")
118+
--use_fast_math -Xcompiler -fPIC -DCCCL_DISABLE_EXCEPTIONS -DCUDASTF_DISABLE_CODE_GENERATION")
96119

97120
set(INCLUDE_DIR include)
98121

@@ -101,6 +124,8 @@ enable_testing()
101124
add_subdirectory(tests_and_benchmarks)
102125
target_include_directories(tfhe_cuda_backend PRIVATE ${INCLUDE_DIR})
103126

127+
target_link_libraries(tfhe_cuda_backend PRIVATE CCCL::CCCL CCCL::cudax cuda)
128+
104129
# This is required for rust cargo build
105130
install(TARGETS tfhe_cuda_backend DESTINATION .)
106131

backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh

+37-10
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
#include "programmable_bootstrap.cuh"
1717
#include "programmable_bootstrap_multibit.cuh"
1818
#include "types/complex/operations.cuh"
19+
#include <cuda/experimental/stf.cuh>
1920
#include <vector>
2021

22+
namespace cudastf = cuda::experimental::stf;
23+
2124
template <typename Torus, class params, sharedMemDegree SMD>
2225
__global__ void __launch_bounds__(params::degree / params::opt)
2326
device_multi_bit_programmable_bootstrap_cg_accumulate(
@@ -384,25 +387,49 @@ __host__ void host_cg_multi_bit_programmable_bootstrap(
384387
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
385388
uint32_t num_many_lut, uint32_t lut_stride) {
386389

390+
// Generate a CUDA graph if the USE_CUDA_GRAPH is set to a non-null value
391+
const char *use_graph_env = getenv("USE_CUDA_GRAPH");
392+
393+
cudastf::context ctx(stream);
394+
if (use_graph_env && atoi(use_graph_env) != 0) {
395+
ctx = cudastf::graph_ctx(stream);
396+
}
397+
387398
auto lwe_chunk_size = buffer->lwe_chunk_size;
388399

400+
auto buffer_token = ctx.logical_token();
401+
389402
for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor);
390403
lwe_offset += lwe_chunk_size) {
391404

405+
auto key_token = ctx.logical_token();
406+
auto result_token = ctx.logical_token();
407+
392408
// Compute a keybundle
393-
execute_compute_keybundle<Torus, params>(
394-
stream, gpu_index, lwe_array_in, lwe_input_indexes, bootstrapping_key,
395-
buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size,
396-
grouping_factor, level_count, lwe_offset);
409+
ctx.task(key_token.write(), buffer_token.write())
410+
.set_symbol("compute_keybundle")
411+
->*[&](cudaStream_t stf_stream) {
412+
execute_compute_keybundle<Torus, params>(
413+
stf_stream, gpu_index, lwe_array_in, lwe_input_indexes,
414+
bootstrapping_key, buffer, num_samples, lwe_dimension,
415+
glwe_dimension, polynomial_size, grouping_factor,
416+
level_count, lwe_offset);
417+
};
397418

398419
// Accumulate
399-
execute_cg_external_product_loop<Torus, params>(
400-
stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
401-
lwe_input_indexes, lwe_array_out, lwe_output_indexes, buffer,
402-
num_samples, lwe_dimension, glwe_dimension, polynomial_size,
403-
grouping_factor, base_log, level_count, lwe_offset, num_many_lut,
404-
lut_stride);
420+
ctx.task(key_token.read(), buffer_token.rw(), result_token.write())
421+
.set_symbol("accumulate")
422+
->*
423+
[&](cudaStream_t stf_stream) {
424+
execute_cg_external_product_loop<Torus, params>(
425+
stf_stream, gpu_index, lut_vector, lut_vector_indexes,
426+
lwe_array_in, lwe_input_indexes, lwe_array_out,
427+
lwe_output_indexes, buffer, num_samples, lwe_dimension,
428+
glwe_dimension, polynomial_size, grouping_factor, base_log,
429+
level_count, lwe_offset, num_many_lut, lut_stride);
430+
};
405431
}
432+
ctx.finalize();
406433
}
407434

408435
// Verify if the grid size satisfies the cooperative group constraints

backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh

+57-30
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
#include "polynomial/polynomial_math.cuh"
1717
#include "programmable_bootstrap_cg_classic.cuh"
1818
#include "types/complex/operations.cuh"
19+
#include <cuda/experimental/stf.cuh>
1920
#include <vector>
2021

22+
namespace cudastf = cuda::experimental::stf;
23+
2124
template <typename Torus, class params>
2225
__device__ uint32_t calculates_monomial_degree(const Torus *lwe_array_group,
2326
uint32_t ggsw_idx,
@@ -683,46 +686,70 @@ __host__ void host_multi_bit_programmable_bootstrap(
683686
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
684687
uint32_t num_many_lut, uint32_t lut_stride) {
685688

689+
// Generate a CUDA graph if the USE_CUDA_GRAPH is set to a non-null value
690+
const char *use_graph_env = getenv("USE_CUDA_GRAPH");
691+
692+
cudastf::context ctx(stream);
693+
if (use_graph_env && atoi(use_graph_env) != 0) {
694+
ctx = cudastf::graph_ctx(stream);
695+
}
696+
697+
auto buffer_token = ctx.logical_token();
686698
auto lwe_chunk_size = buffer->lwe_chunk_size;
687699

688700
for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor);
689701
lwe_offset += lwe_chunk_size) {
690702

703+
auto key_token = ctx.logical_token();
704+
auto result_token = ctx.logical_token();
705+
691706
// Compute a keybundle
692-
execute_compute_keybundle<Torus, params>(
693-
stream, gpu_index, lwe_array_in, lwe_input_indexes, bootstrapping_key,
694-
buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size,
695-
grouping_factor, level_count, lwe_offset);
707+
ctx.task(key_token.write(), buffer_token.write())
708+
.set_symbol("compute_keybundle")
709+
->*[&](cudaStream_t stf_stream) {
710+
execute_compute_keybundle<Torus, params>(
711+
stf_stream, gpu_index, lwe_array_in, lwe_input_indexes,
712+
bootstrapping_key, buffer, num_samples, lwe_dimension,
713+
glwe_dimension, polynomial_size, grouping_factor,
714+
level_count, lwe_offset);
715+
};
696716
// Accumulate
697717
uint32_t chunk_size = std::min(
698718
lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
699719
for (uint32_t j = 0; j < chunk_size; j++) {
700-
bool is_first_iter = (j + lwe_offset) == 0;
701-
bool is_last_iter =
702-
(j + lwe_offset) + 1 == (lwe_dimension / grouping_factor);
703-
if (is_first_iter) {
704-
execute_step_one<Torus, params, true>(
705-
stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
706-
lwe_input_indexes, buffer, num_samples, lwe_dimension,
707-
glwe_dimension, polynomial_size, base_log, level_count);
708-
} else {
709-
execute_step_one<Torus, params, false>(
710-
stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
711-
lwe_input_indexes, buffer, num_samples, lwe_dimension,
712-
glwe_dimension, polynomial_size, base_log, level_count);
713-
}
714-
715-
if (is_last_iter) {
716-
execute_step_two<Torus, params, true>(
717-
stream, gpu_index, lwe_array_out, lwe_output_indexes, buffer,
718-
num_samples, glwe_dimension, polynomial_size, level_count, j,
719-
num_many_lut, lut_stride);
720-
} else {
721-
execute_step_two<Torus, params, false>(
722-
stream, gpu_index, lwe_array_out, lwe_output_indexes, buffer,
723-
num_samples, glwe_dimension, polynomial_size, level_count, j,
724-
num_many_lut, lut_stride);
725-
}
720+
ctx.task(key_token.read(), buffer_token.rw(), result_token.rw())
721+
.set_symbol("step_one_two")
722+
->*
723+
[&](cudaStream_t stf_stream) {
724+
bool is_first_iter = (j + lwe_offset) == 0;
725+
bool is_last_iter =
726+
(j + lwe_offset) + 1 == (lwe_dimension / grouping_factor);
727+
if (is_first_iter) {
728+
execute_step_one<Torus, params, true>(
729+
stf_stream, gpu_index, lut_vector, lut_vector_indexes,
730+
lwe_array_in, lwe_input_indexes, buffer, num_samples,
731+
lwe_dimension, glwe_dimension, polynomial_size, base_log,
732+
level_count);
733+
} else {
734+
execute_step_one<Torus, params, false>(
735+
stf_stream, gpu_index, lut_vector, lut_vector_indexes,
736+
lwe_array_in, lwe_input_indexes, buffer, num_samples,
737+
lwe_dimension, glwe_dimension, polynomial_size, base_log,
738+
level_count);
739+
}
740+
741+
if (is_last_iter) {
742+
execute_step_two<Torus, params, true>(
743+
stf_stream, gpu_index, lwe_array_out, lwe_output_indexes,
744+
buffer, num_samples, glwe_dimension, polynomial_size,
745+
level_count, j, num_many_lut, lut_stride);
746+
} else {
747+
execute_step_two<Torus, params, false>(
748+
stf_stream, gpu_index, lwe_array_out, lwe_output_indexes,
749+
buffer, num_samples, glwe_dimension, polynomial_size,
750+
level_count, j, num_many_lut, lut_stride);
751+
}
752+
};
726753
}
727754
}
728755
}

backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh

+34-10
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
#include "polynomial/polynomial_math.cuh"
1717
#include "programmable_bootstrap.cuh"
1818
#include "types/complex/operations.cuh"
19+
#include <cuda/experimental/stf.cuh>
1920
#include <vector>
2021

22+
namespace cudastf = cuda::experimental::stf;
23+
2124
template <typename Torus, class params, sharedMemDegree SMD>
2225
__global__ void __launch_bounds__(params::degree / params::opt)
2326
device_multi_bit_programmable_bootstrap_tbc_accumulate(
@@ -404,23 +407,44 @@ __host__ void host_tbc_multi_bit_programmable_bootstrap(
404407
uint32_t num_many_lut, uint32_t lut_stride) {
405408
cuda_set_device(gpu_index);
406409

410+
// Generate a CUDA graph if the USE_CUDA_GRAPH is set to a non-null value
411+
const char *use_graph_env = getenv("USE_CUDA_GRAPH");
412+
413+
cudastf::context ctx(stream);
414+
if (use_graph_env && atoi(use_graph_env) != 0) {
415+
ctx = cudastf::graph_ctx(stream);
416+
}
417+
407418
auto lwe_chunk_size = buffer->lwe_chunk_size;
419+
auto buffer_token = ctx.logical_token();
408420
for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor);
409421
lwe_offset += lwe_chunk_size) {
410422

423+
auto key_token = ctx.logical_token();
424+
auto result_token = ctx.logical_token();
411425
// Compute a keybundle
412-
execute_compute_keybundle<Torus, params>(
413-
stream, gpu_index, lwe_array_in, lwe_input_indexes, bootstrapping_key,
414-
buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size,
415-
grouping_factor, level_count, lwe_offset);
426+
ctx.task(key_token.write(), buffer_token.write())
427+
.set_symbol("compute_keybundle")
428+
->*[&](cudaStream_t stf_stream) {
429+
execute_compute_keybundle<Torus, params>(
430+
stf_stream, gpu_index, lwe_array_in, lwe_input_indexes,
431+
bootstrapping_key, buffer, num_samples, lwe_dimension,
432+
glwe_dimension, polynomial_size, grouping_factor,
433+
level_count, lwe_offset);
434+
};
416435

417436
// Accumulate
418-
execute_tbc_external_product_loop<Torus, params>(
419-
stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
420-
lwe_input_indexes, lwe_array_out, lwe_output_indexes, buffer,
421-
num_samples, lwe_dimension, glwe_dimension, polynomial_size,
422-
grouping_factor, base_log, level_count, lwe_offset, num_many_lut,
423-
lut_stride);
437+
ctx.task(key_token.read(), buffer_token.rw(), result_token.write())
438+
.set_symbol("accumulate")
439+
->*
440+
[&](cudaStream_t stf_stream) {
441+
execute_tbc_external_product_loop<Torus, params>(
442+
stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
443+
lwe_input_indexes, lwe_array_out, lwe_output_indexes, buffer,
444+
num_samples, lwe_dimension, glwe_dimension, polynomial_size,
445+
grouping_factor, base_log, level_count, lwe_offset, num_many_lut,
446+
lut_stride);
447+
};
424448
}
425449
}
426450

0 commit comments

Comments
 (0)