13
13
# limitations under the License.
14
14
# ==============================================================================
15
15
"""Differentially private version of Keras optimizer v2."""
16
- from typing import Optional , Type
17
- import warnings
18
16
19
17
import tensorflow as tf
20
- from tensorflow_privacy .privacy .dp_query import dp_query
21
- from tensorflow_privacy .privacy .dp_query import gaussian_query
22
-
23
18
24
- def _normalize (microbatch_gradient : tf .Tensor ,
25
- num_microbatches : float ) -> tf .Tensor :
26
- """Normalizes `microbatch_gradient` by `num_microbatches`."""
27
- return tf .truediv (microbatch_gradient ,
28
- tf .cast (num_microbatches , microbatch_gradient .dtype ))
19
+ from tensorflow_privacy .privacy .dp_query import gaussian_query
29
20
30
21
31
- def make_keras_generic_optimizer_class (
32
- cls : Type [tf .keras .optimizers .Optimizer ]):
33
- """Returns a differentially private (DP) subclass of `cls`.
22
+ def make_keras_optimizer_class (cls ):
23
+ """Given a subclass of `tf.keras.optimizers.Optimizer`, returns a DP-SGD subclass of it.
34
24
35
25
Args:
36
26
cls: Class from which to derive a DP subclass. Should be a subclass of
37
27
`tf.keras.optimizers.Optimizer`.
28
+
29
+ Returns:
30
+ A DP-SGD subclass of `cls`.
38
31
"""
39
32
40
33
class DPOptimizerClass (cls ): # pylint: disable=empty-docstring
@@ -145,23 +138,24 @@ class DPOptimizerClass(cls): # pylint: disable=empty-docstring
145
138
146
139
def __init__ (
147
140
self ,
148
- dp_sum_query : dp_query .DPQuery ,
149
- num_microbatches : Optional [int ] = None ,
150
- gradient_accumulation_steps : int = 1 ,
141
+ l2_norm_clip ,
142
+ noise_multiplier ,
143
+ num_microbatches = None ,
144
+ gradient_accumulation_steps = 1 ,
151
145
* args , # pylint: disable=keyword-arg-before-vararg, g-doc-args
152
146
** kwargs ):
153
- """Initializes the DPOptimizerClass.
147
+ """Initialize the DPOptimizerClass.
154
148
155
149
Args:
156
- dp_sum_query: `DPQuery` object, specifying differential privacy
157
- mechanism to use .
150
+ l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
151
+ noise_multiplier: Ratio of the standard deviation to the clipping norm .
158
152
num_microbatches: Number of microbatches into which each minibatch is
159
- split. Default is `None` which means that number of microbatches is
160
- equal to batch size (i.e. each microbatch contains exactly one
153
+ split. Default is `None` which means that number of microbatches
154
+ is equal to batch size (i.e. each microbatch contains exactly one
161
155
example). If `gradient_accumulation_steps` is greater than 1 and
162
156
`num_microbatches` is not `None` then the effective number of
163
- microbatches is equal to `num_microbatches *
164
- gradient_accumulation_steps`.
157
+ microbatches is equal to
158
+ `num_microbatches * gradient_accumulation_steps`.
165
159
gradient_accumulation_steps: If greater than 1 then optimizer will be
166
160
accumulating gradients for this number of optimizer steps before
167
161
applying them to update model weights. If this argument is set to 1
@@ -171,13 +165,13 @@ def __init__(
171
165
"""
172
166
super ().__init__ (* args , ** kwargs )
173
167
self .gradient_accumulation_steps = gradient_accumulation_steps
168
+ self ._l2_norm_clip = l2_norm_clip
169
+ self ._noise_multiplier = noise_multiplier
174
170
self ._num_microbatches = num_microbatches
175
- self ._dp_sum_query = dp_sum_query
176
- self ._was_dp_gradients_called = False
177
- # We initialize the self.`_global_state` within the gradient functions
178
- # (and not here) because tensors must be initialized within the graph.
179
-
171
+ self ._dp_sum_query = gaussian_query .GaussianSumQuery (
172
+ l2_norm_clip , l2_norm_clip * noise_multiplier )
180
173
self ._global_state = None
174
+ self ._was_dp_gradients_called = False
181
175
182
176
def _create_slots (self , var_list ):
183
177
super ()._create_slots (var_list ) # pytype: disable=attribute-error
@@ -241,62 +235,66 @@ def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
241
235
"""DP-SGD version of base class method."""
242
236
243
237
self ._was_dp_gradients_called = True
244
- if self ._global_state is None :
245
- self ._global_state = self ._dp_sum_query .initial_global_state ()
246
-
247
238
# Compute loss.
248
239
if not callable (loss ) and tape is None :
249
240
raise ValueError ('`tape` is required when a `Tensor` loss is passed.' )
250
-
251
241
tape = tape if tape is not None else tf .GradientTape ()
252
242
253
- with tape :
254
- if callable ( loss ) :
243
+ if callable ( loss ) :
244
+ with tape :
255
245
if not callable (var_list ):
256
246
tape .watch (var_list )
257
247
258
248
loss = loss ()
259
- if self ._num_microbatches is None :
260
- num_microbatches = tf .shape (input = loss )[0 ]
261
- else :
262
- num_microbatches = self ._num_microbatches
263
- microbatch_losses = tf .reduce_mean (
264
- tf .reshape (loss , [num_microbatches , - 1 ]), axis = 1 )
265
-
266
- if callable (var_list ):
267
- var_list = var_list ()
249
+ if self ._num_microbatches is None :
250
+ num_microbatches = tf .shape (input = loss )[0 ]
251
+ else :
252
+ num_microbatches = self ._num_microbatches
253
+ microbatch_losses = tf .reduce_mean (
254
+ tf .reshape (loss , [num_microbatches , - 1 ]), axis = 1 )
255
+
256
+ if callable (var_list ):
257
+ var_list = var_list ()
258
+ else :
259
+ with tape :
260
+ if self ._num_microbatches is None :
261
+ num_microbatches = tf .shape (input = loss )[0 ]
262
+ else :
263
+ num_microbatches = self ._num_microbatches
264
+ microbatch_losses = tf .reduce_mean (
265
+ tf .reshape (loss , [num_microbatches , - 1 ]), axis = 1 )
268
266
269
267
var_list = tf .nest .flatten (var_list )
270
268
271
- sample_params = (
272
- self ._dp_sum_query .derive_sample_params (self ._global_state ))
273
-
274
269
# Compute the per-microbatch losses using helpful jacobian method.
275
270
with tf .keras .backend .name_scope (self ._name + '/gradients' ):
276
- jacobian_per_var = tape .jacobian (
271
+ jacobian = tape .jacobian (
277
272
microbatch_losses , var_list , unconnected_gradients = 'zero' )
278
273
279
- def process_microbatch (sample_state , microbatch_jacobians ):
280
- """Process one microbatch (record) with privacy helper."""
281
- sample_state = self ._dp_sum_query .accumulate_record (
282
- sample_params , sample_state , microbatch_jacobians )
283
- return sample_state
274
+ # Clip gradients to given l2_norm_clip.
275
+ def clip_gradients (g ):
276
+ return tf .clip_by_global_norm (g , self ._l2_norm_clip )[0 ]
284
277
285
- sample_state = self ._dp_sum_query .initial_sample_state (var_list )
286
- for idx in range (num_microbatches ):
287
- microbatch_jacobians_per_var = [
288
- jacobian [idx ] for jacobian in jacobian_per_var
289
- ]
290
- sample_state = process_microbatch (sample_state ,
291
- microbatch_jacobians_per_var )
278
+ clipped_gradients = tf .map_fn (clip_gradients , jacobian )
292
279
293
- grad_sums , self ._global_state , _ = (
294
- self ._dp_sum_query .get_noised_result (sample_state ,
295
- self ._global_state ))
296
- final_grads = tf .nest .map_structure (_normalize , grad_sums ,
297
- [num_microbatches ] * len (grad_sums ))
280
+ def reduce_noise_normalize_batch (g ):
281
+ # Sum gradients over all microbatches.
282
+ summed_gradient = tf .reduce_sum (g , axis = 0 )
298
283
299
- return list (zip (final_grads , var_list ))
284
+ # Add noise to summed gradients.
285
+ noise_stddev = self ._l2_norm_clip * self ._noise_multiplier
286
+ noise = tf .random .normal (
287
+ tf .shape (input = summed_gradient ), stddev = noise_stddev )
288
+ noised_gradient = tf .add (summed_gradient , noise )
289
+
290
+ # Normalize by number of microbatches and return.
291
+ return tf .truediv (noised_gradient ,
292
+ tf .cast (num_microbatches , tf .float32 ))
293
+
294
+ final_gradients = tf .nest .map_structure (reduce_noise_normalize_batch ,
295
+ clipped_gradients )
296
+
297
+ return list (zip (final_gradients , var_list ))
300
298
301
299
def get_gradients (self , loss , params ):
302
300
"""DP-SGD version of base class method."""
@@ -324,13 +322,17 @@ def process_microbatch(i, sample_state):
324
322
sample_state = self ._dp_sum_query .initial_sample_state (params )
325
323
for idx in range (self ._num_microbatches ):
326
324
sample_state = process_microbatch (idx , sample_state )
327
-
328
325
grad_sums , self ._global_state , _ = (
329
326
self ._dp_sum_query .get_noised_result (sample_state ,
330
327
self ._global_state ))
331
328
332
- final_grads = tf .nest .map_structure (
333
- _normalize , grad_sums , [self ._num_microbatches ] * len (grad_sums ))
329
+ def normalize (v ):
330
+ try :
331
+ return tf .truediv (v , tf .cast (self ._num_microbatches , tf .float32 ))
332
+ except TypeError :
333
+ return None
334
+
335
+ final_grads = tf .nest .map_structure (normalize , grad_sums )
334
336
335
337
return final_grads
336
338
@@ -366,87 +368,7 @@ def apply_gradients(self, *args, **kwargs):
366
368
return DPOptimizerClass
367
369
368
370
369
- def make_gaussian_query_optimizer_class (cls ):
370
- """Returns a differentially private optimizer using the `GaussianSumQuery`.
371
-
372
- Args:
373
- cls: `DPOptimizerClass`, the output of `make_keras_optimizer_class`.
374
- """
375
-
376
- def return_gaussian_query_optimizer (
377
- l2_norm_clip : float ,
378
- noise_multiplier : float ,
379
- num_microbatches : Optional [int ] = None ,
380
- gradient_accumulation_steps : int = 1 ,
381
- * args , # pylint: disable=keyword-arg-before-vararg, g-doc-args
382
- ** kwargs ):
383
- """Returns a `DPOptimizerClass` `cls` using the `GaussianSumQuery`.
384
-
385
- This function is a thin wrapper around
386
- `make_keras_optimizer_class.<locals>.DPOptimizerClass` which can be used to
387
- apply a `GaussianSumQuery` to any `DPOptimizerClass`.
388
-
389
- When combined with stochastic gradient descent, this creates the canonical
390
- DP-SGD algorithm of "Deep Learning with Differential Privacy"
391
- (see https://arxiv.org/abs/1607.00133).
392
-
393
- Args:
394
- l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
395
- noise_multiplier: Ratio of the standard deviation to the clipping norm.
396
- num_microbatches: Number of microbatches into which each minibatch is
397
- split. Default is `None` which means that number of microbatches is
398
- equal to batch size (i.e. each microbatch contains exactly one example).
399
- If `gradient_accumulation_steps` is greater than 1 and
400
- `num_microbatches` is not `None` then the effective number of
401
- microbatches is equal to `num_microbatches *
402
- gradient_accumulation_steps`.
403
- gradient_accumulation_steps: If greater than 1 then optimizer will be
404
- accumulating gradients for this number of optimizer steps before
405
- applying them to update model weights. If this argument is set to 1 then
406
- updates will be applied on each optimizer step.
407
- *args: These will be passed on to the base class `__init__` method.
408
- **kwargs: These will be passed on to the base class `__init__` method.
409
- """
410
- dp_sum_query = gaussian_query .GaussianSumQuery (
411
- l2_norm_clip , l2_norm_clip * noise_multiplier )
412
- return cls (
413
- dp_sum_query = dp_sum_query ,
414
- num_microbatches = num_microbatches ,
415
- gradient_accumulation_steps = gradient_accumulation_steps ,
416
- * args ,
417
- ** kwargs )
418
-
419
- return return_gaussian_query_optimizer
420
-
421
-
422
- def make_keras_optimizer_class (cls : Type [tf .keras .optimizers .Optimizer ]):
423
- """Returns a differentially private optimizer using the `GaussianSumQuery`.
424
-
425
- For backwards compatibility, we create this symbol to match the previous
426
- output of `make_keras_optimizer_class` but using the new logic.
427
-
428
- Args:
429
- cls: Class from which to derive a DP subclass. Should be a subclass of
430
- `tf.keras.optimizers.Optimizer`.
431
- """
432
- warnings .warn (
433
- '`make_keras_optimizer_class` will be depracated on 2023-02-23. '
434
- 'Please switch to `make_gaussian_query_optimizer_class` and the '
435
- 'generic optimizers (`make_keras_generic_optimizer_class`).' )
436
- return make_gaussian_query_optimizer_class (
437
- make_keras_generic_optimizer_class (cls ))
438
-
439
-
440
- GenericDPAdagradOptimizer = make_keras_generic_optimizer_class (
371
+ DPKerasAdagradOptimizer = make_keras_optimizer_class (
441
372
tf .keras .optimizers .Adagrad )
442
- GenericDPAdamOptimizer = make_keras_generic_optimizer_class (
443
- tf .keras .optimizers .Adam )
444
- GenericDPSGDOptimizer = make_keras_generic_optimizer_class (
445
- tf .keras .optimizers .SGD )
446
-
447
- # We keep the same names for backwards compatibility.
448
- DPKerasAdagradOptimizer = make_gaussian_query_optimizer_class (
449
- GenericDPAdagradOptimizer )
450
- DPKerasAdamOptimizer = make_gaussian_query_optimizer_class (
451
- GenericDPAdamOptimizer )
452
- DPKerasSGDOptimizer = make_gaussian_query_optimizer_class (GenericDPSGDOptimizer )
373
+ DPKerasAdamOptimizer = make_keras_optimizer_class (tf .keras .optimizers .Adam )
374
+ DPKerasSGDOptimizer = make_keras_optimizer_class (tf .keras .optimizers .SGD )
0 commit comments