Skip to content

Commit

Permalink
Merge pull request #28 from BirkhoffG/integration_tests
Browse files Browse the repository at this point in the history
Add integration tests
  • Loading branch information
BirkhoffG authored Feb 15, 2024
2 parents 82479ef + dd99bea commit fe8c756
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 5 deletions.
55 changes: 50 additions & 5 deletions .github/workflows/nbdev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: "3.9"
cache: "pip"
Expand Down Expand Up @@ -49,7 +49,52 @@ jobs:
git diff
exit 1;
fi
integration-tests:
needs: nbdev-sync
runs-on: ubuntu-latest
strategy:
matrix:
backend: [jax, torch, tf, hf]
steps:
- name: Checkout Code
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: 3.9
cache: "pip"
cache-dependency-path: settings.ini

- name: Install Base Dependencies
run: |
pip install --upgrade pip
pip install -e .
pip install pytest
- name: Test JAX backend
if: ${{ matrix.backend == 'jax'}}
run: |
pytest integration_tests/jax_test.py
- name: Test Pytorch Dependencies
if: ${{ matrix.backend == 'torch'}}
run: |
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pytest integration_tests/torch_test.py
- name: Test Tensorflow Dependencies
if: ${{ matrix.backend == 'tf'}}
run: |
pip install -e .[tensorflow]
pytest integration_tests/tf_test.py
- name: Test Huggingface Dependencies
if: ${{ matrix.backend == 'hf'}}
run: |
pip install -e .[huggingface]
pytest integration_tests/hf_test.py
nbdev-tests:
needs: nbdev-sync
Expand All @@ -59,10 +104,10 @@ jobs:
py: ['3.9', '3.10', '3.11']
steps:
- name: Checkout Code
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.py }}
cache: "pip"
Expand Down
12 changes: 12 additions & 0 deletions integration_tests/hf_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import jax_dataloader as jdl
import numpy as np
import datasets as hfds


def test_hf():
ds = hfds.Dataset.from_dict({"feats": np.ones((10, 3)), "labels": np.ones((10, 3))})
dl = jdl.DataLoader(ds, 'jax', batch_size=2)
for batch in dl:
x, y = batch['feats'], batch['labels']
z = x + y
assert isinstance(z, np.ndarray)
26 changes: 26 additions & 0 deletions integration_tests/jax_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import jax_dataloader as jdl
import jax.numpy as jnp
import pytest


def test_jax():
ds = jdl.ArrayDataset(jnp.ones((10, 3)), jnp.ones((10, 3)))
assert len(ds) == 10
dl = jdl.DataLoader(ds, 'jax', batch_size=2)
for x, y in dl:
z = x + y


def test_torch():
with pytest.raises(ModuleNotFoundError):
ds = jdl.ArrayDataset(jnp.ones((10, 3)), jnp.ones((10, 3)))
dl = jdl.DataLoader(ds, 'pytorch', batch_size=2)
for x, y in dl: z = x + y


def test_tf():
with pytest.raises(ModuleNotFoundError):
ds = jdl.ArrayDataset(jnp.ones((10, 3)), jnp.ones((10, 3)))
dl = jdl.DataLoader(ds, 'tensorflow', batch_size=2)
for x, y in dl: z = x + y

21 changes: 21 additions & 0 deletions integration_tests/tf_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import jax_dataloader as jdl
import numpy as np
import tensorflow_datasets as tfds
import tensorflow as tf


def test_jax():
ds = jdl.ArrayDataset(np.ones((10, 3)), np.ones((10, 3)))
dl = jdl.DataLoader(ds, 'tensorflow', batch_size=2)
for x, y in dl:
z = x + y
assert isinstance(z, np.ndarray)


def test_tf():
ds = tf.data.Dataset.from_tensor_slices((tf.ones((10, 3)), tf.ones((10, 3))))
dl = jdl.DataLoader(ds, 'tensorflow', batch_size=2)
for x, y in dl:
z = x + y
assert isinstance(z, np.ndarray)

22 changes: 22 additions & 0 deletions integration_tests/torch_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import jax_dataloader as jdl
import torch
import numpy as np
import jax.numpy as jnp
from torch.utils.data import TensorDataset


def test_jax_ds():
ds = jdl.ArrayDataset(jnp.ones((10, 3)), jnp.ones((10, 3)))
assert len(ds) == 10
dl = jdl.DataLoader(ds, 'pytorch', batch_size=2)
for x, y in dl:
z = x + y


def test_torch():
ds = TensorDataset(torch.ones((10, 3)), torch.ones((10, 3)))
dl = jdl.DataLoader(ds, 'pytorch', batch_size=2)
for x, y in dl:
z = x + y
assert isinstance(z, np.ndarray)

0 comments on commit fe8c756

Please sign in to comment.