Skip to content

Commit 85964f1

Browse files
committed
feat(integer): improve default neg
1 parent edcf930 commit 85964f1

File tree

2 files changed

+58
-29
lines changed

2 files changed

+58
-29
lines changed

tfhe/src/integer/server_key/radix_parallel/neg.rs

+14-12
Original file line numberDiff line numberDiff line change
@@ -78,19 +78,19 @@ impl ServerKey {
7878
where
7979
T: IntegerRadixCiphertext,
8080
{
81-
let mut tmp_ctxt;
82-
83-
let ct = if ctxt.block_carries_are_empty() {
84-
ctxt
81+
if ctxt.block_carries_are_empty() {
82+
let mut result = self.bitnot(ctxt);
83+
self.scalar_add_assign_parallelized(&mut result, 1);
84+
result
85+
} else if self.is_neg_possible(ctxt).is_ok() {
86+
let mut result = self.unchecked_neg(ctxt);
87+
self.full_propagate_parallelized(&mut result);
88+
result
8589
} else {
86-
tmp_ctxt = ctxt.clone();
87-
self.full_propagate_parallelized(&mut tmp_ctxt);
88-
&tmp_ctxt
89-
};
90-
91-
let mut ct = self.unchecked_neg(ct);
92-
self.full_propagate_parallelized(&mut ct);
93-
ct
90+
let mut cleaned_ctxt = ctxt.clone();
91+
self.full_propagate_parallelized(&mut cleaned_ctxt);
92+
self.neg_parallelized(&cleaned_ctxt)
93+
}
9494
}
9595

9696
pub fn overflowing_neg_parallelized<T>(&self, ctxt: &T) -> (T, BooleanBlock)
@@ -99,6 +99,8 @@ impl ServerKey {
9999
{
100100
let mut tmp_ctxt;
101101

102+
// As we want to compute the overflow we need a truly clean state
103+
// And so we cannot avoid the full_propagate like we may in non overflowing_block
102104
let ct = if ctxt.block_carries_are_empty() {
103105
ctxt
104106
} else {

tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_neg.rs

+44-17
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ use crate::integer::server_key::radix_parallel::tests_cases_unsigned::{FunctionE
44
use crate::integer::server_key::radix_parallel::tests_unsigned::{
55
nb_tests_for_params, nb_tests_smaller_for_params,
66
panic_if_any_block_info_exceeds_max_degree_or_noise, panic_if_any_block_is_not_clean,
7-
panic_if_any_block_values_exceeds_its_degree, unsigned_modulus, CpuFunctionExecutor,
8-
ExpectedDegrees, ExpectedNoiseLevels, MAX_NB_CTXT,
7+
panic_if_any_block_values_exceeds_its_degree, random_non_zero_value, unsigned_modulus,
8+
CpuFunctionExecutor, ExpectedDegrees, ExpectedNoiseLevels, MAX_NB_CTXT,
99
};
1010
use crate::integer::tests::create_parameterized_test;
1111
use crate::integer::{BooleanBlock, IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey};
@@ -174,7 +174,7 @@ where
174174
// Default Tests
175175
//=============================================================================
176176

177-
pub(crate) fn default_neg_test<P, T>(param: P, mut executor: T)
177+
pub(crate) fn default_neg_test<P, T>(param: P, mut neg: T)
178178
where
179179
P: Into<PBSParameters>,
180180
T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>,
@@ -185,28 +185,55 @@ where
185185
let cks = RadixClientKey::from((cks, NB_CTXT));
186186

187187
sks.set_deterministic_pbs_execution(true);
188-
let sks = Arc::new(sks);
188+
let sks = Arc::new(sks.clone());
189189

190190
let mut rng = rand::thread_rng();
191191

192-
let modulus = unsigned_modulus(cks.parameters().message_modulus(), NB_CTXT as u32);
192+
neg.setup(&cks, sks.clone());
193193

194-
executor.setup(&cks, sks);
194+
let cks: crate::integer::ClientKey = cks.into();
195195

196-
for _ in 0..nb_tests_smaller {
197-
let clear = rng.gen::<u64>() % modulus;
196+
for num_blocks in 1..MAX_NB_CTXT {
197+
let modulus = unsigned_modulus(cks.parameters().message_modulus(), num_blocks as u32);
198198

199-
let ctxt = cks.encrypt(clear);
200-
panic_if_any_block_is_not_clean(&ctxt, &cks);
199+
for _ in 0..nb_tests_smaller {
200+
let mut clear = rng.gen_range(0..modulus);
201+
let mut ctxt = cks.encrypt_radix(clear, num_blocks);
201202

202-
let ct_res = executor.execute(&ctxt);
203-
let tmp = executor.execute(&ctxt);
204-
assert!(ct_res.block_carries_are_empty());
205-
assert_eq!(ct_res, tmp);
203+
let ct_res = neg.execute(&ctxt);
204+
panic_if_any_block_is_not_clean(&ct_res, &cks);
206205

207-
let dec: u64 = cks.decrypt(&ct_res);
208-
let clear_result = clear.wrapping_neg() % modulus;
209-
assert_eq!(clear_result, dec);
206+
let dec_ct: u64 = cks.decrypt_radix(&ct_res);
207+
let expected = clear.wrapping_neg() % modulus;
208+
assert_eq!(
209+
dec_ct, expected,
210+
"Invalid result for neg({clear}),\n\
211+
Expected {expected}, got {dec_ct}\n\
212+
num_blocks: {num_blocks}, modulus: {modulus}"
213+
);
214+
215+
let ct_res2 = neg.execute(&ctxt);
216+
assert_eq!(ct_res, ct_res2, "Failed determinism check");
217+
218+
// Test with non clean carries
219+
let random_non_zero = random_non_zero_value(&mut rng, modulus);
220+
sks.unchecked_scalar_add_assign(&mut ctxt, random_non_zero);
221+
clear = clear.wrapping_add(random_non_zero) % modulus;
222+
223+
let ct_res = neg.execute(&ctxt);
224+
panic_if_any_block_is_not_clean(&ct_res, &cks);
225+
226+
let dec_ct: u64 = cks.decrypt_radix(&ct_res);
227+
let expected = clear.wrapping_neg() % modulus;
228+
assert_eq!(
229+
dec_ct, expected,
230+
"Invalid result for neg({clear}),\n\
231+
Expected {expected}, got {dec_ct}\n\
232+
num_blocks: {num_blocks}, modulus: {modulus}"
233+
);
234+
let ct_res2 = neg.execute(&ctxt);
235+
assert_eq!(ct_res, ct_res2, "Failed determinism check");
236+
}
210237
}
211238
}
212239

0 commit comments

Comments
 (0)