Skip to content

Commit

Permalink
In dp_optimizer_keras_sparse, update iterations to reflect the numb…
Browse files Browse the repository at this point in the history
…er of logical batches, rather than physical batches.

In the current behavior, when using gradient accumulation, the `iterations` variable is incremented at every physical batch, while variables are only updated at every logical batch (where logical batch = accumulation_steps many physical batches). This causes certain optimizers that explicitly depend on `iterations` (such as Adam) to behave very differently under gradient accumulation.

With this change, `iterations` is only incremented after each logical batch.

PiperOrigin-RevId: 517197044
  • Loading branch information
walidk authored and tensorflower-gardener committed Mar 16, 2023
1 parent 7ae50c5 commit 52806ba
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 20 deletions.
9 changes: 9 additions & 0 deletions tensorflow_privacy/privacy/optimizers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,15 @@ py_test(
deps = [":dp_optimizer_keras_sparse"],
)

py_test(
name = "dp_optimizer_keras_sparse_distributed_test",
timeout = "long",
srcs = ["dp_optimizer_keras_sparse_distributed_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":dp_optimizer_keras_sparse"],
)

py_test(
name = "dp_optimizer_vectorized_test",
timeout = "long",
Expand Down
72 changes: 62 additions & 10 deletions tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def __init__(
self._num_microbatches = num_microbatches
self._was_dp_gradients_called = False
self._noise_stddev = None
self._acc_iterations = None
if self._num_microbatches is not None:
# The loss/gradients is the mean over the microbatches so we
# divide the noise by num_microbatches too to obtain the correct
Expand All @@ -202,23 +203,37 @@ def _generate_noise(self, g):
def _create_slots(self, var_list):
super()._create_slots(var_list) # pytype: disable=attribute-error
if self.gradient_accumulation_steps > 1:
# Slots for accumulating gradients.
for var in var_list:
self.add_slot(var, 'grad_acc')
if self._acc_iterations is None:
# Variable for the iterations, used for bookkeeping when to accumulate
# and when to update.
self._acc_iterations = self.add_weight(
'acc_iterations',
shape=[],
trainable=False,
dtype=tf.int64,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
)

def _prepare_local(self, var_device, var_dtype, apply_state):
super()._prepare_local(var_device, var_dtype, apply_state) # pytype: disable=attribute-error
if self.gradient_accumulation_steps > 1:
apply_update = tf.math.equal(
tf.math.floormod(self.iterations + 1,
self.gradient_accumulation_steps), 0)
tf.math.floormod(
self._acc_iterations + 1, self.gradient_accumulation_steps
),
0,
)
grad_scaler = tf.cast(1. / self.gradient_accumulation_steps, var_dtype)
apply_state[(var_device, var_dtype)].update({
'apply_update': apply_update,
'grad_scaler': grad_scaler
})

def _resource_apply(self, accum_op, grad, var, apply_state=None):
"""Help method for _resource_apply_dense and _resource_apply_sparse."""
"""Helper method for _resource_apply_dense and _resource_apply_sparse."""
if self.gradient_accumulation_steps > 1:
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
Expand All @@ -235,9 +250,8 @@ def _update_grad():
tf.zeros_like(grad_acc),
use_locking=self._use_locking,
read_value=False)
accum_op(grad_acc, grad, use_locking=self._use_locking)
return tf.cond(
coefficients['apply_update'], _update_grad, lambda: tf.no_op()) # pylint: disable=unnecessary-lambda
with tf.control_dependencies([accum_op(grad_acc, grad)]):
return tf.cond(coefficients['apply_update'], _update_grad, tf.no_op)
else:
grad = tf.convert_to_tensor(grad)
grad = grad + self._generate_noise(grad)
Expand All @@ -246,9 +260,11 @@ def _update_grad():

def _resource_apply_dense(self, grad, var, apply_state=None):
"""Handles dense gradients."""
def _accum_op(grad_acc, grad, use_locking):
def _accum_op(grad_acc, grad):
return grad_acc.assign_add(
grad, use_locking=use_locking, read_value=False)
grad, use_locking=self._use_locking, read_value=False
)

return self._resource_apply(_accum_op, grad, var, apply_state)

# This method is implemented the same as that in optimizer_v2.py. We
Expand All @@ -271,13 +287,49 @@ def _deduplicate_indexed_slices(values, indices):

def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
"""Handles deduped sparse gradients."""
def _accum_op(grad_acc, sparse_delta, use_locking):
def _accum_op(grad_acc, sparse_delta):
return grad_acc.scatter_add(
sparse_delta=sparse_delta, use_locking=use_locking)
sparse_delta=sparse_delta, use_locking=self._use_locking
)

sparse_delta = tf.IndexedSlices(
values=grad, indices=indices, dense_shape=var.shape)
return self._resource_apply(_accum_op, sparse_delta, var, apply_state)

def _distributed_apply(
self, distribution, grads_and_vars, apply_state, name
):
apply_op = super()._distributed_apply(
distribution, grads_and_vars, apply_state, name
)
if self.gradient_accumulation_steps > 1:
# The original _distributed_apply increments self.iterations after each
# call. But we want to increment it only after each logical batch is
# processed, so optimizers that explicitly use self.iterations in their
# updates (such as Adam) can use the correct value.
def increment_acc_iterations():
# Always use locking when updating the steps, so we don't under-count
# the steps (which could invalidate privacy accounting).
return self._acc_iterations.assign_add(
1, use_locking=True, read_value=False
)

def assign_iterations():
return self.iterations.assign(
tf.math.floordiv(
self._acc_iterations, self.gradient_accumulation_steps
),
use_locking=True,
read_value=False,
)

with tf.control_dependencies([apply_op]):
with tf.control_dependencies([increment_acc_iterations()]):
return assign_iterations()
else:
# No accumulation.
return apply_op

def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
"""DP-SGD version of base class method."""
self._was_dp_gradients_called = True
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2023, The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests DPSparseKerasSGDOptimizer in distributed training."""
import contextlib
import multiprocessing
import os
import sys
from absl.testing import parameterized
import numpy as np
import tensorflow as tf

from tensorflow_privacy.privacy.optimizers import dp_optimizer_keras_sparse

ds_combinations = tf.__internal__.distribute.combinations


STRATEGIES = [
ds_combinations.one_device_strategy,
ds_combinations.parameter_server_strategy_1worker_2ps_cpu,
]


class DistributedTrainingTest(parameterized.TestCase, tf.test.TestCase):

@ds_combinations.generate(
tf.__internal__.test.combinations.combine(
strategy=STRATEGIES, mode="eager"
)
)
def test_training_works(self, strategy):
if isinstance(strategy, tf.distribute.OneDeviceStrategy):
strategy_scope = contextlib.nullcontext()
else:
strategy_scope = strategy.scope()

def make_model():
inputs = tf.keras.Input((1000,))
dense = tf.keras.layers.Dense(
units=1, use_bias=False, kernel_initializer=tf.initializers.ones()
)
outputs = dense(inputs)
return tf.keras.models.Model(inputs=inputs, outputs=outputs)

x = tf.ones(shape=[5000, 1000])
y = tf.zeros(shape=[5000])
with strategy_scope:
model = make_model()
clip = 100.0
noise_mult = 0.01
acc_steps = 5
batch_size = 10
opt = dp_optimizer_keras_sparse.DPSparseKerasSGDOptimizer(
l2_norm_clip=clip,
noise_multiplier=noise_mult,
gradient_accumulation_steps=acc_steps,
learning_rate=0.001,
)
model.compile(
loss=tf.keras.losses.MeanAbsoluteError(
reduction=tf.keras.losses.Reduction.NONE
),
optimizer=opt,
)
history = model.fit(
x=x,
y=y,
epochs=2,
steps_per_epoch=500,
batch_size=batch_size,
)
self.assertIn("loss", history.history)
# total steps: 1000 (2 epochs, 500 steps/epoch)
# accumulation steps: 5
# expected_iterations = total steps / accumulation steps
expected_iterations = 1000 / acc_steps # = 200
# The loss is |w.x - y| (where w is the dense layer weights).
# The gradient is sign(w.x - y)x. With the choice of x, y, the gradient
# becomes x.
# So each gradient update is w <- w - learning_rate*1 + noise
expected_params = 1 - 0.001 * expected_iterations
expected_noise = (
0.001
* clip
* noise_mult
* np.sqrt(expected_iterations)
/ (acc_steps * batch_size)
)
self.assertEqual(opt.iterations.numpy(), expected_iterations)
self.assertAllClose(
np.mean(model.trainable_variables[0].numpy()),
expected_params, # 0.8
# stddev = expected_noise/√1000 (since we're averaging 1000 samples)
# we set atol to 4 stddev
atol=4 * expected_noise / np.sqrt(1000), # 0.0358
)
self.assertAllClose(
np.std(model.trainable_variables[0].numpy()),
expected_noise, # 0.2828
atol=4 * expected_noise / np.sqrt(1000), # 0.0358
)


def _set_spawn_exe_path():
"""Set the path to the executable for spawned processes.
This utility searches for the binary the parent process is using, and sets
the executable of multiprocessing's context accordingly.
It is branched from tensorflow/python/distribute/multi_process_lib.py, the
only change being that it additionally looks under "org_tensorflow_privacy".
"""
if sys.argv[0].endswith(".py"):

def guess_path(package_root):
# If all we have is a python module path, we'll need to make a guess for
# the actual executable path.
if "bazel-out" in sys.argv[0] and package_root in sys.argv[0]:
# Guess the binary path under bazel. For target
# //tensorflow/python/distribute:input_lib_test_multiworker_gpu, the
# argv[0] is in the form of
# /.../tensorflow/python/distribute/input_lib_test.py
# and the binary is
# /.../tensorflow/python/distribute/input_lib_test_multiworker_gpu
package_root_base = sys.argv[0][: sys.argv[0].rfind(package_root)]
binary = os.environ["TEST_TARGET"][2:].replace(":", "/", 1)
possible_path = os.path.join(package_root_base, package_root, binary)
if os.access(possible_path, os.X_OK):
return possible_path
return None

path = (
guess_path("org_tensorflow")
or guess_path("org_keras")
or guess_path("org_tensorflow_privacy")
)
if path is not None:
sys.argv[0] = path
multiprocessing.get_context().set_executable(sys.argv[0])


if __name__ == "__main__":
_set_spawn_exe_path()
tf.__internal__.distribute.multi_process_runner.test_main()
Original file line number Diff line number Diff line change
Expand Up @@ -338,21 +338,25 @@ def testLargeBatchEmulationNoNoise(self):
# After first call to optimizer values didn't change
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0)
self.assertAllCloseAccordingToType([3.0], var1)
self.assertEqual(opt.iterations, 0)

opt.minimize(loss2, [var0, var1])
# After second call to optimizer updates were applied
self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0)
self.assertAllCloseAccordingToType([2.0], var1)
self.assertEqual(opt.iterations, 1)

opt.minimize(loss2, [var0, var1])
# After third call to optimizer values didn't change
self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0)
self.assertAllCloseAccordingToType([2.0], var1)
self.assertEqual(opt.iterations, 1)

opt.minimize(loss2, [var0, var1])
# After fourth call to optimizer updates were applied again
self.assertAllCloseAccordingToType([[-4.0, -0.5]], var0)
self.assertAllCloseAccordingToType([1.0], var1)
self.assertEqual(opt.iterations, 2)

@parameterized.named_parameters(
('DPSparseKerasSGDOptimizer 1',
Expand Down Expand Up @@ -388,6 +392,7 @@ def testLargeBatchEmulation(self, cls, gradient_accumulation_steps):

self.assertNotAllClose([[1.0, 2.0]], var0)
self.assertNotAllClose([3.0], var1)
self.assertEqual(opt.iterations, 1)

def testKerasModelBaselineSaving(self):
"""Tests that DP optimizers work with tf.keras.Model."""
Expand Down Expand Up @@ -455,10 +460,15 @@ def testKerasModelBaselineAfterSavingLoading(self):

model.fit(train_data, train_labels, batch_size=8, epochs=1, shuffle=False)

@parameterized.named_parameters(('1', 1), ('None', None))
def testKerasModelBaselineNoNoise(self, num_microbatches):
@parameterized.named_parameters(
('no_microbatch_no_accumulation', False, False),
('no_microbatch_accumulation', False, True),
('microbatch_no_accumulation', True, False),
('microbatch_accumulation', True, True),
)
def testKerasModelBaselineNoNoise(self, microbatch, accumulate):
"""Tests that DP optimizers work with tf.keras.Model."""

acc_steps = 2 if accumulate else 1
model = tf.keras.models.Sequential(layers=[
tf.keras.layers.Dense(
1,
Expand All @@ -471,22 +481,25 @@ def testKerasModelBaselineNoNoise(self, num_microbatches):
optimizer = dp_optimizer.DPSparseKerasSGDOptimizer(
l2_norm_clip=100.0,
noise_multiplier=0.0,
num_microbatches=num_microbatches,
learning_rate=0.05)
num_microbatches=None if microbatch else 1,
learning_rate=0.05,
gradient_accumulation_steps=acc_steps,
)
loss = tf.keras.losses.MeanSquaredError(reduction='none')
model.compile(optimizer, loss)

true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32)
true_bias = np.array([6.0]).astype(np.float32)
train_data = np.random.normal(scale=3.0, size=(1000, 4)).astype(np.float32)
train_labels = np.matmul(train_data,
true_weights) + true_bias + np.random.normal(
scale=0.0, size=(1000, 1)).astype(np.float32)
train_data = np.random.normal(scale=3.0, size=(2000, 4)).astype(np.float32)
train_labels = np.matmul(train_data, true_weights) + true_bias

model.fit(train_data, train_labels, batch_size=8, epochs=1, shuffle=False)
model.fit(train_data, train_labels, batch_size=10, epochs=1, shuffle=False)

self.assertAllClose(model.get_weights()[0], true_weights, atol=0.05)
self.assertAllClose(model.get_weights()[1], true_bias, atol=0.05)
# Check that the optimizer's iterations equal the number of logical batches.
total_batches = 200
self.assertEqual(optimizer.iterations.numpy(), total_batches / acc_steps)


if __name__ == '__main__':
Expand Down

0 comments on commit 52806ba

Please sign in to comment.