|
| 1 | +# Adapted from https://github.com/google/flax/blob/main/examples/mnist/train.py |
| 2 | + |
| 3 | +from torchvision.datasets import FashionMNIST |
| 4 | +from jax_dataloader.core import ( |
| 5 | + get_backend_compatibilities, |
| 6 | + SUPPORTED_DATASETS, |
| 7 | + JAXDataset, |
| 8 | +) |
| 9 | +from jax_dataloader.imports import * |
| 10 | +import jax_dataloader as jdl |
| 11 | +import optax |
| 12 | +import ml_collections |
| 13 | +from flax import linen as nn |
| 14 | +from flax.metrics import tensorboard |
| 15 | +from flax.training import train_state |
| 16 | +import time |
| 17 | +import rich |
| 18 | +import einops |
| 19 | +import os |
| 20 | +import json |
| 21 | + |
| 22 | + |
| 23 | +# https://github.com/huggingface/datasets/blob/main/benchmarks/benchmark_iterating.py |
| 24 | +RESULTS_BASEPATH, RESULTS_FILENAME = os.path.split(__file__) |
| 25 | +RESULTS_FILE_PATH = os.path.join(RESULTS_BASEPATH, "results", RESULTS_FILENAME.replace(".py", ".json")) |
| 26 | + |
| 27 | + |
| 28 | +class CNN(nn.Module): |
| 29 | + """A simple CNN model.""" |
| 30 | + |
| 31 | + @nn.compact |
| 32 | + def __call__(self, x: jnp.ndarray): |
| 33 | + if x.ndim == 3: |
| 34 | + x = einops.rearrange(x, "h w c -> h w c 1") |
| 35 | + x = x / 255.0 |
| 36 | + x = nn.Conv(features=32, kernel_size=(3, 3))(x) |
| 37 | + x = nn.relu(x) |
| 38 | + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) |
| 39 | + x = nn.Conv(features=64, kernel_size=(3, 3))(x) |
| 40 | + x = nn.relu(x) |
| 41 | + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) |
| 42 | + x = x.reshape((x.shape[0], -1)) # flatten |
| 43 | + x = nn.Dense(features=256)(x) |
| 44 | + x = nn.relu(x) |
| 45 | + x = nn.Dense(features=10)(x) |
| 46 | + return x |
| 47 | + |
| 48 | + |
| 49 | +@jax.jit |
| 50 | +def apply_model(state, images, labels): |
| 51 | + """Computes gradients, loss and accuracy for a single batch.""" |
| 52 | + |
| 53 | + def loss_fn(params): |
| 54 | + logits = state.apply_fn({"params": params}, images) |
| 55 | + one_hot = jax.nn.one_hot(labels, 10) |
| 56 | + loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) |
| 57 | + return loss, logits |
| 58 | + |
| 59 | + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) |
| 60 | + (loss, logits), grads = grad_fn(state.params) |
| 61 | + accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) |
| 62 | + return grads, loss, accuracy |
| 63 | + |
| 64 | + |
| 65 | +@jax.jit |
| 66 | +def update_model(state, grads): |
| 67 | + return state.apply_gradients(grads=grads) |
| 68 | + |
| 69 | + |
| 70 | +def get_img_labels(batch): |
| 71 | + if isinstance(batch, tuple) or isinstance(batch, list): |
| 72 | + # print(batch[0]) |
| 73 | + if isinstance(batch[0], dict): |
| 74 | + imgs, labels = batch[0]["image"], batch[0]["label"] |
| 75 | + else: |
| 76 | + imgs, labels = batch |
| 77 | + elif isinstance(batch, dict): |
| 78 | + imgs, labels = batch["image"], batch["label"] |
| 79 | + else: |
| 80 | + raise ValueError( |
| 81 | + f"Unknown batch type: {type(batch)}", |
| 82 | + ) |
| 83 | + return imgs, labels |
| 84 | + |
| 85 | + |
| 86 | +def train_epoch(state, dataloader): |
| 87 | + """Train for a single epoch.""" |
| 88 | + |
| 89 | + epoch_loss = [] |
| 90 | + epoch_accuracy = [] |
| 91 | + |
| 92 | + for batch in dataloader: |
| 93 | + images, labels = get_img_labels(batch) |
| 94 | + # print(images.shape, labels.shape) |
| 95 | + grads, loss, accuracy = apply_model(state, images, labels) |
| 96 | + state = update_model(state, grads) |
| 97 | + epoch_loss.append(loss) |
| 98 | + epoch_accuracy.append(accuracy) |
| 99 | + train_loss = np.mean(epoch_loss) |
| 100 | + train_accuracy = np.mean(epoch_accuracy) |
| 101 | + return state, train_loss, train_accuracy |
| 102 | + |
| 103 | + |
| 104 | +def create_train_state(rng, config): |
| 105 | + """Creates initial `TrainState`.""" |
| 106 | + cnn = CNN() |
| 107 | + params = cnn.init(rng, jnp.ones([32, 28, 28]))["params"] |
| 108 | + tx = optax.sgd(config.learning_rate, config.momentum) |
| 109 | + return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx) |
| 110 | + |
| 111 | + |
| 112 | +def train_and_evaluate( |
| 113 | + config: ml_collections.ConfigDict, workdir: str |
| 114 | +) -> train_state.TrainState: |
| 115 | + """Execute model training and evaluation loop. |
| 116 | +
|
| 117 | + Args: |
| 118 | + config: Hyperparameter configuration for training and evaluation. |
| 119 | + workdir: Directory where the tensorboard summaries are written to. |
| 120 | +
|
| 121 | + Returns: |
| 122 | + The train state (which includes the `.params`). |
| 123 | + """ |
| 124 | + train_ds, test_ds = get_datasets(config.dataset_type) |
| 125 | + train_dl, test_dl = map( |
| 126 | + lambda ds: jdl.DataLoader( |
| 127 | + ds, backend=config.backend, batch_size=config.batch_size, shuffle=True |
| 128 | + ), |
| 129 | + (train_ds, test_ds), |
| 130 | + ) |
| 131 | + rng = jax.random.key(0) |
| 132 | + rng, init_rng = jax.random.split(rng) |
| 133 | + state = create_train_state(init_rng, config) |
| 134 | + |
| 135 | + runtime_per_epoch = [] |
| 136 | + for epoch in range(1, config.num_epochs + 1): |
| 137 | + rng, input_rng = jax.random.split(rng) |
| 138 | + start = time.time() |
| 139 | + state, train_loss, train_accuracy = train_epoch(state, train_dl) |
| 140 | + runtime_per_epoch.append(time.time() - start) |
| 141 | + |
| 142 | + test_accuracy = [] |
| 143 | + for batch in test_dl: |
| 144 | + images, labels = get_img_labels(batch) |
| 145 | + logits = state.apply_fn({"params": state.params}, images) |
| 146 | + accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) |
| 147 | + test_accuracy.append(accuracy) |
| 148 | + |
| 149 | + test_accuracy = np.mean(test_accuracy) |
| 150 | + |
| 151 | + rich.print(f"Test accuracy: {test_accuracy:.3f}") |
| 152 | + rich.print(f"Training batches: {len(train_dl)}") |
| 153 | + return state, runtime_per_epoch |
| 154 | + |
| 155 | + |
| 156 | +def get_datasets(ds_type: Literal["jax", "torch", "tf", "hf"]): |
| 157 | + """Returns train and test datasets.""" |
| 158 | + |
| 159 | + train_ds_torch = FashionMNIST( |
| 160 | + "/tmp/mnist/", |
| 161 | + download=True, |
| 162 | + transform=lambda x: np.array(x, dtype=float), |
| 163 | + train=True, |
| 164 | + ) |
| 165 | + test_ds_torch = FashionMNIST( |
| 166 | + "/tmp/mnist/", |
| 167 | + download=True, |
| 168 | + transform=lambda x: np.array(x, dtype=float), |
| 169 | + train=False, |
| 170 | + ) |
| 171 | + |
| 172 | + if ds_type == "jax": |
| 173 | + train_ds = jdl.ArrayDataset( |
| 174 | + train_ds_torch.data.numpy(), train_ds_torch.targets.numpy() |
| 175 | + ) |
| 176 | + test_ds = jdl.ArrayDataset( |
| 177 | + test_ds_torch.data.numpy(), test_ds_torch.targets.numpy() |
| 178 | + ) |
| 179 | + elif ds_type == "torch": |
| 180 | + train_ds, test_ds = train_ds_torch, test_ds_torch |
| 181 | + elif ds_type == "hf": |
| 182 | + ds = hf_datasets.load_dataset("fashion_mnist") |
| 183 | + train_ds, test_ds = ds["train"], ds["test"] |
| 184 | + elif ds_type == "tf": |
| 185 | + train_ds = tfds.load("fashion_mnist", split="train") |
| 186 | + test_ds = tfds.load("fashion_mnist", split="test") |
| 187 | + else: |
| 188 | + raise ValueError(f"Unknown dataset type: {ds_type}") |
| 189 | + return train_ds, test_ds |
| 190 | + |
| 191 | + |
| 192 | +def get_config(): |
| 193 | + config = ml_collections.ConfigDict() |
| 194 | + |
| 195 | + config.dataset_type = "jax" |
| 196 | + config.backend = "jax" |
| 197 | + config.batch_size = 128 |
| 198 | + config.num_epochs = 10 |
| 199 | + config.learning_rate = 0.1 |
| 200 | + config.momentum = 0.9 |
| 201 | + return config |
| 202 | + |
| 203 | + |
| 204 | +def main(): |
| 205 | + """Benchmark the training time for compatible backends and datasets.""" |
| 206 | + |
| 207 | + compat = get_backend_compatibilities() |
| 208 | + type2ds_name = { |
| 209 | + JAXDataset: "jax", |
| 210 | + TorchDataset: "torch", |
| 211 | + TFDataset: "tf", |
| 212 | + HFDataset: "hf", |
| 213 | + } |
| 214 | + print('Downloading datasets...') |
| 215 | + # download datasets |
| 216 | + for ds_name in type2ds_name.values(): get_datasets(ds_name) |
| 217 | + |
| 218 | + runtime = {} |
| 219 | + config = get_config() |
| 220 | + for backend, ds in compat.items(): |
| 221 | + if len(ds) > 0: |
| 222 | + _supported = [s in ds for s in SUPPORTED_DATASETS] |
| 223 | + runtime["backend=" + backend] = {} |
| 224 | + config.backend = backend |
| 225 | + |
| 226 | + for i, ds_type in enumerate(SUPPORTED_DATASETS): |
| 227 | + if _supported[i]: |
| 228 | + |
| 229 | + ds_name = type2ds_name[ds_type] |
| 230 | + config.dataset_type = ds_name |
| 231 | + print(f"backend={backend}, dataset={ds_name}:") |
| 232 | + train_state, runtime_per_epoch = train_and_evaluate(config, "/tmp/mnist") |
| 233 | + runtime["backend=" + backend]["dataset=" + ds_name] = runtime_per_epoch |
| 234 | + |
| 235 | + rich.print(f"Runtime per epoch: {np.mean(runtime_per_epoch[1:]): .3f} (std={np.std(runtime_per_epoch[1:]): .3f}).") |
| 236 | + del train_state |
| 237 | + else: |
| 238 | + ds_name = type2ds_name[ds_type] |
| 239 | + runtime["backend=" + backend]["dataset=" + ds_name] = [] |
| 240 | + print(f"backend={backend}, dataset={ds_name}: Not supported. Skipping.") |
| 241 | + |
| 242 | + with open(RESULTS_FILE_PATH, "wb") as f: |
| 243 | + f.write(json.dumps(runtime).encode("utf-8")) |
| 244 | + return runtime |
| 245 | + |
| 246 | + |
| 247 | +if __name__ == "__main__": |
| 248 | + # main() |
| 249 | + rich.print_json(data=main()) |
0 commit comments