From 301c5c4db32535078e51971a633a46e6bb696857 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Sep 2024 14:53:36 -0700 Subject: [PATCH] Disable MLIR bridge for the test points that MLIR bridge silently fails PiperOrigin-RevId: 676573736 --- .../registry_functions/embedding_tpu_test.py | 1 + .../registry_functions/nlp_on_device_embedding_tpu_test.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_tpu_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_tpu_test.py index 9411d618..a15d5c5f 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_tpu_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_tpu_test.py @@ -20,6 +20,7 @@ class GradNormTpuTest(embedding_test.GradNormTest): def setUp(self): + tf.config.experimental.disable_mlir_bridge() super(embedding_test.GradNormTest, self).setUp() self.strategy = common_test_utils.create_tpu_strategy() self.assertIn('TPU', self.strategy.extended.worker_devices[0]) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py index 283a6717..98acc870 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py @@ -20,6 +20,7 @@ class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest): def setUp(self): + tf.config.experimental.disable_mlir_bridge() super(nlp_on_device_embedding_test.GradNormTest, self).setUp() self.strategy = common_test_utils.create_tpu_strategy() self.assertIn('TPU', self.strategy.extended.worker_devices[0])