Skip to content

Commit 27c06b6

Browse files
committed
Avoid assert_array_equal on PRNG keys.
This operates via conversion to np.array, which will soon be disallowed by jax-ml/jax#24481.
1 parent 40f080e commit 27c06b6

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

tests/core/core_lift_test.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import operator
16-
1715
import jax
1816
import numpy as np
1917
from absl.testing import absltest
@@ -170,12 +168,11 @@ def body_fn(scope, c):
170168
)
171169
self.assertEqual(vars['state']['acc'], x)
172170
self.assertEqual(c, 2 * x)
173-
np.testing.assert_array_equal(
171+
self.assertEqual(
174172
vars['state']['rng_params'][0], vars['state']['rng_params'][1]
175173
)
176174
with jax.debug_key_reuse(False):
177-
np.testing.assert_array_compare(
178-
operator.__ne__,
175+
self.assertNotEqual(
179176
vars['state']['rng_loop'][0],
180177
vars['state']['rng_loop'][1],
181178
)

0 commit comments

Comments
 (0)