Skip to content

Commit c31ad48

Browse files
committed
Update imports to use numpy
1 parent 8e9cf00 commit c31ad48

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

integration_tests/tf_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import jax_dataloader as jdl
2-
import jax
2+
import numpy as np
33
import tensorflow_datasets as tfds
44
import tensorflow as tf
55

@@ -9,13 +9,13 @@ def test_jax():
99
dl = jdl.DataLoader(ds, 'jax', batch_size=2)
1010
for x, y in dl:
1111
z = x + y
12-
assert isinstance(z, jax.Array)
12+
assert isinstance(z, np.ndarray)
1313

1414

1515
def test_tf():
1616
ds = tf.data.Dataset.from_tensor_slices((tf.ones((10, 3)), tf.ones((10, 3))))
1717
dl = jdl.DataLoader(ds, 'tensorflow', batch_size=2)
1818
for x, y in dl:
1919
z = x + y
20-
assert isinstance(z, jax.Array)
20+
assert isinstance(z, np.ndarray)
2121

0 commit comments

Comments
 (0)