Skip to content

Commit

Permalink
Add ViT to examples
Browse files Browse the repository at this point in the history
  • Loading branch information
BirkhoffG committed Feb 28, 2024
1 parent 29f6167 commit 9547bf1
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 13 deletions.
69 changes: 56 additions & 13 deletions nbs/examples/vit.torch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"\n",
"# utils\n",
"import functools as ft\n",
"import matplotlib.pyplot as plt\n",
"from dataclasses import dataclass\n",
"from tqdm.auto import tqdm"
]
Expand Down Expand Up @@ -250,17 +251,19 @@
" xs, ys = next(iter(data_loader))\n",
" params = model.init(init_key, xs, train=False)\n",
" opt_state = optimizer.init(params)\n",
" losses = []\n",
" losses, steps = [], 0\n",
"\n",
" for epoch in range(epochs):\n",
" with tqdm(data_loader, desc=f\"Epoch {epoch}\", unit='batch') as pbar:\n",
" for batch in pbar:\n",
" rng_key, key = jrand.split(rng_key)\n",
" params, opt_state, loss = step(\n",
" params, model, optimizer, opt_state, batch, key\n",
" )\n",
" losses.append(loss)\n",
" pbar.set_postfix({\"loss\": loss})\n",
" for batch in data_loader:\n",
" rng_key, key = jrand.split(rng_key)\n",
" params, opt_state, loss = step(\n",
" params, model, optimizer, opt_state, batch, key\n",
" )\n",
" losses.append(loss)\n",
" steps += 1\n",
"\n",
" if steps % 500 == 0:\n",
" print(f\"Epoch: {epoch}, Step: {steps}, Loss: {loss}\")\n",
" return params, losses"
]
},
Expand All @@ -271,11 +274,11 @@
"outputs": [],
"source": [
"# Hyperparameters\n",
"lr = 1e-3\n",
"lr = 3e-4\n",
"dropout_rate = 0.1\n",
"beta1 = 0.9\n",
"beta2 = 0.99\n",
"batch_size = 64 * 2\n",
"batch_size = 64 * 2 * 2\n",
"patch_size = 4\n",
"num_patches = 64\n",
"num_steps = 100000\n",
Expand All @@ -297,7 +300,21 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Files already downloaded and verified\n",
"Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/CIFAR/cifar-10-python.tar.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 170498071/170498071 [00:02<00:00, 63361930.39it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting /tmp/CIFAR/cifar-10-python.tar.gz to /tmp/CIFAR/\n",
"Files already downloaded and verified\n"
]
}
Expand Down Expand Up @@ -364,7 +381,33 @@
"metadata": {},
"outputs": [],
"source": [
"_, losses = train(vit, opt, dl, 100)"
"params, losses = train(vit, opt, dl, 500)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.8149\n"
]
}
],
"source": [
"corrects = []\n",
"\n",
"dl = jdl.DataLoader(test_dataset, 'pytorch', batch_size=batch_size * 4, shuffle=True)\n",
"for batch in dl:\n",
" img, label = batch\n",
" logits = vit.apply(params, img, rngs={'dropout': jrand.PRNGKey(0)}, train=False)\n",
" preds = jnp.argmax(logits, axis=-1)\n",
" corrects.append((preds == label))\n",
"\n",
"print(f\"Accuracy: {np.concatenate(corrects).mean()}\")"
]
}
],
Expand Down
3 changes: 3 additions & 0 deletions nbs/sidebar.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ website:
sidebar:
contents:
- index.ipynb
- section: Examples
contents:
- examples/vit.torch.ipynb
- section: API
contents:
- core.ipynb
Expand Down

0 comments on commit 9547bf1

Please sign in to comment.