diff --git a/demos/Main_Demo.ipynb b/demos/Main_Demo.ipynb
index d97f23867..d5d524c76 100644
--- a/demos/Main_Demo.ipynb
+++ b/demos/Main_Demo.ipynb
@@ -45,45 +45,42 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 292,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Running as a Jupyter notebook - intended for development only!\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "# Janky code to do different setup when run in a Colab notebook vs VSCode\n",
+ "import os\n",
"DEVELOPMENT_MODE = False\n",
+ "# Detect if we're running in Google Colab\n",
"try:\n",
" import google.colab\n",
" IN_COLAB = True\n",
" print(\"Running as a Colab notebook\")\n",
- " %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n",
- " %pip install circuitsvis\n",
- " \n",
- " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n",
- " # # Install another version of node that makes PySvelte work way faster\n",
- " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n",
- " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n",
"except:\n",
" IN_COLAB = False\n",
- " print(\"Running as a Jupyter notebook - intended for development only!\")\n",
- " from IPython import get_ipython\n",
"\n",
- " ipython = get_ipython()\n",
- " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n",
- " ipython.magic(\"load_ext autoreload\")\n",
- " ipython.magic(\"autoreload 2\")"
+ "# Install if in Colab\n",
+ "if IN_COLAB:\n",
+ " %pip install transformer_lens\n",
+ " %pip install circuitsvis\n",
+ " # Install a faster Node version\n",
+ " !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs # noqa\n",
+ "\n",
+ "# Hot reload in development mode & not running on the CD\n",
+ "if not IN_COLAB:\n",
+ " from IPython import get_ipython\n",
+ " ip = get_ipython()\n",
+ " if not ip.extension_manager.loaded:\n",
+ " ip.extension_manager.load('autoreload')\n",
+ " %autoreload 2\n",
+ " \n",
+ "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n",
+ "IN_GITHUB = True\n"
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 293,
"metadata": {},
"outputs": [
{
@@ -106,28 +103,32 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 294,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "
\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 3,
- "metadata": {},
+ "execution_count": 294,
+ "metadata": {
+ "text/html": {
+ "Content-Type": "text/html"
+ }
+ },
"output_type": "execute_result"
}
],
@@ -139,49 +140,34 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 295,
"metadata": {},
"outputs": [],
"source": [
"# Import stuff\n",
"import torch\n",
"import torch.nn as nn\n",
- "import torch.nn.functional as F\n",
- "import torch.optim as optim\n",
- "import numpy as np\n",
"import einops\n",
"from fancy_einsum import einsum\n",
"import tqdm.auto as tqdm\n",
- "import random\n",
- "from pathlib import Path\n",
"import plotly.express as px\n",
- "from torch.utils.data import DataLoader\n",
- "\n",
- "from jaxtyping import Float, Int\n",
- "from typing import List, Union, Optional\n",
- "from functools import partial\n",
- "import copy\n",
"\n",
- "import itertools\n",
- "from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n",
- "import dataclasses\n",
- "import datasets\n",
- "from IPython.display import HTML"
+ "from jaxtyping import Float\n",
+ "from functools import partial"
]
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 296,
"metadata": {},
"outputs": [],
"source": [
- "import transformer_lens\n",
+ "# import transformer_lens\n",
"import transformer_lens.utils as utils\n",
"from transformer_lens.hook_points import (\n",
- " HookedRootModule,\n",
" HookPoint,\n",
") # Hooking utilities\n",
- "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache"
+ "from transformer_lens import HookedTransformer, FactoredMatrix"
]
},
{
@@ -193,16 +179,16 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 297,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 6,
+ "execution_count": 297,
"metadata": {},
"output_type": "execute_result"
}
@@ -220,7 +206,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 298,
"metadata": {},
"outputs": [],
"source": [
@@ -249,7 +235,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "This is a demo notebook for [TransformerLens](https://github.com/neelnanda-io/TransformerLens), **a library I ([Neel Nanda](neelnanda.io)) wrote for doing [mechanistic interpretability](https://distill.pub/2020/circuits/zoom-in/) of GPT-2 Style language models.** The goal of mechanistic interpretability is to take a trained model and reverse engineer the algorithms the model learned during training from its weights. It is a fact about the world today that we have computer programs that can essentially speak English at a human level (GPT-3, PaLM, etc), yet we have no idea how they work nor how to write one ourselves. This offends me greatly, and I would like to solve this! Mechanistic interpretability is a very young and small field, and there are a *lot* of open problems - if you would like to help, please try working on one! **If you want to skill up, check out [my guide to getting started](https://neelnanda.io/getting-started), and if you want to jump into an open problem check out my sequence [200 Concrete Open Problems in Mechanistic Interpretability](https://neelnanda.io/concrete-open-problems).**\n",
+ "This is a demo notebook for [TransformerLens](https://github.com/neelnanda-io/TransformerLens), **a library I ([Neel Nanda](https://neelnanda.io)) wrote for doing [mechanistic interpretability](https://distill.pub/2020/circuits/zoom-in/) of GPT-2 Style language models.** The goal of mechanistic interpretability is to take a trained model and reverse engineer the algorithms the model learned during training from its weights. It is a fact about the world today that we have computer programs that can essentially speak English at a human level (GPT-3, PaLM, etc), yet we have no idea how they work nor how to write one ourselves. This offends me greatly, and I would like to solve this! Mechanistic interpretability is a very young and small field, and there are a *lot* of open problems - if you would like to help, please try working on one! **If you want to skill up, check out [my guide to getting started](https://neelnanda.io/getting-started), and if you want to jump into an open problem check out my sequence [200 Concrete Open Problems in Mechanistic Interpretability](https://neelnanda.io/concrete-open-problems).**\n",
"\n",
"I wrote this library because after I left the Anthropic interpretability team and started doing independent research, I got extremely frustrated by the state of open source tooling. There's a lot of excellent infrastructure like HuggingFace and DeepSpeed to *use* or *train* models, but very little to dig into their internals and reverse engineer how they work. **This library tries to solve that**, and to make it easy to get into the field even if you don't work at an industry org with real infrastructure! The core features were heavily inspired by [Anthropic's excellent Garcon tool](https://transformer-circuits.pub/2021/garcon/index.html). Credit to Nelson Elhage and Chris Olah for building Garcon and showing me the value of good infrastructure for accelerating exploratory research!\n",
"\n",
@@ -268,16 +254,16 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 299,
"metadata": {},
"outputs": [],
"source": [
- "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
+ "device = utils.get_device()"
]
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 300,
"metadata": {},
"outputs": [
{
@@ -296,6 +282,7 @@
}
],
"source": [
+ "# NBVAL_IGNORE_OUTPUT\n",
"model = HookedTransformer.from_pretrained(\"gpt2-small\", device=device)"
]
},
@@ -313,7 +300,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 301,
"metadata": {},
"outputs": [
{
@@ -351,7 +338,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 302,
"metadata": {},
"outputs": [
{
@@ -385,7 +372,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 303,
"metadata": {},
"outputs": [
{
@@ -406,7 +393,7 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 304,
"metadata": {},
"outputs": [
{
@@ -419,22 +406,26 @@
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 13,
- "metadata": {},
+ "execution_count": 304,
+ "metadata": {
+ "text/html": {
+ "Content-Type": "text/html"
+ }
+ },
"output_type": "execute_result"
}
],
@@ -483,7 +474,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 305,
"metadata": {},
"outputs": [
{
@@ -562,7 +553,7 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 306,
"metadata": {},
"outputs": [
{
@@ -611,23 +602,13 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 307,
"metadata": {},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
- "To disable this warning, you can either:\n",
- "\t- Avoid using `tokenizers` before the fork if possible\n",
- "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
- ]
- },
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "b93d6b6e8c37495f84b7a00f2caf81c3",
+ "model_id": "980e183587f54a03bb4ead134831c94d",
"version_major": 2,
"version_minor": 0
},
@@ -681,7 +662,7 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 308,
"metadata": {},
"outputs": [
{
@@ -691,9 +672,9 @@
"\n",
"\n",
" \n",
- "
\n",
- "
\n",
- "
"
],
"text/plain": [
- "
"
+ ""
]
},
- "metadata": {},
+ "metadata": {
+ "text/html": {
+ "Content-Type": "text/html"
+ }
+ },
"output_type": "display_data"
}
],
"source": [
+ "if IN_GITHUB:\n",
+ " torch.manual_seed(50)\n",
+ " \n",
"induction_head_layer = 5\n",
"induction_head_index = 5\n",
- "single_random_sequence = torch.randint(1000, 10000, (1, 20)).to(model.cfg.device)\n",
+ "size = (1, 20)\n",
+ "input_tensor = torch.randint(1000, 10000, size)\n",
+ "\n",
+ "single_random_sequence = input_tensor.to(model.cfg.device)\n",
"repeated_random_sequence = einops.repeat(single_random_sequence, \"batch seq_len -> batch (2 seq_len)\")\n",
"def visualize_pattern_hook(\n",
" pattern: Float[torch.Tensor, \"batch head_index dest_pos source_pos\"],\n",
@@ -1011,7 +994,7 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 312,
"metadata": {},
"outputs": [
{
@@ -1027,7 +1010,18 @@
"text": [
"Loaded pretrained model distilgpt2 into HookedTransformer\n"
]
- },
+ }
+ ],
+ "source": [
+ "# NBVAL_IGNORE_OUTPUT\n",
+ "distilgpt2 = HookedTransformer.from_pretrained(\"distilgpt2\", device=device)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 313,
+ "metadata": {},
+ "outputs": [
{
"data": {
"text/html": [
@@ -1035,9 +1029,9 @@
"\n",
"\n",
" \n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "