diff --git a/demos/Main_Demo.ipynb b/demos/Main_Demo.ipynb index d97f23867..d5d524c76 100644 --- a/demos/Main_Demo.ipynb +++ b/demos/Main_Demo.ipynb @@ -45,45 +45,42 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 292, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Running as a Jupyter notebook - intended for development only!\n" - ] - } - ], + "outputs": [], "source": [ - "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "import os\n", "DEVELOPMENT_MODE = False\n", + "# Detect if we're running in Google Colab\n", "try:\n", " import google.colab\n", " IN_COLAB = True\n", " print(\"Running as a Colab notebook\")\n", - " %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n", - " %pip install circuitsvis\n", - " \n", - " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", - " # # Install another version of node that makes PySvelte work way faster\n", - " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", - " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", "except:\n", " IN_COLAB = False\n", - " print(\"Running as a Jupyter notebook - intended for development only!\")\n", - " from IPython import get_ipython\n", "\n", - " ipython = get_ipython()\n", - " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", - " ipython.magic(\"load_ext autoreload\")\n", - " ipython.magic(\"autoreload 2\")" + "# Install if in Colab\n", + "if IN_COLAB:\n", + " %pip install transformer_lens\n", + " %pip install circuitsvis\n", + " # Install a faster Node version\n", + " !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs # noqa\n", + "\n", + "# Hot reload in development mode & not running on the CD\n", + "if not IN_COLAB:\n", + " from IPython import get_ipython\n", + " ip = get_ipython()\n", + " if not ip.extension_manager.loaded:\n", + " ip.extension_manager.load('autoreload')\n", + " %autoreload 2\n", + " \n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", + "IN_GITHUB = True\n" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 293, "metadata": {}, "outputs": [ { @@ -106,28 +103,32 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 294, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 3, - "metadata": {}, + "execution_count": 294, + "metadata": { + "text/html": { + "Content-Type": "text/html" + } + }, "output_type": "execute_result" } ], @@ -139,49 +140,34 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 295, "metadata": {}, "outputs": [], "source": [ "# Import stuff\n", "import torch\n", "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "import torch.optim as optim\n", - "import numpy as np\n", "import einops\n", "from fancy_einsum import einsum\n", "import tqdm.auto as tqdm\n", - "import random\n", - "from pathlib import Path\n", "import plotly.express as px\n", - "from torch.utils.data import DataLoader\n", - "\n", - "from jaxtyping import Float, Int\n", - "from typing import List, Union, Optional\n", - "from functools import partial\n", - "import copy\n", "\n", - "import itertools\n", - "from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n", - "import dataclasses\n", - "import datasets\n", - "from IPython.display import HTML" + "from jaxtyping import Float\n", + "from functools import partial" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 296, "metadata": {}, "outputs": [], "source": [ - "import transformer_lens\n", + "# import transformer_lens\n", "import transformer_lens.utils as utils\n", "from transformer_lens.hook_points import (\n", - " HookedRootModule,\n", " HookPoint,\n", ") # Hooking utilities\n", - "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache" + "from transformer_lens import HookedTransformer, FactoredMatrix" ] }, { @@ -193,16 +179,16 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 297, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 6, + "execution_count": 297, "metadata": {}, "output_type": "execute_result" } @@ -220,7 +206,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 298, "metadata": {}, "outputs": [], "source": [ @@ -249,7 +235,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This is a demo notebook for [TransformerLens](https://github.com/neelnanda-io/TransformerLens), **a library I ([Neel Nanda](neelnanda.io)) wrote for doing [mechanistic interpretability](https://distill.pub/2020/circuits/zoom-in/) of GPT-2 Style language models.** The goal of mechanistic interpretability is to take a trained model and reverse engineer the algorithms the model learned during training from its weights. It is a fact about the world today that we have computer programs that can essentially speak English at a human level (GPT-3, PaLM, etc), yet we have no idea how they work nor how to write one ourselves. This offends me greatly, and I would like to solve this! Mechanistic interpretability is a very young and small field, and there are a *lot* of open problems - if you would like to help, please try working on one! **If you want to skill up, check out [my guide to getting started](https://neelnanda.io/getting-started), and if you want to jump into an open problem check out my sequence [200 Concrete Open Problems in Mechanistic Interpretability](https://neelnanda.io/concrete-open-problems).**\n", + "This is a demo notebook for [TransformerLens](https://github.com/neelnanda-io/TransformerLens), **a library I ([Neel Nanda](https://neelnanda.io)) wrote for doing [mechanistic interpretability](https://distill.pub/2020/circuits/zoom-in/) of GPT-2 Style language models.** The goal of mechanistic interpretability is to take a trained model and reverse engineer the algorithms the model learned during training from its weights. It is a fact about the world today that we have computer programs that can essentially speak English at a human level (GPT-3, PaLM, etc), yet we have no idea how they work nor how to write one ourselves. This offends me greatly, and I would like to solve this! Mechanistic interpretability is a very young and small field, and there are a *lot* of open problems - if you would like to help, please try working on one! **If you want to skill up, check out [my guide to getting started](https://neelnanda.io/getting-started), and if you want to jump into an open problem check out my sequence [200 Concrete Open Problems in Mechanistic Interpretability](https://neelnanda.io/concrete-open-problems).**\n", "\n", "I wrote this library because after I left the Anthropic interpretability team and started doing independent research, I got extremely frustrated by the state of open source tooling. There's a lot of excellent infrastructure like HuggingFace and DeepSpeed to *use* or *train* models, but very little to dig into their internals and reverse engineer how they work. **This library tries to solve that**, and to make it easy to get into the field even if you don't work at an industry org with real infrastructure! The core features were heavily inspired by [Anthropic's excellent Garcon tool](https://transformer-circuits.pub/2021/garcon/index.html). Credit to Nelson Elhage and Chris Olah for building Garcon and showing me the value of good infrastructure for accelerating exploratory research!\n", "\n", @@ -268,16 +254,16 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 299, "metadata": {}, "outputs": [], "source": [ - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" + "device = utils.get_device()" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 300, "metadata": {}, "outputs": [ { @@ -296,6 +282,7 @@ } ], "source": [ + "# NBVAL_IGNORE_OUTPUT\n", "model = HookedTransformer.from_pretrained(\"gpt2-small\", device=device)" ] }, @@ -313,7 +300,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 301, "metadata": {}, "outputs": [ { @@ -351,7 +338,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 302, "metadata": {}, "outputs": [ { @@ -385,7 +372,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 303, "metadata": {}, "outputs": [ { @@ -406,7 +393,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 304, "metadata": {}, "outputs": [ { @@ -419,22 +406,26 @@ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 13, - "metadata": {}, + "execution_count": 304, + "metadata": { + "text/html": { + "Content-Type": "text/html" + } + }, "output_type": "execute_result" } ], @@ -483,7 +474,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 305, "metadata": {}, "outputs": [ { @@ -562,7 +553,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 306, "metadata": {}, "outputs": [ { @@ -611,23 +602,13 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 307, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", - "To disable this warning, you can either:\n", - "\t- Avoid using `tokenizers` before the fork if possible\n", - "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" - ] - }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b93d6b6e8c37495f84b7a00f2caf81c3", + "model_id": "980e183587f54a03bb4ead134831c94d", "version_major": 2, "version_minor": 0 }, @@ -681,7 +662,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 308, "metadata": {}, "outputs": [ { @@ -691,9 +672,9 @@ "\n", "\n", "
\n", - "
\n", - "
\n", - "
" ], "text/plain": [ - "" + "" ] }, - "metadata": {}, + "metadata": { + "text/html": { + "Content-Type": "text/html" + } + }, "output_type": "display_data" } ], "source": [ + "if IN_GITHUB:\n", + " torch.manual_seed(50)\n", + " \n", "induction_head_layer = 5\n", "induction_head_index = 5\n", - "single_random_sequence = torch.randint(1000, 10000, (1, 20)).to(model.cfg.device)\n", + "size = (1, 20)\n", + "input_tensor = torch.randint(1000, 10000, size)\n", + "\n", + "single_random_sequence = input_tensor.to(model.cfg.device)\n", "repeated_random_sequence = einops.repeat(single_random_sequence, \"batch seq_len -> batch (2 seq_len)\")\n", "def visualize_pattern_hook(\n", " pattern: Float[torch.Tensor, \"batch head_index dest_pos source_pos\"],\n", @@ -1011,7 +994,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 312, "metadata": {}, "outputs": [ { @@ -1027,7 +1010,18 @@ "text": [ "Loaded pretrained model distilgpt2 into HookedTransformer\n" ] - }, + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "distilgpt2 = HookedTransformer.from_pretrained(\"distilgpt2\", device=device)" + ] + }, + { + "cell_type": "code", + "execution_count": 313, + "metadata": {}, + "outputs": [ { "data": { "text/html": [ @@ -1035,9 +1029,9 @@ "\n", "\n", "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "