From 9547bf13e9d3a3a025de04527b52a7504b57b6b4 Mon Sep 17 00:00:00 2001 From: BirkhoffG <26811230+BirkhoffG@users.noreply.github.com> Date: Wed, 28 Feb 2024 19:36:15 +0000 Subject: [PATCH] Add ViT to examples --- nbs/examples/vit.torch.ipynb | 69 +++++++++++++++++++++++++++++------- nbs/sidebar.yml | 3 ++ 2 files changed, 59 insertions(+), 13 deletions(-) diff --git a/nbs/examples/vit.torch.ipynb b/nbs/examples/vit.torch.ipynb index 6cbe386..d2ec3fe 100644 --- a/nbs/examples/vit.torch.ipynb +++ b/nbs/examples/vit.torch.ipynb @@ -31,6 +31,7 @@ "\n", "# utils\n", "import functools as ft\n", + "import matplotlib.pyplot as plt\n", "from dataclasses import dataclass\n", "from tqdm.auto import tqdm" ] @@ -250,17 +251,19 @@ " xs, ys = next(iter(data_loader))\n", " params = model.init(init_key, xs, train=False)\n", " opt_state = optimizer.init(params)\n", - " losses = []\n", + " losses, steps = [], 0\n", "\n", " for epoch in range(epochs):\n", - " with tqdm(data_loader, desc=f\"Epoch {epoch}\", unit='batch') as pbar:\n", - " for batch in pbar:\n", - " rng_key, key = jrand.split(rng_key)\n", - " params, opt_state, loss = step(\n", - " params, model, optimizer, opt_state, batch, key\n", - " )\n", - " losses.append(loss)\n", - " pbar.set_postfix({\"loss\": loss})\n", + " for batch in data_loader:\n", + " rng_key, key = jrand.split(rng_key)\n", + " params, opt_state, loss = step(\n", + " params, model, optimizer, opt_state, batch, key\n", + " )\n", + " losses.append(loss)\n", + " steps += 1\n", + "\n", + " if steps % 500 == 0:\n", + " print(f\"Epoch: {epoch}, Step: {steps}, Loss: {loss}\")\n", " return params, losses" ] }, @@ -271,11 +274,11 @@ "outputs": [], "source": [ "# Hyperparameters\n", - "lr = 1e-3\n", + "lr = 3e-4\n", "dropout_rate = 0.1\n", "beta1 = 0.9\n", "beta2 = 0.99\n", - "batch_size = 64 * 2\n", + "batch_size = 64 * 2 * 2\n", "patch_size = 4\n", "num_patches = 64\n", "num_steps = 100000\n", @@ -297,7 +300,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "Files already downloaded and verified\n", + "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/CIFAR/cifar-10-python.tar.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 170498071/170498071 [00:02<00:00, 63361930.39it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting /tmp/CIFAR/cifar-10-python.tar.gz to /tmp/CIFAR/\n", "Files already downloaded and verified\n" ] } @@ -364,7 +381,33 @@ "metadata": {}, "outputs": [], "source": [ - "_, losses = train(vit, opt, dl, 100)" + "params, losses = train(vit, opt, dl, 500)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy: 0.8149\n" + ] + } + ], + "source": [ + "corrects = []\n", + "\n", + "dl = jdl.DataLoader(test_dataset, 'pytorch', batch_size=batch_size * 4, shuffle=True)\n", + "for batch in dl:\n", + " img, label = batch\n", + " logits = vit.apply(params, img, rngs={'dropout': jrand.PRNGKey(0)}, train=False)\n", + " preds = jnp.argmax(logits, axis=-1)\n", + " corrects.append((preds == label))\n", + "\n", + "print(f\"Accuracy: {np.concatenate(corrects).mean()}\")" ] } ], diff --git a/nbs/sidebar.yml b/nbs/sidebar.yml index 46f81e1..5dd98b3 100644 --- a/nbs/sidebar.yml +++ b/nbs/sidebar.yml @@ -2,6 +2,9 @@ website: sidebar: contents: - index.ipynb + - section: Examples + contents: + - examples/vit.torch.ipynb - section: API contents: - core.ipynb