Skip to content

Commit fb1a92f

Browse files
authored
Merge pull request #23 from BirkhoffG/mnist
Create a benchmarking script
2 parents 95fc759 + 316d0e4 commit fb1a92f

File tree

7 files changed

+261
-11
lines changed

7 files changed

+261
-11
lines changed

benchmarks/mnist.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# Adapted from https://github.com/google/flax/blob/main/examples/mnist/train.py
2+
3+
from torchvision.datasets import FashionMNIST
4+
from jax_dataloader.core import (
5+
get_backend_compatibilities,
6+
SUPPORTED_DATASETS,
7+
JAXDataset,
8+
)
9+
from jax_dataloader.imports import *
10+
import jax_dataloader as jdl
11+
import optax
12+
import ml_collections
13+
from flax import linen as nn
14+
from flax.metrics import tensorboard
15+
from flax.training import train_state
16+
import time
17+
import rich
18+
import einops
19+
import os
20+
import json
21+
22+
23+
# https://github.com/huggingface/datasets/blob/main/benchmarks/benchmark_iterating.py
24+
RESULTS_BASEPATH, RESULTS_FILENAME = os.path.split(__file__)
25+
RESULTS_FILE_PATH = os.path.join(RESULTS_BASEPATH, "results", RESULTS_FILENAME.replace(".py", ".json"))
26+
27+
28+
class CNN(nn.Module):
29+
"""A simple CNN model."""
30+
31+
@nn.compact
32+
def __call__(self, x: jnp.ndarray):
33+
if x.ndim == 3:
34+
x = einops.rearrange(x, "h w c -> h w c 1")
35+
x = x / 255.0
36+
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
37+
x = nn.relu(x)
38+
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
39+
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
40+
x = nn.relu(x)
41+
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
42+
x = x.reshape((x.shape[0], -1)) # flatten
43+
x = nn.Dense(features=256)(x)
44+
x = nn.relu(x)
45+
x = nn.Dense(features=10)(x)
46+
return x
47+
48+
49+
@jax.jit
50+
def apply_model(state, images, labels):
51+
"""Computes gradients, loss and accuracy for a single batch."""
52+
53+
def loss_fn(params):
54+
logits = state.apply_fn({"params": params}, images)
55+
one_hot = jax.nn.one_hot(labels, 10)
56+
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
57+
return loss, logits
58+
59+
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
60+
(loss, logits), grads = grad_fn(state.params)
61+
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
62+
return grads, loss, accuracy
63+
64+
65+
@jax.jit
66+
def update_model(state, grads):
67+
return state.apply_gradients(grads=grads)
68+
69+
70+
def get_img_labels(batch):
71+
if isinstance(batch, tuple) or isinstance(batch, list):
72+
# print(batch[0])
73+
if isinstance(batch[0], dict):
74+
imgs, labels = batch[0]["image"], batch[0]["label"]
75+
else:
76+
imgs, labels = batch
77+
elif isinstance(batch, dict):
78+
imgs, labels = batch["image"], batch["label"]
79+
else:
80+
raise ValueError(
81+
f"Unknown batch type: {type(batch)}",
82+
)
83+
return imgs, labels
84+
85+
86+
def train_epoch(state, dataloader):
87+
"""Train for a single epoch."""
88+
89+
epoch_loss = []
90+
epoch_accuracy = []
91+
92+
for batch in dataloader:
93+
images, labels = get_img_labels(batch)
94+
# print(images.shape, labels.shape)
95+
grads, loss, accuracy = apply_model(state, images, labels)
96+
state = update_model(state, grads)
97+
epoch_loss.append(loss)
98+
epoch_accuracy.append(accuracy)
99+
train_loss = np.mean(epoch_loss)
100+
train_accuracy = np.mean(epoch_accuracy)
101+
return state, train_loss, train_accuracy
102+
103+
104+
def create_train_state(rng, config):
105+
"""Creates initial `TrainState`."""
106+
cnn = CNN()
107+
params = cnn.init(rng, jnp.ones([32, 28, 28]))["params"]
108+
tx = optax.sgd(config.learning_rate, config.momentum)
109+
return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)
110+
111+
112+
def train_and_evaluate(
113+
config: ml_collections.ConfigDict, workdir: str
114+
) -> train_state.TrainState:
115+
"""Execute model training and evaluation loop.
116+
117+
Args:
118+
config: Hyperparameter configuration for training and evaluation.
119+
workdir: Directory where the tensorboard summaries are written to.
120+
121+
Returns:
122+
The train state (which includes the `.params`).
123+
"""
124+
train_ds, test_ds = get_datasets(config.dataset_type)
125+
train_dl, test_dl = map(
126+
lambda ds: jdl.DataLoader(
127+
ds, backend=config.backend, batch_size=config.batch_size, shuffle=True
128+
),
129+
(train_ds, test_ds),
130+
)
131+
rng = jax.random.key(0)
132+
rng, init_rng = jax.random.split(rng)
133+
state = create_train_state(init_rng, config)
134+
135+
runtime_per_epoch = []
136+
for epoch in range(1, config.num_epochs + 1):
137+
rng, input_rng = jax.random.split(rng)
138+
start = time.time()
139+
state, train_loss, train_accuracy = train_epoch(state, train_dl)
140+
runtime_per_epoch.append(time.time() - start)
141+
142+
test_accuracy = []
143+
for batch in test_dl:
144+
images, labels = get_img_labels(batch)
145+
logits = state.apply_fn({"params": state.params}, images)
146+
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
147+
test_accuracy.append(accuracy)
148+
149+
test_accuracy = np.mean(test_accuracy)
150+
151+
rich.print(f"Test accuracy: {test_accuracy:.3f}")
152+
rich.print(f"Training batches: {len(train_dl)}")
153+
return state, runtime_per_epoch
154+
155+
156+
def get_datasets(ds_type: Literal["jax", "torch", "tf", "hf"]):
157+
"""Returns train and test datasets."""
158+
159+
train_ds_torch = FashionMNIST(
160+
"/tmp/mnist/",
161+
download=True,
162+
transform=lambda x: np.array(x, dtype=float),
163+
train=True,
164+
)
165+
test_ds_torch = FashionMNIST(
166+
"/tmp/mnist/",
167+
download=True,
168+
transform=lambda x: np.array(x, dtype=float),
169+
train=False,
170+
)
171+
172+
if ds_type == "jax":
173+
train_ds = jdl.ArrayDataset(
174+
train_ds_torch.data.numpy(), train_ds_torch.targets.numpy()
175+
)
176+
test_ds = jdl.ArrayDataset(
177+
test_ds_torch.data.numpy(), test_ds_torch.targets.numpy()
178+
)
179+
elif ds_type == "torch":
180+
train_ds, test_ds = train_ds_torch, test_ds_torch
181+
elif ds_type == "hf":
182+
ds = hf_datasets.load_dataset("fashion_mnist")
183+
train_ds, test_ds = ds["train"], ds["test"]
184+
elif ds_type == "tf":
185+
train_ds = tfds.load("fashion_mnist", split="train")
186+
test_ds = tfds.load("fashion_mnist", split="test")
187+
else:
188+
raise ValueError(f"Unknown dataset type: {ds_type}")
189+
return train_ds, test_ds
190+
191+
192+
def get_config():
193+
config = ml_collections.ConfigDict()
194+
195+
config.dataset_type = "jax"
196+
config.backend = "jax"
197+
config.batch_size = 128
198+
config.num_epochs = 10
199+
config.learning_rate = 0.1
200+
config.momentum = 0.9
201+
return config
202+
203+
204+
def main():
205+
"""Benchmark the training time for compatible backends and datasets."""
206+
207+
compat = get_backend_compatibilities()
208+
type2ds_name = {
209+
JAXDataset: "jax",
210+
TorchDataset: "torch",
211+
TFDataset: "tf",
212+
HFDataset: "hf",
213+
}
214+
print('Downloading datasets...')
215+
# download datasets
216+
for ds_name in type2ds_name.values(): get_datasets(ds_name)
217+
218+
runtime = {}
219+
config = get_config()
220+
for backend, ds in compat.items():
221+
if len(ds) > 0:
222+
_supported = [s in ds for s in SUPPORTED_DATASETS]
223+
runtime["backend=" + backend] = {}
224+
config.backend = backend
225+
226+
for i, ds_type in enumerate(SUPPORTED_DATASETS):
227+
if _supported[i]:
228+
229+
ds_name = type2ds_name[ds_type]
230+
config.dataset_type = ds_name
231+
print(f"backend={backend}, dataset={ds_name}:")
232+
train_state, runtime_per_epoch = train_and_evaluate(config, "/tmp/mnist")
233+
runtime["backend=" + backend]["dataset=" + ds_name] = runtime_per_epoch
234+
235+
rich.print(f"Runtime per epoch: {np.mean(runtime_per_epoch[1:]): .3f} (std={np.std(runtime_per_epoch[1:]): .3f}).")
236+
del train_state
237+
else:
238+
ds_name = type2ds_name[ds_type]
239+
runtime["backend=" + backend]["dataset=" + ds_name] = []
240+
print(f"backend={backend}, dataset={ds_name}: Not supported. Skipping.")
241+
242+
with open(RESULTS_FILE_PATH, "wb") as f:
243+
f.write(json.dumps(runtime).encode("utf-8"))
244+
return runtime
245+
246+
247+
if __name__ == "__main__":
248+
# main()
249+
rich.print_json(data=main())

benchmarks/results/mnist.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"backend=jax": {"dataset=jax": [6.00615930557251, 0.7483389377593994, 0.743844747543335, 0.7461087703704834, 0.7428982257843018, 0.7401800155639648, 0.7440762519836426, 0.7452375888824463, 0.7448098659515381, 0.7438862323760986], "dataset=torch": [], "dataset=tf": [], "dataset=hf": [5.5362865924835205, 4.725424528121948, 4.779559373855591, 4.82304835319519, 4.8099517822265625, 4.83212685585022, 4.809313774108887, 4.818130016326904, 4.873654127120972, 4.943045377731323]}, "backend=pytorch": {"dataset=jax": [1.839327096939087, 0.966571569442749, 0.9141397476196289, 0.9267113208770752, 0.9212968349456787, 0.9556140899658203, 0.912912130355835, 0.9298040866851807, 1.024169921875, 0.9741010665893555], "dataset=torch": [3.6769654750823975, 2.9240589141845703, 2.9682703018188477, 2.9608328342437744, 2.9808106422424316, 2.9670567512512207, 3.0383379459381104, 3.0271716117858887, 2.952286720275879, 2.946077823638916], "dataset=tf": [], "dataset=hf": [6.0433268547058105, 5.160061359405518, 5.155879259109497, 5.418692350387573, 6.080808639526367, 5.252376079559326, 5.186870098114014, 5.237852096557617, 5.198176622390747, 5.24579381942749]}, "backend=tensorflow": {"dataset=jax": [1.793637752532959, 0.9664194583892822, 0.8171765804290771, 0.8145453929901123, 0.8179428577423096, 0.8119533061981201, 0.8084969520568848, 0.8056025505065918, 0.8245463371276855, 0.8152358531951904], "dataset=torch": [], "dataset=tf": [3.459198474884033, 1.1708533763885498, 1.228088617324829, 1.1557502746582031, 1.124732494354248, 1.1766808032989502, 1.119699239730835, 1.1329691410064697, 1.1301155090332031, 1.1265795230865479], "dataset=hf": [34.50724649429321, 34.40732932090759, 34.51625299453735, 33.26384210586548, 32.72089385986328, 32.59306216239929, 32.585314989089966, 32.482611656188965, 32.63717746734619, 32.016772747039795]}}

jax_dataloader/imports.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@
2626
import torch.utils.data as torch_data
2727
import torch
2828

29-
TorchDataset = Annotated[
30-
torch_data.Dataset,
31-
Is[lambda _: torch_data is not None],
32-
]
29+
TorchDataset = torch_data.Dataset
3330
except ModuleNotFoundError:
3431
torch_data = None
3532
torch = None
@@ -55,10 +52,11 @@
5552
import tensorflow as tf
5653
import tensorflow_datasets as tfds
5754

58-
TFDataset = Annotated[
59-
tf.data.Dataset,
60-
Is[lambda _: tf is not None],
61-
]
55+
# TFDataset = Annotated[
56+
# tf.data.Dataset,
57+
# Is[lambda _: tf is not None],
58+
# ]
59+
TFDataset = tf.data.Dataset
6260
except ModuleNotFoundError:
6361
tf = None
6462
tfds = None

jax_dataloader/loaders/jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def to_jax_dataset(dataset: JAXDataset):
8282

8383
@dispatch
8484
def to_jax_dataset(dataset: HFDataset):
85-
return dataset.with_format('jax')
85+
return dataset.with_format('numpy')
8686

8787
# %% ../../nbs/loader.jax.ipynb 8
8888
class DataLoaderJAX(BaseDataLoader):

jax_dataloader/loaders/torch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from ..utils import check_pytorch_installed
99
from ..tests import *
1010
from jax.tree_util import tree_map
11+
import warnings
1112

1213

1314
# %% auto 0

nbs/loader.jax.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@
152152
"\n",
153153
"@dispatch\n",
154154
"def to_jax_dataset(dataset: HFDataset):\n",
155-
" return dataset.with_format('jax')"
155+
" return dataset.with_format('numpy')"
156156
]
157157
},
158158
{

nbs/loader.torch.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@
4747
"from jax_dataloader.datasets import Dataset, ArrayDataset, JAXDataset\n",
4848
"from jax_dataloader.utils import check_pytorch_installed\n",
4949
"from jax_dataloader.tests import *\n",
50-
"from jax.tree_util import tree_map\n"
50+
"from jax.tree_util import tree_map\n",
51+
"import warnings\n"
5152
]
5253
},
5354
{

0 commit comments

Comments
 (0)