@@ -28,7 +28,7 @@ template <typename Torus> uint64_t get_shared_mem_size_tgemm() {
28
28
// Initialize decomposition by performing rounding
29
29
// and decomposing one level of an array of Torus LWEs. Only
30
30
// decomposes the mask elements of the incoming LWEs.
31
- template <typename Torus, typename TorusVec >
31
+ template <typename Torus>
32
32
__global__ void decompose_vectorize_init (Torus const *lwe_in, Torus *lwe_out,
33
33
uint32_t lwe_dimension,
34
34
uint32_t num_lwe, uint32_t base_log,
@@ -63,7 +63,7 @@ __global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
63
63
// Continue decomposiion of an array of Torus elements in place. Supposes
64
64
// that the array contains already decomposed elements and
65
65
// computes the new decomposed level in place.
66
- template <typename Torus, typename TorusVec >
66
+ template <typename Torus>
67
67
__global__ void
68
68
decompose_vectorize_step_inplace (Torus *buffer_in, uint32_t lwe_dimension,
69
69
uint32_t num_lwe, uint32_t base_log,
@@ -101,7 +101,7 @@ decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
101
101
// This code is adapted by generalizing the 1d block-tiling
102
102
// kernel from https://github.com/siboehm/SGEMM_CUDA
103
103
// to any matrix dimension
104
- template <typename Torus, typename TorusVec >
104
+ template <typename Torus>
105
105
__global__ void tgemm (int M, int N, int K, const Torus *A, const Torus *B,
106
106
int stride_B, Torus *C) {
107
107
@@ -251,8 +251,8 @@ __global__ void polynomial_accumulate_monic_monomial_mul_many_neg_and_add_C(
251
251
degree, coeffIdx, polynomial_size, 1 , true );
252
252
}
253
253
254
- template <typename Torus, typename TorusVec >
255
- __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe (
254
+ template <typename Torus>
255
+ __host__ void host_packing_keyswitch_lwe_list_to_glwe (
256
256
cudaStream_t stream, uint32_t gpu_index, Torus *glwe_out,
257
257
Torus const *lwe_array_in, Torus const *fp_ksk_array, int8_t *fp_ks_buffer,
258
258
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
@@ -296,10 +296,8 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
296
296
dim3 threads_decomp (BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP);
297
297
298
298
// decompose first level
299
- decompose_vectorize_init<Torus, TorusVec>
300
- <<<grid_decomp, threads_decomp, 0 , stream>>> (lwe_array_in, d_mem_0,
301
- lwe_dimension, num_lwes,
302
- base_log, level_count);
299
+ decompose_vectorize_init<Torus><<<grid_decomp, threads_decomp, 0 , stream>>> (
300
+ lwe_array_in, d_mem_0, lwe_dimension, num_lwes, base_log, level_count);
303
301
check_cuda_error (cudaGetLastError ());
304
302
305
303
// gemm to ks the individual LWEs to GLWEs
@@ -310,23 +308,22 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
310
308
auto stride_KSK_buffer = glwe_accumulator_size * level_count;
311
309
312
310
uint32_t shared_mem_size = get_shared_mem_size_tgemm<Torus>();
313
- tgemm<Torus, TorusVec ><<<grid_gemm, threads_gemm, shared_mem_size, stream>>> (
311
+ tgemm<Torus><<<grid_gemm, threads_gemm, shared_mem_size, stream>>> (
314
312
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0, fp_ksk_array,
315
313
stride_KSK_buffer, d_mem_1);
316
314
check_cuda_error (cudaGetLastError ());
317
315
318
316
auto ksk_block_size = glwe_accumulator_size;
319
317
320
318
for (int li = 1 ; li < level_count; ++li) {
321
- decompose_vectorize_step_inplace<Torus, TorusVec >
319
+ decompose_vectorize_step_inplace<Torus>
322
320
<<<grid_decomp, threads_decomp, 0 , stream>>> (
323
321
d_mem_0, lwe_dimension, num_lwes, base_log, level_count);
324
322
check_cuda_error (cudaGetLastError ());
325
323
326
- tgemm<Torus, TorusVec>
327
- <<<grid_gemm, threads_gemm, shared_mem_size, stream>>> (
328
- num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
329
- fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1);
324
+ tgemm<Torus><<<grid_gemm, threads_gemm, shared_mem_size, stream>>> (
325
+ num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
326
+ fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1);
330
327
check_cuda_error (cudaGetLastError ());
331
328
}
332
329
0 commit comments