diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_tpu_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_tpu_test.py index 9411d618..a15d5c5f 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_tpu_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_tpu_test.py @@ -20,6 +20,7 @@ class GradNormTpuTest(embedding_test.GradNormTest): def setUp(self): + tf.config.experimental.disable_mlir_bridge() super(embedding_test.GradNormTest, self).setUp() self.strategy = common_test_utils.create_tpu_strategy() self.assertIn('TPU', self.strategy.extended.worker_devices[0]) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py index 283a6717..98acc870 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py @@ -20,6 +20,7 @@ class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest): def setUp(self): + tf.config.experimental.disable_mlir_bridge() super(nlp_on_device_embedding_test.GradNormTest, self).setUp() self.strategy = common_test_utils.create_tpu_strategy() self.assertIn('TPU', self.strategy.extended.worker_devices[0])