|
16 | 16 | #include "polynomial/polynomial_math.cuh"
|
17 | 17 | #include "programmable_bootstrap_cg_classic.cuh"
|
18 | 18 | #include "types/complex/operations.cuh"
|
| 19 | +#include <cuda/experimental/stf.cuh> |
19 | 20 | #include <vector>
|
20 | 21 |
|
| 22 | +namespace cudastf = cuda::experimental::stf; |
| 23 | + |
21 | 24 | template <typename Torus, class params>
|
22 | 25 | __device__ uint32_t calculates_monomial_degree(const Torus *lwe_array_group,
|
23 | 26 | uint32_t ggsw_idx,
|
@@ -683,46 +686,70 @@ __host__ void host_multi_bit_programmable_bootstrap(
|
683 | 686 | uint32_t base_log, uint32_t level_count, uint32_t num_samples,
|
684 | 687 | uint32_t num_many_lut, uint32_t lut_stride) {
|
685 | 688 |
|
| 689 | + // Generate a CUDA graph if the USE_CUDA_GRAPH is set to a non-null value |
| 690 | + const char *use_graph_env = getenv("USE_CUDA_GRAPH"); |
| 691 | + |
| 692 | + cudastf::context ctx(stream); |
| 693 | + if (use_graph_env && atoi(use_graph_env) != 0) { |
| 694 | + ctx = cudastf::graph_ctx(stream); |
| 695 | + } |
| 696 | + |
| 697 | + auto buffer_token = ctx.logical_token(); |
686 | 698 | auto lwe_chunk_size = buffer->lwe_chunk_size;
|
687 | 699 |
|
688 | 700 | for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor);
|
689 | 701 | lwe_offset += lwe_chunk_size) {
|
690 | 702 |
|
| 703 | + auto key_token = ctx.logical_token(); |
| 704 | + auto result_token = ctx.logical_token(); |
| 705 | + |
691 | 706 | // Compute a keybundle
|
692 |
| - execute_compute_keybundle<Torus, params>( |
693 |
| - stream, gpu_index, lwe_array_in, lwe_input_indexes, bootstrapping_key, |
694 |
| - buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size, |
695 |
| - grouping_factor, level_count, lwe_offset); |
| 707 | + ctx.task(key_token.write(), buffer_token.write()) |
| 708 | + .set_symbol("compute_keybundle") |
| 709 | + ->*[&](cudaStream_t stf_stream) { |
| 710 | + execute_compute_keybundle<Torus, params>( |
| 711 | + stf_stream, gpu_index, lwe_array_in, lwe_input_indexes, |
| 712 | + bootstrapping_key, buffer, num_samples, lwe_dimension, |
| 713 | + glwe_dimension, polynomial_size, grouping_factor, |
| 714 | + level_count, lwe_offset); |
| 715 | + }; |
696 | 716 | // Accumulate
|
697 | 717 | uint32_t chunk_size = std::min(
|
698 | 718 | lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
|
699 | 719 | for (uint32_t j = 0; j < chunk_size; j++) {
|
700 |
| - bool is_first_iter = (j + lwe_offset) == 0; |
701 |
| - bool is_last_iter = |
702 |
| - (j + lwe_offset) + 1 == (lwe_dimension / grouping_factor); |
703 |
| - if (is_first_iter) { |
704 |
| - execute_step_one<Torus, params, true>( |
705 |
| - stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in, |
706 |
| - lwe_input_indexes, buffer, num_samples, lwe_dimension, |
707 |
| - glwe_dimension, polynomial_size, base_log, level_count); |
708 |
| - } else { |
709 |
| - execute_step_one<Torus, params, false>( |
710 |
| - stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in, |
711 |
| - lwe_input_indexes, buffer, num_samples, lwe_dimension, |
712 |
| - glwe_dimension, polynomial_size, base_log, level_count); |
713 |
| - } |
714 |
| - |
715 |
| - if (is_last_iter) { |
716 |
| - execute_step_two<Torus, params, true>( |
717 |
| - stream, gpu_index, lwe_array_out, lwe_output_indexes, buffer, |
718 |
| - num_samples, glwe_dimension, polynomial_size, level_count, j, |
719 |
| - num_many_lut, lut_stride); |
720 |
| - } else { |
721 |
| - execute_step_two<Torus, params, false>( |
722 |
| - stream, gpu_index, lwe_array_out, lwe_output_indexes, buffer, |
723 |
| - num_samples, glwe_dimension, polynomial_size, level_count, j, |
724 |
| - num_many_lut, lut_stride); |
725 |
| - } |
| 720 | + ctx.task(key_token.read(), buffer_token.rw(), result_token.rw()) |
| 721 | + .set_symbol("step_one_two") |
| 722 | + ->* |
| 723 | + [&](cudaStream_t stf_stream) { |
| 724 | + bool is_first_iter = (j + lwe_offset) == 0; |
| 725 | + bool is_last_iter = |
| 726 | + (j + lwe_offset) + 1 == (lwe_dimension / grouping_factor); |
| 727 | + if (is_first_iter) { |
| 728 | + execute_step_one<Torus, params, true>( |
| 729 | + stf_stream, gpu_index, lut_vector, lut_vector_indexes, |
| 730 | + lwe_array_in, lwe_input_indexes, buffer, num_samples, |
| 731 | + lwe_dimension, glwe_dimension, polynomial_size, base_log, |
| 732 | + level_count); |
| 733 | + } else { |
| 734 | + execute_step_one<Torus, params, false>( |
| 735 | + stf_stream, gpu_index, lut_vector, lut_vector_indexes, |
| 736 | + lwe_array_in, lwe_input_indexes, buffer, num_samples, |
| 737 | + lwe_dimension, glwe_dimension, polynomial_size, base_log, |
| 738 | + level_count); |
| 739 | + } |
| 740 | + |
| 741 | + if (is_last_iter) { |
| 742 | + execute_step_two<Torus, params, true>( |
| 743 | + stf_stream, gpu_index, lwe_array_out, lwe_output_indexes, |
| 744 | + buffer, num_samples, glwe_dimension, polynomial_size, |
| 745 | + level_count, j, num_many_lut, lut_stride); |
| 746 | + } else { |
| 747 | + execute_step_two<Torus, params, false>( |
| 748 | + stf_stream, gpu_index, lwe_array_out, lwe_output_indexes, |
| 749 | + buffer, num_samples, glwe_dimension, polynomial_size, |
| 750 | + level_count, j, num_many_lut, lut_stride); |
| 751 | + } |
| 752 | + }; |
726 | 753 | }
|
727 | 754 | }
|
728 | 755 | }
|
|
0 commit comments