Skip to content

Commit 7e3a5fd

Browse files
committed
feat(gpu): add necessary entry points for 128 bit compression
1 parent d9a3bd4 commit 7e3a5fd

File tree

15 files changed

+316
-50
lines changed

15 files changed

+316
-50
lines changed

backends/tfhe-cuda-backend/build.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ fn main() {
7878
"cuda/include/integer/compression/compression.h",
7979
"cuda/include/integer/integer.h",
8080
"cuda/include/zk/zk.h",
81-
"cuda/include/keyswitch.h",
81+
"cuda/include/keyswitch/keyswitch.h",
8282
"cuda/include/keyswitch/ks_enums.h",
8383
"cuda/include/linear_algebra.h",
8484
"cuda/include/fft/fft128.h",

backends/tfhe-cuda-backend/cuda/include/ciphertext.h

+5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ void cuda_improve_noise_modulus_switch_64(
3131
void const *lwe_array_in, void const *encrypted_zeros, uint32_t lwe_size,
3232
uint32_t num_lwes, uint32_t num_zeros, double input_variance,
3333
double r_sigma, double bound, uint32_t log_modulus);
34+
35+
void cuda_glwe_sample_extract_128(
36+
void *stream, uint32_t gpu_index, void *lwe_array_out,
37+
void const *glwe_array_in, uint32_t const *nth_array, uint32_t num_nths,
38+
uint32_t lwe_per_glwe, uint32_t glwe_dimension, uint32_t polynomial_size);
3439
}
3540

3641
#endif

backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include "integer.h"
55
#include "integer/radix_ciphertext.cuh"
66
#include "integer/radix_ciphertext.h"
7-
#include "keyswitch.h"
7+
#include "keyswitch/keyswitch.h"
88
#include "pbs/programmable_bootstrap.cuh"
99
#include <cmath>
1010
#include <functional>

backends/tfhe-cuda-backend/cuda/include/keyswitch.h renamed to backends/tfhe-cuda-backend/cuda/include/keyswitch/keyswitch.h

+12
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ void cuda_packing_keyswitch_lwe_list_to_glwe_64(
3131
uint32_t output_polynomial_size, uint32_t base_log, uint32_t level_count,
3232
uint32_t num_lwes);
3333

34+
void scratch_packing_keyswitch_lwe_list_to_glwe_128(
35+
void *stream, uint32_t gpu_index, int8_t **fp_ks_buffer,
36+
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
37+
uint32_t num_lwes, bool allocate_gpu_memory);
38+
39+
void cuda_packing_keyswitch_lwe_list_to_glwe_128(
40+
void *stream, uint32_t gpu_index, void *glwe_array_out,
41+
void const *lwe_array_in, void const *fp_ksk_array, int8_t *fp_ks_buffer,
42+
uint32_t input_lwe_dimension, uint32_t output_glwe_dimension,
43+
uint32_t output_polynomial_size, uint32_t base_log, uint32_t level_count,
44+
uint32_t num_lwes);
45+
3446
void cleanup_packing_keyswitch_lwe_list_to_glwe(void *stream,
3547
uint32_t gpu_index,
3648
int8_t **fp_ks_buffer,

backends/tfhe-cuda-backend/cuda/src/crypto/ciphertext.cu

+42
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,45 @@ void cuda_improve_noise_modulus_switch_64(
9696
static_cast<const uint64_t *>(encrypted_zeros), lwe_size, num_lwes,
9797
num_zeros, input_variance, r_sigma, bound, log_modulus);
9898
}
99+
100+
void cuda_glwe_sample_extract_128(
101+
void *stream, uint32_t gpu_index, void *lwe_array_out,
102+
void const *glwe_array_in, uint32_t const *nth_array, uint32_t num_nths,
103+
uint32_t lwe_per_glwe, uint32_t glwe_dimension, uint32_t polynomial_size) {
104+
105+
switch (polynomial_size) {
106+
case 256:
107+
host_sample_extract<__uint128_t, AmortizedDegree<256>>(
108+
static_cast<cudaStream_t>(stream), gpu_index,
109+
(__uint128_t *)lwe_array_out, (__uint128_t const *)glwe_array_in,
110+
(uint32_t const *)nth_array, num_nths, lwe_per_glwe, glwe_dimension);
111+
break;
112+
case 512:
113+
host_sample_extract<__uint128_t, AmortizedDegree<512>>(
114+
static_cast<cudaStream_t>(stream), gpu_index,
115+
(__uint128_t *)lwe_array_out, (__uint128_t const *)glwe_array_in,
116+
(uint32_t const *)nth_array, num_nths, lwe_per_glwe, glwe_dimension);
117+
break;
118+
case 1024:
119+
host_sample_extract<__uint128_t, AmortizedDegree<1024>>(
120+
static_cast<cudaStream_t>(stream), gpu_index,
121+
(__uint128_t *)lwe_array_out, (__uint128_t const *)glwe_array_in,
122+
(uint32_t const *)nth_array, num_nths, lwe_per_glwe, glwe_dimension);
123+
break;
124+
case 2048:
125+
host_sample_extract<__uint128_t, AmortizedDegree<2048>>(
126+
static_cast<cudaStream_t>(stream), gpu_index,
127+
(__uint128_t *)lwe_array_out, (__uint128_t const *)glwe_array_in,
128+
(uint32_t const *)nth_array, num_nths, lwe_per_glwe, glwe_dimension);
129+
break;
130+
case 4096:
131+
host_sample_extract<__uint128_t, AmortizedDegree<4096>>(
132+
static_cast<cudaStream_t>(stream), gpu_index,
133+
(__uint128_t *)lwe_array_out, (__uint128_t const *)glwe_array_in,
134+
(uint32_t const *)nth_array, num_nths, lwe_per_glwe, glwe_dimension);
135+
break;
136+
default:
137+
PANIC("Cuda error: unsupported polynomial size. Supported "
138+
"N's are powers of two in the interval [256..4096].")
139+
}
140+
}

backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cu

+31-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
#include "fast_packing_keyswitch.cuh"
21
#include "keyswitch.cuh"
3-
#include "keyswitch.h"
4-
#include <cstdint>
5-
#include <stdio.h>
2+
#include "keyswitch/keyswitch.h"
3+
#include "packing_keyswitch.cuh"
64

75
/* Perform keyswitch on a batch of 32 bits input LWE ciphertexts.
86
* Head out to the equivalent operation on 64 bits for more details.
@@ -73,7 +71,7 @@ void cuda_packing_keyswitch_lwe_list_to_glwe_64(
7371
uint32_t output_polynomial_size, uint32_t base_log, uint32_t level_count,
7472
uint32_t num_lwes) {
7573

76-
host_fast_packing_keyswitch_lwe_list_to_glwe<uint64_t, ulonglong4>(
74+
host_packing_keyswitch_lwe_list_to_glwe<uint64_t>(
7775
static_cast<cudaStream_t>(stream), gpu_index,
7876
static_cast<uint64_t *>(glwe_array_out),
7977
static_cast<const uint64_t *>(lwe_array_in),
@@ -90,3 +88,31 @@ void cleanup_packing_keyswitch_lwe_list_to_glwe(void *stream,
9088
static_cast<cudaStream_t>(stream),
9189
gpu_index, gpu_memory_allocated);
9290
}
91+
92+
void scratch_packing_keyswitch_lwe_list_to_glwe_128(
93+
void *stream, uint32_t gpu_index, int8_t **fp_ks_buffer,
94+
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
95+
uint32_t num_lwes, bool allocate_gpu_memory) {
96+
scratch_packing_keyswitch_lwe_list_to_glwe<__uint128_t>(
97+
static_cast<cudaStream_t>(stream), gpu_index, fp_ks_buffer, lwe_dimension,
98+
glwe_dimension, polynomial_size, num_lwes, allocate_gpu_memory);
99+
}
100+
101+
/* Perform functional packing keyswitch on a batch of 64 bits input LWE
102+
* ciphertexts.
103+
*/
104+
105+
void cuda_packing_keyswitch_lwe_list_to_glwe_128(
106+
void *stream, uint32_t gpu_index, void *glwe_array_out,
107+
void const *lwe_array_in, void const *fp_ksk_array, int8_t *fp_ks_buffer,
108+
uint32_t input_lwe_dimension, uint32_t output_glwe_dimension,
109+
uint32_t output_polynomial_size, uint32_t base_log, uint32_t level_count,
110+
uint32_t num_lwes) {
111+
host_packing_keyswitch_lwe_list_to_glwe<__uint128_t>(
112+
static_cast<cudaStream_t>(stream), gpu_index,
113+
static_cast<__uint128_t *>(glwe_array_out),
114+
static_cast<const __uint128_t *>(lwe_array_in),
115+
static_cast<const __uint128_t *>(fp_ksk_array), fp_ks_buffer,
116+
input_lwe_dimension, output_glwe_dimension, output_polynomial_size,
117+
base_log, level_count, num_lwes);
118+
}

backends/tfhe-cuda-backend/cuda/src/crypto/fast_packing_keyswitch.cuh renamed to backends/tfhe-cuda-backend/cuda/src/crypto/packing_keyswitch.cuh

+12-15
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ template <typename Torus> uint64_t get_shared_mem_size_tgemm() {
2828
// Initialize decomposition by performing rounding
2929
// and decomposing one level of an array of Torus LWEs. Only
3030
// decomposes the mask elements of the incoming LWEs.
31-
template <typename Torus, typename TorusVec>
31+
template <typename Torus>
3232
__global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
3333
uint32_t lwe_dimension,
3434
uint32_t num_lwe, uint32_t base_log,
@@ -63,7 +63,7 @@ __global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
6363
// Continue decomposiion of an array of Torus elements in place. Supposes
6464
// that the array contains already decomposed elements and
6565
// computes the new decomposed level in place.
66-
template <typename Torus, typename TorusVec>
66+
template <typename Torus>
6767
__global__ void
6868
decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
6969
uint32_t num_lwe, uint32_t base_log,
@@ -101,7 +101,7 @@ decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
101101
// This code is adapted by generalizing the 1d block-tiling
102102
// kernel from https://github.com/siboehm/SGEMM_CUDA
103103
// to any matrix dimension
104-
template <typename Torus, typename TorusVec>
104+
template <typename Torus>
105105
__global__ void tgemm(int M, int N, int K, const Torus *A, const Torus *B,
106106
int stride_B, Torus *C) {
107107

@@ -251,8 +251,8 @@ __global__ void polynomial_accumulate_monic_monomial_mul_many_neg_and_add_C(
251251
degree, coeffIdx, polynomial_size, 1, true);
252252
}
253253

254-
template <typename Torus, typename TorusVec>
255-
__host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
254+
template <typename Torus>
255+
__host__ void host_packing_keyswitch_lwe_list_to_glwe(
256256
cudaStream_t stream, uint32_t gpu_index, Torus *glwe_out,
257257
Torus const *lwe_array_in, Torus const *fp_ksk_array, int8_t *fp_ks_buffer,
258258
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
@@ -296,10 +296,8 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
296296
dim3 threads_decomp(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP);
297297

298298
// decompose first level
299-
decompose_vectorize_init<Torus, TorusVec>
300-
<<<grid_decomp, threads_decomp, 0, stream>>>(lwe_array_in, d_mem_0,
301-
lwe_dimension, num_lwes,
302-
base_log, level_count);
299+
decompose_vectorize_init<Torus><<<grid_decomp, threads_decomp, 0, stream>>>(
300+
lwe_array_in, d_mem_0, lwe_dimension, num_lwes, base_log, level_count);
303301
check_cuda_error(cudaGetLastError());
304302

305303
// gemm to ks the individual LWEs to GLWEs
@@ -310,23 +308,22 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
310308
auto stride_KSK_buffer = glwe_accumulator_size * level_count;
311309

312310
uint32_t shared_mem_size = get_shared_mem_size_tgemm<Torus>();
313-
tgemm<Torus, TorusVec><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
311+
tgemm<Torus><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
314312
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0, fp_ksk_array,
315313
stride_KSK_buffer, d_mem_1);
316314
check_cuda_error(cudaGetLastError());
317315

318316
auto ksk_block_size = glwe_accumulator_size;
319317

320318
for (int li = 1; li < level_count; ++li) {
321-
decompose_vectorize_step_inplace<Torus, TorusVec>
319+
decompose_vectorize_step_inplace<Torus>
322320
<<<grid_decomp, threads_decomp, 0, stream>>>(
323321
d_mem_0, lwe_dimension, num_lwes, base_log, level_count);
324322
check_cuda_error(cudaGetLastError());
325323

326-
tgemm<Torus, TorusVec>
327-
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
328-
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
329-
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1);
324+
tgemm<Torus><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
325+
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
326+
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1);
330327
check_cuda_error(cudaGetLastError());
331328
}
332329

backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
#define CUDA_INTEGER_COMPRESSION_CUH
33

44
#include "ciphertext.h"
5-
#include "crypto/fast_packing_keyswitch.cuh"
65
#include "crypto/keyswitch.cuh"
6+
#include "crypto/packing_keyswitch.cuh"
77
#include "device.h"
88
#include "integer/compression/compression.h"
99
#include "integer/compression/compression_utilities.h"
@@ -116,7 +116,7 @@ host_integer_compress(cudaStream_t const *streams, uint32_t const *gpu_indexes,
116116
while (rem_lwes > 0) {
117117
auto chunk_size = min(rem_lwes, mem_ptr->lwe_per_glwe);
118118

119-
host_fast_packing_keyswitch_lwe_list_to_glwe<Torus, ulonglong4>(
119+
host_packing_keyswitch_lwe_list_to_glwe<Torus>(
120120
streams[0], gpu_indexes[0], glwe_out, lwe_subset, fp_ksk[0],
121121
fp_ks_buffer, input_lwe_dimension, compression_params.glwe_dimension,
122122
compression_params.polynomial_size, compression_params.ks_base_log,

backends/tfhe-cuda-backend/cuda/tests_and_benchmarks/include/setup_and_teardown.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#ifndef SETUP_AND_TEARDOWN_H
22
#define SETUP_AND_TEARDOWN_H
33

4+
#include "keyswitch/keyswitch.h"
45
#include "pbs/programmable_bootstrap.h"
56
#include "pbs/programmable_bootstrap_multibit.h"
67
#include <device.h>
7-
#include <keyswitch.h>
88
#include <utils.h>
99

1010
void programmable_bootstrap_classical_setup(

backends/tfhe-cuda-backend/src/bindings.rs

+41
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,19 @@ unsafe extern "C" {
6060
log_modulus: u32,
6161
);
6262
}
63+
unsafe extern "C" {
64+
pub fn cuda_glwe_sample_extract_128(
65+
stream: *mut ffi::c_void,
66+
gpu_index: u32,
67+
lwe_array_out: *mut ffi::c_void,
68+
glwe_array_in: *const ffi::c_void,
69+
nth_array: *const u32,
70+
num_nths: u32,
71+
lwe_per_glwe: u32,
72+
glwe_dimension: u32,
73+
polynomial_size: u32,
74+
);
75+
}
6376
pub const PBS_TYPE_MULTI_BIT: PBS_TYPE = 0;
6477
pub const PBS_TYPE_CLASSICAL: PBS_TYPE = 1;
6578
pub type PBS_TYPE = ffi::c_uint;
@@ -1429,6 +1442,34 @@ unsafe extern "C" {
14291442
num_lwes: u32,
14301443
);
14311444
}
1445+
unsafe extern "C" {
1446+
pub fn scratch_packing_keyswitch_lwe_list_to_glwe_128(
1447+
stream: *mut ffi::c_void,
1448+
gpu_index: u32,
1449+
fp_ks_buffer: *mut *mut i8,
1450+
lwe_dimension: u32,
1451+
glwe_dimension: u32,
1452+
polynomial_size: u32,
1453+
num_lwes: u32,
1454+
allocate_gpu_memory: bool,
1455+
);
1456+
}
1457+
unsafe extern "C" {
1458+
pub fn cuda_packing_keyswitch_lwe_list_to_glwe_128(
1459+
stream: *mut ffi::c_void,
1460+
gpu_index: u32,
1461+
glwe_array_out: *mut ffi::c_void,
1462+
lwe_array_in: *const ffi::c_void,
1463+
fp_ksk_array: *const ffi::c_void,
1464+
fp_ks_buffer: *mut i8,
1465+
input_lwe_dimension: u32,
1466+
output_glwe_dimension: u32,
1467+
output_polynomial_size: u32,
1468+
base_log: u32,
1469+
level_count: u32,
1470+
num_lwes: u32,
1471+
);
1472+
}
14321473
unsafe extern "C" {
14331474
pub fn cleanup_packing_keyswitch_lwe_list_to_glwe(
14341475
stream: *mut ffi::c_void,

backends/tfhe-cuda-backend/wrapper.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "cuda/include/integer/compression/compression.h"
44
#include "cuda/include/integer/integer.h"
55
#include "cuda/include/zk/zk.h"
6-
#include "cuda/include/keyswitch.h"
6+
#include "cuda/include/keyswitch/keyswitch.h"
77
#include "cuda/include/keyswitch/ks_enums.h"
88
#include "cuda/include/linear_algebra.h"
99
#include "cuda/include/fft/fft128.h"

tfhe/benches/core_crypto/ks_bench.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ mod cuda {
497497
use tfhe::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
498498
use tfhe::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
499499
use tfhe::core_crypto::gpu::{
500-
cuda_keyswitch_lwe_ciphertext, cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext,
500+
cuda_keyswitch_lwe_ciphertext, cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64,
501501
get_number_of_gpus, CudaStreams,
502502
};
503503
use tfhe::core_crypto::prelude::*;
@@ -796,7 +796,7 @@ mod cuda {
796796
{
797797
bench_group.bench_function(&bench_id, |b| {
798798
b.iter(|| {
799-
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext(
799+
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64(
800800
gpu_keys.pksk.as_ref().unwrap(),
801801
&d_input_lwe_list,
802802
&mut d_output_glwe,
@@ -879,7 +879,7 @@ mod cuda {
879879
((i, input_lwe_list), output_glwe_list),
880880
local_stream,
881881
)| {
882-
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext(
882+
cuda_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_64(
883883
gpu_keys_vec[i].pksk.as_ref().unwrap(),
884884
input_lwe_list,
885885
output_glwe_list,

0 commit comments

Comments
 (0)