diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 062f61a9d..31a4e6afe 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -33,7 +33,7 @@ permissions: jobs: checks: - name: Checks + name: Code Checks runs-on: ubuntu-latest strategy: matrix: @@ -67,3 +67,27 @@ jobs: # run: poetry run mypy transformer_lens - name: Build check run: poetry build + docs: + name: Documentation Checks + runs-on: ubuntu-latest + strategy: + matrix: + python-version: + - "3.9" + steps: + - uses: actions/checkout@v3 + - name: Install Poetry + uses: snok/install-poetry@v1 + with: + version: 1.4.0 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: "poetry" + - name: Install dependencies + run: | + poetry lock --check + poetry install --with dev + - name: Documentation test + run: make documentation-test diff --git a/demos/Activation_Patching_in_TL_Demo.ipynb b/demos/Activation_Patching_in_TL_Demo.ipynb index 87f44d0b0..82dc23436 100644 --- a/demos/Activation_Patching_in_TL_Demo.ipynb +++ b/demos/Activation_Patching_in_TL_Demo.ipynb @@ -43,24 +43,27 @@ } ], "source": [ - "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "import os\n", + "\n", + "IN_COLAB = 'google.colab' in str(get_ipython())\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", "DEBUG_MODE = False\n", - "try:\n", - " import google.colab\n", - " IN_COLAB = True\n", + "DO_SLOW_RUNS = not IN_GITHUB\n", + "\n", + "if IN_COLAB or IN_GITHUB:\n", " print(\"Running as a Colab notebook\")\n", " %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n", " # Install my janky personal plotting utils\n", " %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n", - "except:\n", - " IN_COLAB = False\n", + "else:\n", " print(\"Running as a Jupyter notebook - intended for development only!\")\n", " from IPython import get_ipython\n", "\n", " ipython = get_ipython()\n", " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", " ipython.magic(\"load_ext autoreload\")\n", - " ipython.magic(\"autoreload 2\")" + " ipython.magic(\"autoreload 2\")\n", + " " ] }, { @@ -319,7 +322,7 @@ "outputs": [], "source": [ "# Whether to do the runs by head and by position, which are much slower\n", - "DO_SLOW_RUNS = True" + "# DO_SLOW_RUNS = False" ] }, { diff --git a/demos/Attribution_Patching_Demo.ipynb b/demos/Attribution_Patching_Demo.ipynb index 7b76c2e42..bd7f37c75 100644 --- a/demos/Attribution_Patching_Demo.ipynb +++ b/demos/Attribution_Patching_Demo.ipynb @@ -1 +1,3004 @@ -{"cells":[{"cell_type":"markdown","metadata":{},"source":[" # Attribution Patching Demo\n"," **Read [the accompanying blog post here](https://neelnanda.io/attribution-patching) for more context**\n"," This is an interim research report, giving a whirlwind tour of some unpublished work I did at Anthropic (credit to the then team - Chris Olah, Catherine Olsson, Nelson Elhage and Tristan Hume for help, support, and mentorship!)\n","\n"," The goal of this work is run activation patching at an industrial scale, by using gradient based attribution to approximate the technique - allow an arbitrary number of patches to be made on two forwards and a single backward pass\n","\n"," I have had less time than hoped to flesh out this investigation, but am writing up a rough investigation and comparison to standard activation patching on a few tasks to give a sense of the potential of this approach, and where it works vs falls down."]},{"cell_type":"markdown","metadata":{},"source":[" To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.\n","\n"," **Tips for reading this Colab:**\n"," * You can run all this code for yourself!\n"," * The graphs are interactive!\n"," * Use the table of contents pane in the sidebar to navigate\n"," * Collapse irrelevant sections with the dropdown arrows\n"," * Search the page using the search in the sidebar, not CTRL+F"]},{"cell_type":"markdown","metadata":{},"source":[" ## Setup (Ignore)"]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Running as a Jupyter notebook - intended for development only!\n"]}],"source":["# Janky code to do different setup when run in a Colab notebook vs VSCode\n","DEBUG_MODE = False\n","try:\n"," import google.colab\n"," IN_COLAB = True\n"," print(\"Running as a Colab notebook\")\n"," %pip install transformer_lens\n"," %pip install torchtyping\n"," # Install my janky personal plotting utils\n"," %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n"," # Install another version of node that makes PySvelte work way faster\n"," !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n"," %pip install git+https://github.com/neelnanda-io/PySvelte.git\n"," # Needed for PySvelte to work, v3 came out and broke things...\n"," %pip install typeguard==2.13.3\n","except:\n"," IN_COLAB = False\n"," print(\"Running as a Jupyter notebook - intended for development only!\")\n"," from IPython import get_ipython\n","\n"," ipython = get_ipython()\n"," # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n"," ipython.magic(\"load_ext autoreload\")\n"," ipython.magic(\"autoreload 2\")"]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[],"source":["# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n","import plotly.io as pio\n","\n","if IN_COLAB or not DEBUG_MODE:\n"," # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.\n"," pio.renderers.default = \"colab\"\n","else:\n"," pio.renderers.default = \"notebook_connected\""]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[],"source":["# Import stuff\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","import numpy as np\n","import einops\n","from fancy_einsum import einsum\n","import tqdm.notebook as tqdm\n","import random\n","from pathlib import Path\n","import plotly.express as px\n","from torch.utils.data import DataLoader\n","\n","from torchtyping import TensorType as TT\n","from typing import List, Union, Optional, Callable\n","from functools import partial\n","import copy\n","import itertools\n","import json\n","\n","from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n","import dataclasses\n","import datasets\n","from IPython.display import HTML, Markdown"]},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[],"source":["import pysvelte\n","\n","import transformer_lens\n","import transformer_lens.utils as utils\n","from transformer_lens.hook_points import (\n"," HookedRootModule,\n"," HookPoint,\n",") # Hooking utilities\n","from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache"]},{"cell_type":"markdown","metadata":{},"source":[" Plotting helper functions from a janky personal library of plotting utils. The library is not documented and I recommend against trying to read it, just use your preferred plotting library if you want to do anything non-obvious:"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[],"source":["from neel_plotly import line, imshow, scatter"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[],"source":["import transformer_lens.patching as patching"]},{"cell_type":"markdown","metadata":{},"source":[" ## IOI Patching Setup\n"," This just copies the relevant set up from Exploratory Analysis Demo, and isn't very important."]},{"cell_type":"code","execution_count":8,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-small into HookedTransformer\n"]}],"source":["model = HookedTransformer.from_pretrained(\"gpt2-small\")\n","model.set_use_attn_result(True)"]},{"cell_type":"code","execution_count":9,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean string 0 <|endoftext|>When John and Mary went to the shops, John gave the bag to\n","Corrupted string 0 <|endoftext|>When John and Mary went to the shops, Mary gave the bag to\n","Answer token indices tensor([[ 5335, 1757],\n"," [ 1757, 5335],\n"," [ 4186, 3700],\n"," [ 3700, 4186],\n"," [ 6035, 15686],\n"," [15686, 6035],\n"," [ 5780, 14235],\n"," [14235, 5780]], device='cuda:0')\n"]}],"source":["prompts = ['When John and Mary went to the shops, John gave the bag to', 'When John and Mary went to the shops, Mary gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']\n","answers = [(' Mary', ' John'), (' John', ' Mary'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]\n","\n","clean_tokens = model.to_tokens(prompts)\n","# Swap each adjacent pair, with a hacky list comprehension\n","corrupted_tokens = clean_tokens[\n"," [(i+1 if i%2==0 else i-1) for i in range(len(clean_tokens)) ]\n"," ]\n","print(\"Clean string 0\", model.to_string(clean_tokens[0]))\n","print(\"Corrupted string 0\", model.to_string(corrupted_tokens[0]))\n","\n","answer_token_indices = torch.tensor([[model.to_single_token(answers[i][j]) for j in range(2)] for i in range(len(answers))], device=model.cfg.device)\n","print(\"Answer token indices\", answer_token_indices)"]},{"cell_type":"code","execution_count":10,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 3.5519\n","Corrupted logit diff: -3.5519\n"]}],"source":["def get_logit_diff(logits, answer_token_indices=answer_token_indices):\n"," if len(logits.shape)==3:\n"," # Get final logits only\n"," logits = logits[:, -1, :]\n"," correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))\n"," incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))\n"," return (correct_logits - incorrect_logits).mean()\n","\n","clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n","corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)\n","\n","clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()\n","print(f\"Clean logit diff: {clean_logit_diff:.4f}\")\n","\n","corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()\n","print(f\"Corrupted logit diff: {corrupted_logit_diff:.4f}\")"]},{"cell_type":"code","execution_count":11,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Baseline is 1: 1.0000\n","Corrupted Baseline is 0: 0.0000\n"]}],"source":["CLEAN_BASELINE = clean_logit_diff\n","CORRUPTED_BASELINE = corrupted_logit_diff\n","def ioi_metric(logits, answer_token_indices=answer_token_indices):\n"," return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE - CORRUPTED_BASELINE)\n","\n","print(f\"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}\")\n","print(f\"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Patching\n"," In the following cells, we define attribution patching and use it in various ways on the model."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["Metric = Callable[[TT[\"batch_and_pos_dims\", \"d_model\"]], float]"]},{"cell_type":"code","execution_count":13,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Value: 1.0\n","Clean Activations Cached: 220\n","Clean Gradients Cached: 220\n","Corrupted Value: 0.0\n","Corrupted Activations Cached: 220\n","Corrupted Gradients Cached: 220\n"]}],"source":["filter_not_qkv_input = lambda name: \"_input\" not in name\n","def get_cache_fwd_and_bwd(model, tokens, metric):\n"," model.reset_hooks()\n"," cache = {}\n"," def forward_cache_hook(act, hook):\n"," cache[hook.name] = act.detach()\n"," model.add_hook(filter_not_qkv_input, forward_cache_hook, \"fwd\")\n","\n"," grad_cache = {}\n"," def backward_cache_hook(act, hook):\n"," grad_cache[hook.name] = act.detach()\n"," model.add_hook(filter_not_qkv_input, backward_cache_hook, \"bwd\")\n","\n"," value = metric(model(tokens))\n"," value.backward()\n"," model.reset_hooks()\n"," return value.item(), ActivationCache(cache, model), ActivationCache(grad_cache, model)\n","\n","clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(model, clean_tokens, ioi_metric)\n","print(\"Clean Value:\", clean_value)\n","print(\"Clean Activations Cached:\", len(clean_cache))\n","print(\"Clean Gradients Cached:\", len(clean_grad_cache))\n","corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(model, corrupted_tokens, ioi_metric)\n","print(\"Corrupted Value:\", corrupted_value)\n","print(\"Corrupted Activations Cached:\", len(corrupted_cache))\n","print(\"Corrupted Gradients Cached:\", len(corrupted_grad_cache))"]},{"cell_type":"markdown","metadata":{},"source":[" ### Attention Attribution\n"," The easiest thing to start with is to not even engage with the corrupted tokens/patching, but to look at the attribution of the attention patterns - that is, the linear approximation to what happens if you set each element of the attention pattern to zero. This, as it turns out, is a good proxy to what is going on with each head!\n"," Note that this is *not* the same as what we will later do with patching. In particular, this does not set up a careful counterfactual! It's a good tool for what's generally going on in this problem, but does not control for eg stuff that systematically boosts John > Mary in general, stuff that says \"I should activate the IOI circuit\", etc. Though using logit diff as our metric *does*\n"," Each element of the batch is independent and the metric is an average logit diff, so we can analyse each batch element independently here. We'll look at the first one, and then at the average across the whole batch (note - 4 prompts have indirect object before subject, 4 prompts have it the other way round, making the average pattern harder to interpret - I plot it over the first sequence of tokens as a mildly misleading reference).\n"," We can compare it to the interpretability in the wild diagram, and basically instantly recover most of the circuit!"]},{"cell_type":"code","execution_count":14,"metadata":{},"outputs":[],"source":["def create_attention_attr(clean_cache, clean_grad_cache) -> TT[\"batch\", \"layer\", \"head_index\", \"dest\", \"src\"]:\n"," attention_stack = torch.stack([clean_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0)\n"," attention_grad_stack = torch.stack([clean_grad_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0)\n"," attention_attr = attention_grad_stack * attention_stack\n"," attention_attr = einops.rearrange(attention_attr, \"layer batch head_index dest src -> batch layer head_index dest src\")\n"," return attention_attr\n","\n","attention_attr = create_attention_attr(clean_cache, clean_grad_cache)"]},{"cell_type":"code","execution_count":15,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["['L0H0', 'L0H1', 'L0H2', 'L0H3', 'L0H4']\n","['L0H0+', 'L0H0-', 'L0H1+', 'L0H1-', 'L0H2+']\n","['L0H0Q', 'L0H0K', 'L0H0V', 'L0H1Q', 'L0H1K']\n"]}],"source":["HEAD_NAMES = [f\"L{l}H{h}\" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]\n","HEAD_NAMES_SIGNED = [f\"{name}{sign}\" for name in HEAD_NAMES for sign in [\"+\", \"-\"]]\n","HEAD_NAMES_QKV = [f\"{name}{act_name}\" for name in HEAD_NAMES for act_name in [\"Q\", \"K\", \"V\"]]\n","print(HEAD_NAMES[:5])\n","print(HEAD_NAMES_SIGNED[:5])\n","print(HEAD_NAMES_QKV[:5])"]},{"cell_type":"markdown","metadata":{},"source":[" An extremely janky way to plot the attention attribution patterns. We scale them to be in [-1, 1], split each head into a positive and negative part (so all of it is in [0, 1]), and then plot the top 20 head-halves (a head can appear twice!) by the max value of the attribution pattern."]},{"cell_type":"code","execution_count":16,"metadata":{},"outputs":[{"data":{"text/markdown":["### Attention Attribution for first sequence"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["### Summed Attention Attribution for all sequences"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\n"]}],"source":["def plot_attention_attr(attention_attr, tokens, top_k=20, index=0, title=\"\"):\n"," if len(tokens.shape)==2:\n"," tokens = tokens[index]\n"," if len(attention_attr.shape)==5:\n"," attention_attr = attention_attr[index]\n"," attention_attr_pos = attention_attr.clamp(min=-1e-5)\n"," attention_attr_neg = - attention_attr.clamp(max=1e-5)\n"," attention_attr_signed = torch.stack([attention_attr_pos, attention_attr_neg], dim=0)\n"," attention_attr_signed = einops.rearrange(attention_attr_signed, \"sign layer head_index dest src -> (layer head_index sign) dest src\")\n"," attention_attr_signed = attention_attr_signed / attention_attr_signed.max()\n"," attention_attr_indices = attention_attr_signed.max(-1).values.max(-1).values.argsort(descending=True)\n"," # print(attention_attr_indices.shape)\n"," # print(attention_attr_indices)\n"," attention_attr_signed = attention_attr_signed[attention_attr_indices, :, :]\n"," head_labels = [HEAD_NAMES_SIGNED[i.item()] for i in attention_attr_indices]\n","\n"," if title: display(Markdown(\"### \"+title))\n"," display(pysvelte.AttentionMulti(tokens=model.to_str_tokens(tokens), attention=attention_attr_signed.permute(1, 2, 0)[:, :, :top_k], head_labels=head_labels[:top_k]))\n","\n","plot_attention_attr(attention_attr, clean_tokens, index=0, title=\"Attention Attribution for first sequence\")\n","\n","plot_attention_attr(attention_attr.sum(0), clean_tokens[0], title=\"Summed Attention Attribution for all sequences\")\n","print(\"Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Attribution Patching\n"," In the following sections, I will implement various kinds of attribution patching, and then compare them to the activation patching patterns (activation patching code copied from [Exploratory Analysis Demo](https://neelnanda.io/exploratory-analysis-demo))\n"," ### Residual Stream Patching\n","
Note: We add up across both d_model and batch (Explanation).\n"," We add up along d_model because we're taking the dot product - the derivative *is* the linear map that locally linearly approximates the metric, and so we take the dot product of our change vector with the derivative vector. Equivalent, we look at the effect of changing each coordinate independently, and then combine them by adding it up - it's linear, so this totally works.\n"," We add up across batch because we're taking the average of the metric, so each individual batch element provides `1/batch_size` of the overall effect. Because each batch element is independent of the others and no information moves between activations for different inputs, the batched version is equivalent to doing attribution patching separately for each input, and then averaging - in this second version the metric per input is *not* divided by batch_size because we don't average.
"]},{"cell_type":"code","execution_count":17,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_residual(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," ) -> TT[\"component\", \"pos\"]:\n"," clean_residual, residual_labels = clean_cache.accumulated_resid(-1, incl_mid=True, return_labels=True)\n"," corrupted_residual = corrupted_cache.accumulated_resid(-1, incl_mid=True, return_labels=False)\n"," corrupted_grad_residual = corrupted_grad_cache.accumulated_resid(-1, incl_mid=True, return_labels=False)\n"," residual_attr = einops.reduce(\n"," corrupted_grad_residual * (clean_residual - corrupted_residual),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\"\n"," )\n"," return residual_attr, residual_labels\n","\n","residual_attr, residual_labels = attr_patch_residual(clean_cache, corrupted_cache, corrupted_grad_cache)\n","imshow(residual_attr, y=residual_labels, yaxis=\"Component\", xaxis=\"Position\", title=\"Residual Attribution Patching\")\n","\n","# ### Layer Output Patching"]},{"cell_type":"code","execution_count":18,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_layer_out(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," ) -> TT[\"component\", \"pos\"]:\n"," clean_layer_out, labels = clean_cache.decompose_resid(-1, return_labels=True)\n"," corrupted_layer_out = corrupted_cache.decompose_resid(-1, return_labels=False)\n"," corrupted_grad_layer_out = corrupted_grad_cache.decompose_resid(-1, return_labels=False)\n"," layer_out_attr = einops.reduce(\n"," corrupted_grad_layer_out * (clean_layer_out - corrupted_layer_out),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\"\n"," )\n"," return layer_out_attr, labels\n","\n","layer_out_attr, layer_out_labels = attr_patch_layer_out(clean_cache, corrupted_cache, corrupted_grad_cache)\n","imshow(layer_out_attr, y=layer_out_labels, yaxis=\"Component\", xaxis=\"Position\", title=\"Layer Output Attribution Patching\")"]},{"cell_type":"code","execution_count":19,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_head_out(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," ) -> TT[\"component\", \"pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_out = clean_cache.stack_head_results(-1, return_labels=False)\n"," corrupted_head_out = corrupted_cache.stack_head_results(-1, return_labels=False)\n"," corrupted_grad_head_out = corrupted_grad_cache.stack_head_results(-1, return_labels=False)\n"," head_out_attr = einops.reduce(\n"," corrupted_grad_head_out * (clean_head_out - corrupted_head_out),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\"\n"," )\n"," return head_out_attr, labels\n","\n","head_out_attr, head_out_labels = attr_patch_head_out(clean_cache, corrupted_cache, corrupted_grad_cache)\n","imshow(head_out_attr, y=head_out_labels, yaxis=\"Component\", xaxis=\"Position\", title=\"Head Output Attribution Patching\")\n","sum_head_out_attr = einops.reduce(head_out_attr, \"(layer head) pos -> layer head\", \"sum\", layer=model.cfg.n_layers, head=model.cfg.n_heads)\n","imshow(sum_head_out_attr, yaxis=\"Layer\", xaxis=\"Head Index\", title=\"Head Output Attribution Patching Sum Over Pos\")"]},{"cell_type":"markdown","metadata":{},"source":[" ### Head Activation Patching\n"," Intuitively, a head has three inputs, keys, queries and values. We can patch each of these individually to get a sense for where the important part of each head's input comes from!\n"," As a sanity check, we also do this for the mixed value. The result is a linear map of this (`z @ W_O == result`), so this is the same as patching the output of the head.\n"," We plot both the patch for each head over each position, and summed over position (it tends to be pretty sparse, so the latter is the same)"]},{"cell_type":"code","execution_count":20,"metadata":{},"outputs":[{"data":{"text/markdown":["#### Key Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Query Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Value Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Mixed Value Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","def stack_head_vector_from_cache(\n"," cache, \n"," activation_name: Literal[\"q\", \"k\", \"v\", \"z\"]\n"," ) -> TT[\"layer_and_head_index\", \"batch\", \"pos\", \"d_head\"]:\n"," \"\"\"Stacks the head vectors from the cache from a specific activation (key, query, value or mixed_value (z)) into a single tensor.\"\"\"\n"," stacked_head_vectors = torch.stack([cache[activation_name, l] for l in range(model.cfg.n_layers)], dim=0)\n"," stacked_head_vectors = einops.rearrange(\n"," stacked_head_vectors,\n"," \"layer batch pos head_index d_head -> (layer head_index) batch pos d_head\"\n"," )\n"," return stacked_head_vectors\n","\n","def attr_patch_head_vector(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," activation_name: Literal[\"q\", \"k\", \"v\", \"z\"],\n"," ) -> TT[\"component\", \"pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_vector = stack_head_vector_from_cache(clean_cache, activation_name)\n"," corrupted_head_vector = stack_head_vector_from_cache(corrupted_cache, activation_name)\n"," corrupted_grad_head_vector = stack_head_vector_from_cache(corrupted_grad_cache, activation_name)\n"," head_vector_attr = einops.reduce(\n"," corrupted_grad_head_vector * (clean_head_vector - corrupted_head_vector),\n"," \"component batch pos d_head -> component pos\",\n"," \"sum\"\n"," )\n"," return head_vector_attr, labels\n","\n","head_vector_attr_dict = {}\n","for activation_name, activation_name_full in [(\"k\", \"Key\"), (\"q\", \"Query\"), (\"v\", \"Value\"), (\"z\", \"Mixed Value\")]:\n"," display(Markdown(f\"#### {activation_name_full} Head Vector Attribution Patching\"))\n"," head_vector_attr_dict[activation_name], head_vector_labels = attr_patch_head_vector(clean_cache, corrupted_cache, corrupted_grad_cache, activation_name)\n"," imshow(head_vector_attr_dict[activation_name], y=head_vector_labels, yaxis=\"Component\", xaxis=\"Position\", title=f\"{activation_name_full} Attribution Patching\")\n"," sum_head_vector_attr = einops.reduce(head_vector_attr_dict[activation_name], \"(layer head) pos -> layer head\", \"sum\", layer=model.cfg.n_layers, head=model.cfg.n_heads)\n"," imshow(sum_head_vector_attr, yaxis=\"Layer\", xaxis=\"Head Index\", title=f\"{activation_name_full} Attribution Patching Sum Over Pos\")"]},{"cell_type":"code","execution_count":21,"metadata":{},"outputs":[{"data":{"text/markdown":["### Head Pattern Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","def stack_head_pattern_from_cache(\n"," cache, \n"," ) -> TT[\"layer_and_head_index\", \"batch\", \"dest_pos\", \"src_pos\"]:\n"," \"\"\"Stacks the head patterns from the cache into a single tensor.\"\"\"\n"," stacked_head_pattern = torch.stack([cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0)\n"," stacked_head_pattern = einops.rearrange(\n"," stacked_head_pattern,\n"," \"layer batch head_index dest_pos src_pos -> (layer head_index) batch dest_pos src_pos\"\n"," )\n"," return stacked_head_pattern\n","\n","def attr_patch_head_pattern(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," ) -> TT[\"component\", \"dest_pos\", \"src_pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_pattern = stack_head_pattern_from_cache(clean_cache)\n"," corrupted_head_pattern = stack_head_pattern_from_cache(corrupted_cache)\n"," corrupted_grad_head_pattern = stack_head_pattern_from_cache(corrupted_grad_cache)\n"," head_pattern_attr = einops.reduce(\n"," corrupted_grad_head_pattern * (clean_head_pattern - corrupted_head_pattern),\n"," \"component batch dest_pos src_pos -> component dest_pos src_pos\",\n"," \"sum\"\n"," )\n"," return head_pattern_attr, labels\n","\n","head_pattern_attr, labels = attr_patch_head_pattern(clean_cache, corrupted_cache, corrupted_grad_cache)\n","\n","plot_attention_attr(einops.rearrange(head_pattern_attr, \"(layer head) dest src -> layer head dest src\", layer=model.cfg.n_layers, head=model.cfg.n_heads), clean_tokens, index=0, title=\"Head Pattern Attribution Patching\")"]},{"cell_type":"code","execution_count":22,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_head_vector_grad_input_from_grad_cache(\n"," grad_cache: ActivationCache, \n"," activation_name: Literal[\"q\", \"k\", \"v\"],\n"," layer: int\n"," ) -> TT[\"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," vector_grad = grad_cache[activation_name, layer]\n"," ln_scales = grad_cache[\"scale\", layer, \"ln1\"]\n"," attn_layer_object = model.blocks[layer].attn\n"," if activation_name == \"q\":\n"," W = attn_layer_object.W_Q\n"," elif activation_name == \"k\":\n"," W = attn_layer_object.W_K\n"," elif activation_name == \"v\":\n"," W = attn_layer_object.W_V\n"," else:\n"," raise ValueError(\"Invalid activation name\")\n","\n"," return einsum(\"batch pos head_index d_head, batch pos, head_index d_model d_head -> batch pos head_index d_model\", vector_grad, ln_scales.squeeze(-1), W)\n","\n","def get_stacked_head_vector_grad_input(grad_cache, activation_name: Literal[\"q\", \"k\", \"v\"]) -> TT[\"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack([get_head_vector_grad_input_from_grad_cache(grad_cache, activation_name, l) for l in range(model.cfg.n_layers)], dim=0)\n","\n","def get_full_vector_grad_input(grad_cache) -> TT[\"qkv\", \"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack([get_stacked_head_vector_grad_input(grad_cache, activation_name) for activation_name in ['q', 'k', 'v']], dim=0)\n","\n","def attr_patch_head_path(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache\n"," ) -> TT[\"qkv\", \"dest_component\", \"src_component\", \"pos\"]:\n"," \"\"\"\n"," Computes the attribution patch along the path between each pair of heads.\n","\n"," Sets this to zero for the path from any late head to any early head\n","\n"," \"\"\"\n"," start_labels = HEAD_NAMES\n"," end_labels = HEAD_NAMES_QKV\n"," full_vector_grad_input = get_full_vector_grad_input(corrupted_grad_cache)\n"," clean_head_result_stack = clean_cache.stack_head_results(-1)\n"," corrupted_head_result_stack = corrupted_cache.stack_head_results(-1)\n"," diff_head_result = einops.rearrange(\n"," clean_head_result_stack - corrupted_head_result_stack,\n"," \"(layer head_index) batch pos d_model -> layer batch pos head_index d_model\",\n"," layer = model.cfg.n_layers,\n"," head_index = model.cfg.n_heads,\n"," )\n"," path_attr = einsum(\n"," \"qkv layer_end batch pos head_end d_model, layer_start batch pos head_start d_model -> qkv layer_end head_end layer_start head_start pos\", \n"," full_vector_grad_input, \n"," diff_head_result)\n"," correct_layer_order_mask = (\n"," torch.arange(model.cfg.n_layers)[None, :, None, None, None, None] > \n"," torch.arange(model.cfg.n_layers)[None, None, None, :, None, None]).to(path_attr.device)\n"," zero = torch.zeros(1, device=path_attr.device)\n"," path_attr = torch.where(correct_layer_order_mask, path_attr, zero)\n","\n"," path_attr = einops.rearrange(\n"," path_attr,\n"," \"qkv layer_end head_end layer_start head_start pos -> (layer_end head_end qkv) (layer_start head_start) pos\",\n"," )\n"," return path_attr, end_labels, start_labels\n","\n","head_path_attr, end_labels, start_labels = attr_patch_head_path(clean_cache, corrupted_cache, corrupted_grad_cache)\n","imshow(head_path_attr.sum(-1), y=end_labels, yaxis=\"Path End (Head Input)\", x=start_labels, xaxis=\"Path Start (Head Output)\", title=\"Head Path Attribution Patching\")"]},{"cell_type":"markdown","metadata":{},"source":[" This is hard to parse. Here's an experiment with filtering for the most important heads and showing their paths."]},{"cell_type":"code","execution_count":23,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["head_out_values, head_out_indices = head_out_attr.sum(-1).abs().sort(descending=True)\n","line(head_out_values)\n","top_head_indices = head_out_indices[:22].sort().values\n","top_end_indices = []\n","top_end_labels = []\n","top_start_indices = []\n","top_start_labels = []\n","for i in top_head_indices:\n"," i = i.item()\n"," top_start_indices.append(i)\n"," top_start_labels.append(start_labels[i])\n"," for j in range(3):\n"," top_end_indices.append(3*i+j)\n"," top_end_labels.append(end_labels[3*i+j])\n","\n","imshow(head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1), y=top_end_labels, yaxis=\"Path End (Head Input)\", x=top_start_labels, xaxis=\"Path Start (Head Output)\", title=\"Head Path Attribution Patching (Filtered for Top Heads)\")"]},{"cell_type":"code","execution_count":24,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["for j, composition_type in enumerate([\"Query\", \"Key\", \"Value\"]):\n"," imshow(head_path_attr[top_end_indices, :][:, top_start_indices][j::3].sum(-1), y=top_end_labels[j::3], yaxis=\"Path End (Head Input)\", x=top_start_labels, xaxis=\"Path Start (Head Output)\", title=f\"Head Path to {composition_type} Attribution Patching (Filtered for Top Heads)\")"]},{"cell_type":"code","execution_count":25,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["top_head_path_attr = einops.rearrange(head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1), \"(head_end qkv) head_start -> qkv head_end head_start\", qkv=3)\n","imshow(top_head_path_attr, y=[i[:-1] for i in top_end_labels[::3]], yaxis=\"Path End (Head Input)\", x=top_start_labels, xaxis=\"Path Start (Head Output)\", title=f\"Head Path Attribution Patching (Filtered for Top Heads)\", facet_col=0, facet_labels=[\"Query\", \"Key\", \"Value\"])"]},{"cell_type":"markdown","metadata":{},"source":[" Let's now dive into 3 interesting heads: L5H5 (induction head), L8H6 (S-Inhibition Head), L9H9 (Name Mover) and look at their input and output paths (note - Q input means )"]},{"cell_type":"code","execution_count":26,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["interesting_heads = [5 * model.cfg.n_heads + 5, 8 * model.cfg.n_heads + 6, 9 * model.cfg.n_heads + 9]\n","interesting_head_labels = [HEAD_NAMES[i] for i in interesting_heads]\n","for head_index, label in zip(interesting_heads, interesting_head_labels):\n"," in_paths = head_path_attr[3*head_index:3*head_index+3].sum(-1)\n"," out_paths = head_path_attr[:, head_index].sum(-1)\n"," out_paths = einops.rearrange(out_paths, \"(layer_head qkv) -> qkv layer_head\", qkv=3)\n"," all_paths = torch.cat([in_paths, out_paths], dim=0)\n"," all_paths = einops.rearrange(all_paths, \"path_type (layer head) -> path_type layer head\", layer=model.cfg.n_layers, head=model.cfg.n_heads)\n"," imshow(all_paths, facet_col=0, facet_labels=[\"Query (In)\", \"Key (In)\", \"Value (In)\", \"Query (Out)\", \"Key (Out)\", \"Value (Out)\"], title=f\"Input and Output Paths for head {label}\", yaxis=\"Layer\", xaxis=\"Head\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Validating Attribution vs Activation Patching\n"," Let's now compare attribution and activation patching. Generally it's a decent approximation! The main place it fails is MLP0 and the residual stream\n"," My fuzzy intuition is that attribution patching works badly for \"big\" things which are poorly modelled as linear approximations, and works well for \"small\" things which are more like incremental changes. Anything involving replacing the embedding is a \"big\" thing, which includes residual streams, and in GPT-2 small MLP0 seems to be used as an \"extended embedding\" (where later layers use MLP0's output instead of the token embedding), so I also count it as big.\n"," See more discussion in the accompanying blog post!\n"]},{"cell_type":"markdown","metadata":{},"source":[" First do some refactoring to make attribution patching more generic. We make an attribution cache, which is an ActivationCache where each element is (clean_act - corrupted_act) * corrupted_grad, so that it's the per-element attribution for each activation. Thanks to linearity, we just compute things by adding stuff up along the relevant dimensions!"]},{"cell_type":"code","execution_count":27,"metadata":{},"outputs":[],"source":["attribution_cache_dict = {}\n","for key in corrupted_grad_cache.cache_dict.keys():\n"," attribution_cache_dict[key] = corrupted_grad_cache.cache_dict[key] * (clean_cache.cache_dict[key] - corrupted_cache.cache_dict[key])\n","attr_cache = ActivationCache(attribution_cache_dict, model)"]},{"cell_type":"markdown","metadata":{},"source":[" By block: For each head we patch the starting residual stream, attention output + MLP output"]},{"cell_type":"code","execution_count":28,"metadata":{},"outputs":[],"source":["str_tokens = model.to_str_tokens(clean_tokens[0])\n","context_length = len(str_tokens)"]},{"cell_type":"code","execution_count":29,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"95a5290e11b64b6a95ef5dd37d027c7a","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_block_act_patch_result = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(every_block_act_patch_result, facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Activation Patching Per Block\", xaxis=\"Position\", yaxis=\"Layer\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])"]},{"cell_type":"code","execution_count":30,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_block_every(attr_cache):\n"," resid_pre_attr = einops.reduce(\n"," attr_cache.stack_activation(\"resid_pre\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," attn_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"attn_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," mlp_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"mlp_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n","\n"," every_block_attr_patch_result = torch.stack([resid_pre_attr, attn_out_attr, mlp_out_attr], dim=0)\n"," return every_block_attr_patch_result\n","every_block_attr_patch_result = get_attr_patch_block_every(attr_cache)\n","imshow(every_block_attr_patch_result, facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Attribution Patching Per Block\", xaxis=\"Position\", yaxis=\"Layer\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])"]},{"cell_type":"code","execution_count":31,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(y=every_block_attr_patch_result.reshape(3, -1), x=every_block_act_patch_result.reshape(3, -1), facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Attribution vs Activation Patching Per Block\", xaxis=\"Activation Patch\", yaxis=\"Attribution Patch\", hover=[f\"Layer {l}, Position {p}, |{str_tokens[p]}|\" for l in range(model.cfg.n_layers) for p in range(context_length)], color=einops.repeat(torch.arange(model.cfg.n_layers), \"layer -> (layer pos)\", pos=context_length), color_continuous_scale=\"Portland\")"]},{"cell_type":"markdown","metadata":{},"source":[" By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow."]},{"cell_type":"code","execution_count":32,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"18b2e6b0985b40cd8c0cd1a16ba62975","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=1, zmin=-1)"]},{"cell_type":"code","execution_count":33,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_all_pos_every(attr_cache):\n"," head_out_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_q_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_k_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_v_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack([head_out_all_pos_attr, head_q_all_pos_attr, head_k_all_pos_attr, head_v_all_pos_attr, head_pattern_all_pos_attr])\n"," \n","every_head_all_pos_attr_patch_result = get_attr_patch_attn_head_all_pos_every(attr_cache)\n","imshow(every_head_all_pos_attr_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=1, zmin=-1)"]},{"cell_type":"code","execution_count":34,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(y=every_head_all_pos_attr_patch_result.reshape(5, -1), x=every_head_all_pos_act_patch_result.reshape(5, -1), facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution vs Activation Patching Per Head (All Pos)\", xaxis=\"Activation Patch\", yaxis=\"Attribution Patch\", include_diag=True, hover=head_out_labels, color=einops.repeat(torch.arange(model.cfg.n_layers), \"layer -> (layer head)\", head=model.cfg.n_heads), color_continuous_scale=\"Portland\")"]},{"cell_type":"markdown","metadata":{},"source":[" We see pretty good results in general, but significant errors for heads L5H5 on query and moderate errors for head L10H7 on query and key, and moderate errors for head L11H10 on key. But each of these is fine for pattern and output. My guess is that the problem is that these have pretty saturated attention on a single token, and the linear approximation is thus not great on the attention calculation here, but I'm not sure. When we plot the attention patterns, we do see this!\n"," Note that the axis labels are for the *first* prompt's tokens, but each facet is a different prompt, so this is somewhat inaccurate. In particular, every odd facet has indirect object and subject in the opposite order (IO first). But otherwise everything lines up between the prompts"]},{"cell_type":"code","execution_count":35,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["graph_tok_labels = [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]\n","imshow(clean_cache[\"pattern\", 5][:, 5], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title=\"Attention for Head L5H5\", facet_name=\"Prompt\")\n","imshow(clean_cache[\"pattern\", 10][:, 7], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title=\"Attention for Head L10H7\", facet_name=\"Prompt\")\n","imshow(clean_cache[\"pattern\", 11][:, 10], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title=\"Attention for Head L11H10\", facet_name=\"Prompt\")\n","\n","\n","# [markdown]"]},{"cell_type":"code","execution_count":36,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"06f39489001845849fbc7446a07066f4","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/2160 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_by_pos_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","every_head_by_pos_act_patch_result = einops.rearrange(every_head_by_pos_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n","imshow(every_head_by_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))], y=head_out_labels)"]},{"cell_type":"code","execution_count":37,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_by_pos_every(attr_cache):\n"," head_out_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_q_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_k_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_v_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer dest_pos head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack([head_out_by_pos_attr, head_q_by_pos_attr, head_k_by_pos_attr, head_v_by_pos_attr, head_pattern_by_pos_attr])\n","every_head_by_pos_attr_patch_result = get_attr_patch_attn_head_by_pos_every(attr_cache)\n","every_head_by_pos_attr_patch_result = einops.rearrange(every_head_by_pos_attr_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n","imshow(every_head_by_pos_attr_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))], y=head_out_labels)"]},{"cell_type":"code","execution_count":38,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(y=every_head_by_pos_attr_patch_result.reshape(5, -1), x=every_head_by_pos_act_patch_result.reshape(5, -1), facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution vs Activation Patching Per Head (by Pos)\", xaxis=\"Activation Patch\", yaxis=\"Attribution Patch\", include_diag=True, hover=[f\"{label} {tok}\" for label in head_out_labels for tok in graph_tok_labels], color=einops.repeat(torch.arange(model.cfg.n_layers), \"layer -> (layer head pos)\", head=model.cfg.n_heads, pos = 15), color_continuous_scale=\"Portland\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Factual Knowledge Patching Example\n"," Incomplete, but maybe of interest!\n"," Note that I have better results with the corrupted prompt as having random words rather than Colosseum."]},{"cell_type":"code","execution_count":39,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-xl into HookedTransformer\n","Tokenized prompt: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Paris']\n"]},{"data":{"text/html":["
Performance on answer token:\n","Rank: 0        Logit: 20.73 Prob: 95.80% Token: | Paris|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.73\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m95.80\u001b[0m\u001b[1m% Token: | Paris|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.73 Prob: 95.80% Token: | Paris|\n","Top 1th token. Logit: 16.49 Prob: 1.39% Token: | E|\n","Top 2th token. Logit: 14.69 Prob: 0.23% Token: | the|\n","Top 3th token. Logit: 14.58 Prob: 0.21% Token: | É|\n","Top 4th token. Logit: 14.44 Prob: 0.18% Token: | France|\n","Top 5th token. Logit: 14.36 Prob: 0.16% Token: | Mont|\n","Top 6th token. Logit: 13.77 Prob: 0.09% Token: | Le|\n","Top 7th token. Logit: 13.66 Prob: 0.08% Token: | Ang|\n","Top 8th token. Logit: 13.43 Prob: 0.06% Token: | V|\n","Top 9th token. Logit: 13.42 Prob: 0.06% Token: | Stras|\n"]},{"data":{"text/html":["
Ranks of the answer tokens: [(' Paris', 0)]\n","
\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Paris'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Tokenized prompt: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Rome']\n"]},{"data":{"text/html":["
Performance on answer token:\n","Rank: 0        Logit: 20.02 Prob: 83.70% Token: | Rome|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.02\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m83.70\u001b[0m\u001b[1m% Token: | Rome|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.02 Prob: 83.70% Token: | Rome|\n","Top 1th token. Logit: 17.03 Prob: 4.23% Token: | Naples|\n","Top 2th token. Logit: 16.85 Prob: 3.51% Token: | Pompe|\n","Top 3th token. Logit: 16.14 Prob: 1.73% Token: | Ver|\n","Top 4th token. Logit: 15.87 Prob: 1.32% Token: | Florence|\n","Top 5th token. Logit: 14.77 Prob: 0.44% Token: | Roma|\n","Top 6th token. Logit: 14.68 Prob: 0.40% Token: | Milan|\n","Top 7th token. Logit: 14.66 Prob: 0.39% Token: | ancient|\n","Top 8th token. Logit: 14.37 Prob: 0.29% Token: | Pal|\n","Top 9th token. Logit: 14.30 Prob: 0.27% Token: | Constantinople|\n"]},{"data":{"text/html":["
Ranks of the answer tokens: [(' Rome', 0)]\n","
\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Rome'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"}],"source":["gpt2_xl = HookedTransformer.from_pretrained(\"gpt2-xl\")\n","clean_prompt = \"The Eiffel Tower is located in the city of\"\n","clean_answer = \" Paris\"\n","# corrupted_prompt = \"The red brown fox jumps is located in the city of\"\n","corrupted_prompt = \"The Colosseum is located in the city of\"\n","corrupted_answer = \" Rome\"\n","utils.test_prompt(clean_prompt, clean_answer, gpt2_xl)\n","utils.test_prompt(corrupted_prompt, corrupted_answer, gpt2_xl)"]},{"cell_type":"code","execution_count":40,"metadata":{},"outputs":[],"source":["clean_answer_index = gpt2_xl.to_single_token(clean_answer)\n","corrupted_answer_index = gpt2_xl.to_single_token(corrupted_answer)\n","def factual_logit_diff(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return logits[0, -1, clean_answer_index] - logits[0, -1, corrupted_answer_index]"]},{"cell_type":"code","execution_count":41,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 10.634519577026367\n","Corrupted logit diff: -8.988396644592285\n","Clean Metric: tensor(1., device='cuda:0', grad_fn=)\n","Corrupted Metric: tensor(0., device='cuda:0', grad_fn=)\n"]}],"source":["clean_logits, clean_cache = gpt2_xl.run_with_cache(clean_prompt)\n","CLEAN_LOGIT_DIFF_FACTUAL = factual_logit_diff(clean_logits).item()\n","corrupted_logits, _ = gpt2_xl.run_with_cache(corrupted_prompt)\n","CORRUPTED_LOGIT_DIFF_FACTUAL = factual_logit_diff(corrupted_logits).item()\n","\n","def factual_metric(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return (factual_logit_diff(logits) - CORRUPTED_LOGIT_DIFF_FACTUAL) / (CLEAN_LOGIT_DIFF_FACTUAL - CORRUPTED_LOGIT_DIFF_FACTUAL)\n","print(\"Clean logit diff:\", CLEAN_LOGIT_DIFF_FACTUAL)\n","print(\"Corrupted logit diff:\", CORRUPTED_LOGIT_DIFF_FACTUAL)\n","print(\"Clean Metric:\", factual_metric(clean_logits))\n","print(\"Corrupted Metric:\", factual_metric(corrupted_logits))"]},{"cell_type":"code","execution_count":42,"metadata":{},"outputs":[],"source":["# corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(gpt2_xl, corrupted_prompt, factual_metric)"]},{"cell_type":"code","execution_count":43,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Corrupted: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n"]}],"source":["clean_tokens = gpt2_xl.to_tokens(clean_prompt)\n","clean_str_tokens = gpt2_xl.to_str_tokens(clean_prompt)\n","corrupted_tokens = gpt2_xl.to_tokens(corrupted_prompt)\n","corrupted_str_tokens = gpt2_xl.to_str_tokens(corrupted_prompt)\n","print(\"Clean:\", clean_str_tokens)\n","print(\"Corrupted:\", corrupted_str_tokens)"]},{"cell_type":"code","execution_count":44,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"b767eef7a3cd49b9b3cb6e5301463f08","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/48 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def act_patch_residual(clean_cache, corrupted_tokens, model: HookedTransformer, metric):\n"," if len(corrupted_tokens.shape)==2:\n"," corrupted_tokens = corrupted_tokens[0]\n"," residual_patches = torch.zeros((model.cfg.n_layers, len(corrupted_tokens)), device=model.cfg.device)\n"," def residual_hook(resid_pre, hook, layer, pos):\n"," resid_pre[:, pos, :] = clean_cache[\"resid_pre\", layer][:, pos, :]\n"," return resid_pre\n"," for layer in tqdm.tqdm(range(model.cfg.n_layers)):\n"," for pos in range(len(corrupted_tokens)):\n"," patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[(f\"blocks.{layer}.hook_resid_pre\", partial(residual_hook, layer=layer, pos=pos))])\n"," residual_patches[layer, pos] = metric(patched_logits).item()\n"," return residual_patches\n","\n","residual_act_patch = act_patch_residual(clean_cache, corrupted_tokens, gpt2_xl, factual_metric)\n","\n","imshow(residual_act_patch, title=\"Factual Recall Patching (Residual)\", xaxis=\"Position\", yaxis=\"Layer\", x=clean_str_tokens)"]}],"metadata":{"kernelspec":{"display_name":"base","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.7.13"},"orig_nbformat":4,"vscode":{"interpreter":{"hash":"d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"}}},"nbformat":4,"nbformat_minor":2} +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " # Attribution Patching Demo\n", + " **Read [the accompanying blog post here](https://neelnanda.io/attribution-patching) for more context**\n", + " This is an interim research report, giving a whirlwind tour of some unpublished work I did at Anthropic (credit to the then team - Chris Olah, Catherine Olsson, Nelson Elhage and Tristan Hume for help, support, and mentorship!)\n", + "\n", + " The goal of this work is run activation patching at an industrial scale, by using gradient based attribution to approximate the technique - allow an arbitrary number of patches to be made on two forwards and a single backward pass\n", + "\n", + " I have had less time than hoped to flesh out this investigation, but am writing up a rough investigation and comparison to standard activation patching on a few tasks to give a sense of the potential of this approach, and where it works vs falls down." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.\n", + "\n", + " **Tips for reading this Colab:**\n", + " * You can run all this code for yourself!\n", + " * The graphs are interactive!\n", + " * Use the table of contents pane in the sidebar to navigate\n", + " * Collapse irrelevant sections with the dropdown arrows\n", + " * Search the page using the search in the sidebar, not CTRL+F" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ## Setup (Ignore)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running as a Jupyter notebook - intended for development only!\n" + ] + } + ], + "source": [ + "import os\n", + "\n", + "IN_COLAB = 'google.colab' in str(get_ipython())\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", + "DEBUG_MODE = False\n", + "DO_SLOW_RUNS = not IN_GITHUB\n", + "EPOCHS_SIZE = 4000 if not IN_GITHUB else 25\n", + "\n", + "if IN_COLAB or IN_GITHUB:\n", + " \n", + " %pip install transformer_lens\n", + " %pip install torchtyping\n", + " # Install my janky personal plotting utils\n", + " %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n", + " # Install another version of node that makes PySvelte work way faster\n", + " !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", + " %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", + " # Needed for PySvelte to work, v3 came out and broke things...\n", + " %pip install typeguard==2.13.3\n", + " print(\"Running as a Colab or github notebook\")\n", + "else:\n", + " print(\"Running as a Jupyter notebook - intended for development only!\")\n", + " from IPython import get_ipython\n", + "\n", + " ipython = get_ipython()\n", + " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", + " ipython.magic(\"load_ext autoreload\")\n", + " ipython.magic(\"autoreload 2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n", + "import plotly.io as pio\n", + "\n", + "if IN_COLAB or not DEBUG_MODE:\n", + " # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.\n", + " pio.renderers.default = \"colab\"\n", + "else:\n", + " pio.renderers.default = \"notebook_connected\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Import stuff\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "import numpy as np\n", + "import einops\n", + "from fancy_einsum import einsum\n", + "import tqdm.notebook as tqdm\n", + "import random\n", + "from pathlib import Path\n", + "import plotly.express as px\n", + "from torch.utils.data import DataLoader\n", + "\n", + "from torchtyping import TensorType as TT\n", + "from typing import List, Union, Optional, Callable\n", + "from functools import partial\n", + "import copy\n", + "import itertools\n", + "import json\n", + "\n", + "from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n", + "import dataclasses\n", + "import datasets\n", + "from IPython.display import HTML, Markdown" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import pysvelte\n", + "\n", + "import transformer_lens\n", + "import transformer_lens.utils as utils\n", + "from transformer_lens.hook_points import (\n", + " HookedRootModule,\n", + " HookPoint,\n", + ") # Hooking utilities\n", + "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " Plotting helper functions from a janky personal library of plotting utils. The library is not documented and I recommend against trying to read it, just use your preferred plotting library if you want to do anything non-obvious:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from neel_plotly import line, imshow, scatter" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import transformer_lens.patching as patching" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ## IOI Patching Setup\n", + " This just copies the relevant set up from Exploratory Analysis Demo, and isn't very important." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using pad_token, but it is not set yet.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded pretrained model gpt2-small into HookedTransformer\n" + ] + } + ], + "source": [ + "model = HookedTransformer.from_pretrained(\"gpt2-small\")\n", + "model.set_use_attn_result(True)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Clean string 0 <|endoftext|>When John and Mary went to the shops, John gave the bag to\n", + "Corrupted string 0 <|endoftext|>When John and Mary went to the shops, Mary gave the bag to\n", + "Answer token indices tensor([[ 5335, 1757],\n", + " [ 1757, 5335],\n", + " [ 4186, 3700],\n", + " [ 3700, 4186],\n", + " [ 6035, 15686],\n", + " [15686, 6035],\n", + " [ 5780, 14235],\n", + " [14235, 5780]])\n" + ] + } + ], + "source": [ + "prompts = ['When John and Mary went to the shops, John gave the bag to', 'When John and Mary went to the shops, Mary gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']\n", + "answers = [(' Mary', ' John'), (' John', ' Mary'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]\n", + "\n", + "clean_tokens = model.to_tokens(prompts)\n", + "# Swap each adjacent pair, with a hacky list comprehension\n", + "corrupted_tokens = clean_tokens[\n", + " [(i+1 if i%2==0 else i-1) for i in range(len(clean_tokens)) ]\n", + " ]\n", + "print(\"Clean string 0\", model.to_string(clean_tokens[0]))\n", + "print(\"Corrupted string 0\", model.to_string(corrupted_tokens[0]))\n", + "\n", + "answer_token_indices = torch.tensor([[model.to_single_token(answers[i][j]) for j in range(2)] for i in range(len(answers))], device=model.cfg.device)\n", + "print(\"Answer token indices\", answer_token_indices)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Clean logit diff: 3.5519\n", + "Corrupted logit diff: -3.5519\n" + ] + } + ], + "source": [ + "def get_logit_diff(logits, answer_token_indices=answer_token_indices):\n", + " if len(logits.shape)==3:\n", + " # Get final logits only\n", + " logits = logits[:, -1, :]\n", + " correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))\n", + " incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))\n", + " return (correct_logits - incorrect_logits).mean()\n", + "\n", + "clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n", + "corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)\n", + "\n", + "clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()\n", + "print(f\"Clean logit diff: {clean_logit_diff:.4f}\")\n", + "\n", + "corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()\n", + "print(f\"Corrupted logit diff: {corrupted_logit_diff:.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Clean Baseline is 1: 1.0000\n", + "Corrupted Baseline is 0: 0.0000\n" + ] + } + ], + "source": [ + "CLEAN_BASELINE = clean_logit_diff\n", + "CORRUPTED_BASELINE = corrupted_logit_diff\n", + "def ioi_metric(logits, answer_token_indices=answer_token_indices):\n", + " return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE - CORRUPTED_BASELINE)\n", + "\n", + "print(f\"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}\")\n", + "print(f\"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ## Patching\n", + " In the following cells, we define attribution patching and use it in various ways on the model." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "Metric = Callable[[TT[\"batch_and_pos_dims\", \"d_model\"]], float]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Clean Value: 1.0\n", + "Clean Activations Cached: 220\n", + "Clean Gradients Cached: 220\n", + "Corrupted Value: 0.0\n", + "Corrupted Activations Cached: 220\n", + "Corrupted Gradients Cached: 220\n" + ] + } + ], + "source": [ + "filter_not_qkv_input = lambda name: \"_input\" not in name\n", + "def get_cache_fwd_and_bwd(model, tokens, metric):\n", + " model.reset_hooks()\n", + " cache = {}\n", + " def forward_cache_hook(act, hook):\n", + " cache[hook.name] = act.detach()\n", + " model.add_hook(filter_not_qkv_input, forward_cache_hook, \"fwd\")\n", + "\n", + " grad_cache = {}\n", + " def backward_cache_hook(act, hook):\n", + " grad_cache[hook.name] = act.detach()\n", + " model.add_hook(filter_not_qkv_input, backward_cache_hook, \"bwd\")\n", + "\n", + " value = metric(model(tokens))\n", + " value.backward()\n", + " model.reset_hooks()\n", + " return value.item(), ActivationCache(cache, model), ActivationCache(grad_cache, model)\n", + "\n", + "clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(model, clean_tokens, ioi_metric)\n", + "print(\"Clean Value:\", clean_value)\n", + "print(\"Clean Activations Cached:\", len(clean_cache))\n", + "print(\"Clean Gradients Cached:\", len(clean_grad_cache))\n", + "corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(model, corrupted_tokens, ioi_metric)\n", + "print(\"Corrupted Value:\", corrupted_value)\n", + "print(\"Corrupted Activations Cached:\", len(corrupted_cache))\n", + "print(\"Corrupted Gradients Cached:\", len(corrupted_grad_cache))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ### Attention Attribution\n", + " The easiest thing to start with is to not even engage with the corrupted tokens/patching, but to look at the attribution of the attention patterns - that is, the linear approximation to what happens if you set each element of the attention pattern to zero. This, as it turns out, is a good proxy to what is going on with each head!\n", + " Note that this is *not* the same as what we will later do with patching. In particular, this does not set up a careful counterfactual! It's a good tool for what's generally going on in this problem, but does not control for eg stuff that systematically boosts John > Mary in general, stuff that says \"I should activate the IOI circuit\", etc. Though using logit diff as our metric *does*\n", + " Each element of the batch is independent and the metric is an average logit diff, so we can analyse each batch element independently here. We'll look at the first one, and then at the average across the whole batch (note - 4 prompts have indirect object before subject, 4 prompts have it the other way round, making the average pattern harder to interpret - I plot it over the first sequence of tokens as a mildly misleading reference).\n", + " We can compare it to the interpretability in the wild diagram, and basically instantly recover most of the circuit!" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "def create_attention_attr(clean_cache, clean_grad_cache) -> TT[\"batch\", \"layer\", \"head_index\", \"dest\", \"src\"]:\n", + " attention_stack = torch.stack([clean_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0)\n", + " attention_grad_stack = torch.stack([clean_grad_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0)\n", + " attention_attr = attention_grad_stack * attention_stack\n", + " attention_attr = einops.rearrange(attention_attr, \"layer batch head_index dest src -> batch layer head_index dest src\")\n", + " return attention_attr\n", + "\n", + "attention_attr = create_attention_attr(clean_cache, clean_grad_cache)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['L0H0', 'L0H1', 'L0H2', 'L0H3', 'L0H4']\n", + "['L0H0+', 'L0H0-', 'L0H1+', 'L0H1-', 'L0H2+']\n", + "['L0H0Q', 'L0H0K', 'L0H0V', 'L0H1Q', 'L0H1K']\n" + ] + } + ], + "source": [ + "HEAD_NAMES = [f\"L{l}H{h}\" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]\n", + "HEAD_NAMES_SIGNED = [f\"{name}{sign}\" for name in HEAD_NAMES for sign in [\"+\", \"-\"]]\n", + "HEAD_NAMES_QKV = [f\"{name}{act_name}\" for name in HEAD_NAMES for act_name in [\"Q\", \"K\", \"V\"]]\n", + "print(HEAD_NAMES[:5])\n", + "print(HEAD_NAMES_SIGNED[:5])\n", + "print(HEAD_NAMES_QKV[:5])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " An extremely janky way to plot the attention attribution patterns. We scale them to be in [-1, 1], split each head into a positive and negative part (so all of it is in [0, 1]), and then plot the top 20 head-halves (a head can appear twice!) by the max value of the attribution pattern." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "### Attention Attribution for first sequence" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + "\n", + " \n", + "
\n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "### Summed Attention Attribution for all sequences" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + "\n", + " \n", + "
\n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\n" + ] + } + ], + "source": [ + "def plot_attention_attr(attention_attr, tokens, top_k=20, index=0, title=\"\"):\n", + " if len(tokens.shape)==2:\n", + " tokens = tokens[index]\n", + " if len(attention_attr.shape)==5:\n", + " attention_attr = attention_attr[index]\n", + " attention_attr_pos = attention_attr.clamp(min=-1e-5)\n", + " attention_attr_neg = - attention_attr.clamp(max=1e-5)\n", + " attention_attr_signed = torch.stack([attention_attr_pos, attention_attr_neg], dim=0)\n", + " attention_attr_signed = einops.rearrange(attention_attr_signed, \"sign layer head_index dest src -> (layer head_index sign) dest src\")\n", + " attention_attr_signed = attention_attr_signed / attention_attr_signed.max()\n", + " attention_attr_indices = attention_attr_signed.max(-1).values.max(-1).values.argsort(descending=True)\n", + " # print(attention_attr_indices.shape)\n", + " # print(attention_attr_indices)\n", + " attention_attr_signed = attention_attr_signed[attention_attr_indices, :, :]\n", + " head_labels = [HEAD_NAMES_SIGNED[i.item()] for i in attention_attr_indices]\n", + "\n", + " if title: display(Markdown(\"### \"+title))\n", + " if DO_SLOW_RUNS:\n", + " display(pysvelte.AttentionMulti(tokens=model.to_str_tokens(tokens), attention=attention_attr_signed.permute(1, 2, 0)[:, :, :top_k], head_labels=head_labels[:top_k]))\n", + "\n", + "\n", + "plot_attention_attr(attention_attr, clean_tokens, index=0, title=\"Attention Attribution for first sequence\")\n", + "\n", + "plot_attention_attr(attention_attr.sum(0), clean_tokens[0], title=\"Summed Attention Attribution for all sequences\")\n", + "print(\"Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ## Attribution Patching\n", + " In the following sections, I will implement various kinds of attribution patching, and then compare them to the activation patching patterns (activation patching code copied from [Exploratory Analysis Demo](https://neelnanda.io/exploratory-analysis-demo))\n", + " ### Residual Stream Patching\n", + "
Note: We add up across both d_model and batch (Explanation).\n", + " We add up along d_model because we're taking the dot product - the derivative *is* the linear map that locally linearly approximates the metric, and so we take the dot product of our change vector with the derivative vector. Equivalent, we look at the effect of changing each coordinate independently, and then combine them by adding it up - it's linear, so this totally works.\n", + " We add up across batch because we're taking the average of the metric, so each individual batch element provides `1/batch_size` of the overall effect. Because each batch element is independent of the others and no information moves between activations for different inputs, the batched version is equivalent to doing attribution patching separately for each input, and then averaging - in this second version the metric per input is *not* divided by batch_size because we don't average.
" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def attr_patch_residual(\n", + " clean_cache: ActivationCache, \n", + " corrupted_cache: ActivationCache, \n", + " corrupted_grad_cache: ActivationCache,\n", + " ) -> TT[\"component\", \"pos\"]:\n", + " clean_residual, residual_labels = clean_cache.accumulated_resid(-1, incl_mid=True, return_labels=True)\n", + " corrupted_residual = corrupted_cache.accumulated_resid(-1, incl_mid=True, return_labels=False)\n", + " corrupted_grad_residual = corrupted_grad_cache.accumulated_resid(-1, incl_mid=True, return_labels=False)\n", + " residual_attr = einops.reduce(\n", + " corrupted_grad_residual * (clean_residual - corrupted_residual),\n", + " \"component batch pos d_model -> component pos\",\n", + " \"sum\"\n", + " )\n", + " return residual_attr, residual_labels\n", + "\n", + "residual_attr, residual_labels = attr_patch_residual(clean_cache, corrupted_cache, corrupted_grad_cache)\n", + "imshow(residual_attr, y=residual_labels, yaxis=\"Component\", xaxis=\"Position\", title=\"Residual Attribution Patching\")\n", + "\n", + "# ### Layer Output Patching" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def attr_patch_layer_out(\n", + " clean_cache: ActivationCache, \n", + " corrupted_cache: ActivationCache, \n", + " corrupted_grad_cache: ActivationCache,\n", + " ) -> TT[\"component\", \"pos\"]:\n", + " clean_layer_out, labels = clean_cache.decompose_resid(-1, return_labels=True)\n", + " corrupted_layer_out = corrupted_cache.decompose_resid(-1, return_labels=False)\n", + " corrupted_grad_layer_out = corrupted_grad_cache.decompose_resid(-1, return_labels=False)\n", + " layer_out_attr = einops.reduce(\n", + " corrupted_grad_layer_out * (clean_layer_out - corrupted_layer_out),\n", + " \"component batch pos d_model -> component pos\",\n", + " \"sum\"\n", + " )\n", + " return layer_out_attr, labels\n", + "\n", + "layer_out_attr, layer_out_labels = attr_patch_layer_out(clean_cache, corrupted_cache, corrupted_grad_cache)\n", + "imshow(layer_out_attr, y=layer_out_labels, yaxis=\"Component\", xaxis=\"Position\", title=\"Layer Output Attribution Patching\")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def attr_patch_head_out(\n", + " clean_cache: ActivationCache, \n", + " corrupted_cache: ActivationCache, \n", + " corrupted_grad_cache: ActivationCache,\n", + " ) -> TT[\"component\", \"pos\"]:\n", + " labels = HEAD_NAMES\n", + "\n", + " clean_head_out = clean_cache.stack_head_results(-1, return_labels=False)\n", + " corrupted_head_out = corrupted_cache.stack_head_results(-1, return_labels=False)\n", + " corrupted_grad_head_out = corrupted_grad_cache.stack_head_results(-1, return_labels=False)\n", + " head_out_attr = einops.reduce(\n", + " corrupted_grad_head_out * (clean_head_out - corrupted_head_out),\n", + " \"component batch pos d_model -> component pos\",\n", + " \"sum\"\n", + " )\n", + " return head_out_attr, labels\n", + "\n", + "head_out_attr, head_out_labels = attr_patch_head_out(clean_cache, corrupted_cache, corrupted_grad_cache)\n", + "imshow(head_out_attr, y=head_out_labels, yaxis=\"Component\", xaxis=\"Position\", title=\"Head Output Attribution Patching\")\n", + "sum_head_out_attr = einops.reduce(head_out_attr, \"(layer head) pos -> layer head\", \"sum\", layer=model.cfg.n_layers, head=model.cfg.n_heads)\n", + "imshow(sum_head_out_attr, yaxis=\"Layer\", xaxis=\"Head Index\", title=\"Head Output Attribution Patching Sum Over Pos\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ### Head Activation Patching\n", + " Intuitively, a head has three inputs, keys, queries and values. We can patch each of these individually to get a sense for where the important part of each head's input comes from!\n", + " As a sanity check, we also do this for the mixed value. The result is a linear map of this (`z @ W_O == result`), so this is the same as patching the output of the head.\n", + " We plot both the patch for each head over each position, and summed over position (it tends to be pretty sparse, so the latter is the same)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "#### Key Head Vector Attribution Patching" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "#### Query Head Vector Attribution Patching" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "#### Value Head Vector Attribution Patching" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "#### Mixed Value Head Vector Attribution Patching" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from typing_extensions import Literal\n", + "def stack_head_vector_from_cache(\n", + " cache, \n", + " activation_name: Literal[\"q\", \"k\", \"v\", \"z\"]\n", + " ) -> TT[\"layer_and_head_index\", \"batch\", \"pos\", \"d_head\"]:\n", + " \"\"\"Stacks the head vectors from the cache from a specific activation (key, query, value or mixed_value (z)) into a single tensor.\"\"\"\n", + " stacked_head_vectors = torch.stack([cache[activation_name, l] for l in range(model.cfg.n_layers)], dim=0)\n", + " stacked_head_vectors = einops.rearrange(\n", + " stacked_head_vectors,\n", + " \"layer batch pos head_index d_head -> (layer head_index) batch pos d_head\"\n", + " )\n", + " return stacked_head_vectors\n", + "\n", + "def attr_patch_head_vector(\n", + " clean_cache: ActivationCache, \n", + " corrupted_cache: ActivationCache, \n", + " corrupted_grad_cache: ActivationCache,\n", + " activation_name: Literal[\"q\", \"k\", \"v\", \"z\"],\n", + " ) -> TT[\"component\", \"pos\"]:\n", + " labels = HEAD_NAMES\n", + "\n", + " clean_head_vector = stack_head_vector_from_cache(clean_cache, activation_name)\n", + " corrupted_head_vector = stack_head_vector_from_cache(corrupted_cache, activation_name)\n", + " corrupted_grad_head_vector = stack_head_vector_from_cache(corrupted_grad_cache, activation_name)\n", + " head_vector_attr = einops.reduce(\n", + " corrupted_grad_head_vector * (clean_head_vector - corrupted_head_vector),\n", + " \"component batch pos d_head -> component pos\",\n", + " \"sum\"\n", + " )\n", + " return head_vector_attr, labels\n", + "\n", + "head_vector_attr_dict = {}\n", + "for activation_name, activation_name_full in [(\"k\", \"Key\"), (\"q\", \"Query\"), (\"v\", \"Value\"), (\"z\", \"Mixed Value\")]:\n", + " display(Markdown(f\"#### {activation_name_full} Head Vector Attribution Patching\"))\n", + " head_vector_attr_dict[activation_name], head_vector_labels = attr_patch_head_vector(clean_cache, corrupted_cache, corrupted_grad_cache, activation_name)\n", + " imshow(head_vector_attr_dict[activation_name], y=head_vector_labels, yaxis=\"Component\", xaxis=\"Position\", title=f\"{activation_name_full} Attribution Patching\")\n", + " sum_head_vector_attr = einops.reduce(head_vector_attr_dict[activation_name], \"(layer head) pos -> layer head\", \"sum\", layer=model.cfg.n_layers, head=model.cfg.n_heads)\n", + " imshow(sum_head_vector_attr, yaxis=\"Layer\", xaxis=\"Head Index\", title=f\"{activation_name_full} Attribution Patching Sum Over Pos\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "### Head Pattern Attribution Patching" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + "\n", + " \n", + "
\n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from typing_extensions import Literal\n", + "def stack_head_pattern_from_cache(\n", + " cache, \n", + " ) -> TT[\"layer_and_head_index\", \"batch\", \"dest_pos\", \"src_pos\"]:\n", + " \"\"\"Stacks the head patterns from the cache into a single tensor.\"\"\"\n", + " stacked_head_pattern = torch.stack([cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0)\n", + " stacked_head_pattern = einops.rearrange(\n", + " stacked_head_pattern,\n", + " \"layer batch head_index dest_pos src_pos -> (layer head_index) batch dest_pos src_pos\"\n", + " )\n", + " return stacked_head_pattern\n", + "\n", + "def attr_patch_head_pattern(\n", + " clean_cache: ActivationCache, \n", + " corrupted_cache: ActivationCache, \n", + " corrupted_grad_cache: ActivationCache,\n", + " ) -> TT[\"component\", \"dest_pos\", \"src_pos\"]:\n", + " labels = HEAD_NAMES\n", + "\n", + " clean_head_pattern = stack_head_pattern_from_cache(clean_cache)\n", + " corrupted_head_pattern = stack_head_pattern_from_cache(corrupted_cache)\n", + " corrupted_grad_head_pattern = stack_head_pattern_from_cache(corrupted_grad_cache)\n", + " head_pattern_attr = einops.reduce(\n", + " corrupted_grad_head_pattern * (clean_head_pattern - corrupted_head_pattern),\n", + " \"component batch dest_pos src_pos -> component dest_pos src_pos\",\n", + " \"sum\"\n", + " )\n", + " return head_pattern_attr, labels\n", + "\n", + "head_pattern_attr, labels = attr_patch_head_pattern(clean_cache, corrupted_cache, corrupted_grad_cache)\n", + "\n", + "plot_attention_attr(einops.rearrange(head_pattern_attr, \"(layer head) dest src -> layer head dest src\", layer=model.cfg.n_layers, head=model.cfg.n_heads), clean_tokens, index=0, title=\"Head Pattern Attribution Patching\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def get_head_vector_grad_input_from_grad_cache(\n", + " grad_cache: ActivationCache, \n", + " activation_name: Literal[\"q\", \"k\", \"v\"],\n", + " layer: int\n", + " ) -> TT[\"batch\", \"pos\", \"head_index\", \"d_model\"]:\n", + " vector_grad = grad_cache[activation_name, layer]\n", + " ln_scales = grad_cache[\"scale\", layer, \"ln1\"]\n", + " attn_layer_object = model.blocks[layer].attn\n", + " if activation_name == \"q\":\n", + " W = attn_layer_object.W_Q\n", + " elif activation_name == \"k\":\n", + " W = attn_layer_object.W_K\n", + " elif activation_name == \"v\":\n", + " W = attn_layer_object.W_V\n", + " else:\n", + " raise ValueError(\"Invalid activation name\")\n", + "\n", + " return einsum(\"batch pos head_index d_head, batch pos, head_index d_model d_head -> batch pos head_index d_model\", vector_grad, ln_scales.squeeze(-1), W)\n", + "\n", + "def get_stacked_head_vector_grad_input(grad_cache, activation_name: Literal[\"q\", \"k\", \"v\"]) -> TT[\"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n", + " return torch.stack([get_head_vector_grad_input_from_grad_cache(grad_cache, activation_name, l) for l in range(model.cfg.n_layers)], dim=0)\n", + "\n", + "def get_full_vector_grad_input(grad_cache) -> TT[\"qkv\", \"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n", + " return torch.stack([get_stacked_head_vector_grad_input(grad_cache, activation_name) for activation_name in ['q', 'k', 'v']], dim=0)\n", + "\n", + "def attr_patch_head_path(\n", + " clean_cache: ActivationCache, \n", + " corrupted_cache: ActivationCache, \n", + " corrupted_grad_cache: ActivationCache\n", + " ) -> TT[\"qkv\", \"dest_component\", \"src_component\", \"pos\"]:\n", + " \"\"\"\n", + " Computes the attribution patch along the path between each pair of heads.\n", + "\n", + " Sets this to zero for the path from any late head to any early head\n", + "\n", + " \"\"\"\n", + " start_labels = HEAD_NAMES\n", + " end_labels = HEAD_NAMES_QKV\n", + " full_vector_grad_input = get_full_vector_grad_input(corrupted_grad_cache)\n", + " clean_head_result_stack = clean_cache.stack_head_results(-1)\n", + " corrupted_head_result_stack = corrupted_cache.stack_head_results(-1)\n", + " diff_head_result = einops.rearrange(\n", + " clean_head_result_stack - corrupted_head_result_stack,\n", + " \"(layer head_index) batch pos d_model -> layer batch pos head_index d_model\",\n", + " layer = model.cfg.n_layers,\n", + " head_index = model.cfg.n_heads,\n", + " )\n", + " path_attr = einsum(\n", + " \"qkv layer_end batch pos head_end d_model, layer_start batch pos head_start d_model -> qkv layer_end head_end layer_start head_start pos\", \n", + " full_vector_grad_input, \n", + " diff_head_result)\n", + " correct_layer_order_mask = (\n", + " torch.arange(model.cfg.n_layers)[None, :, None, None, None, None] > \n", + " torch.arange(model.cfg.n_layers)[None, None, None, :, None, None]).to(path_attr.device)\n", + " zero = torch.zeros(1, device=path_attr.device)\n", + " path_attr = torch.where(correct_layer_order_mask, path_attr, zero)\n", + "\n", + " path_attr = einops.rearrange(\n", + " path_attr,\n", + " \"qkv layer_end head_end layer_start head_start pos -> (layer_end head_end qkv) (layer_start head_start) pos\",\n", + " )\n", + " return path_attr, end_labels, start_labels\n", + "\n", + "head_path_attr, end_labels, start_labels = attr_patch_head_path(clean_cache, corrupted_cache, corrupted_grad_cache)\n", + "imshow(head_path_attr.sum(-1), y=end_labels, yaxis=\"Path End (Head Input)\", x=start_labels, xaxis=\"Path Start (Head Output)\", title=\"Head Path Attribution Patching\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " This is hard to parse. Here's an experiment with filtering for the most important heads and showing their paths." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "head_out_values, head_out_indices = head_out_attr.sum(-1).abs().sort(descending=True)\n", + "line(head_out_values)\n", + "top_head_indices = head_out_indices[:22].sort().values\n", + "top_end_indices = []\n", + "top_end_labels = []\n", + "top_start_indices = []\n", + "top_start_labels = []\n", + "for i in top_head_indices:\n", + " i = i.item()\n", + " top_start_indices.append(i)\n", + " top_start_labels.append(start_labels[i])\n", + " for j in range(3):\n", + " top_end_indices.append(3*i+j)\n", + " top_end_labels.append(end_labels[3*i+j])\n", + "\n", + "imshow(head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1), y=top_end_labels, yaxis=\"Path End (Head Input)\", x=top_start_labels, xaxis=\"Path Start (Head Output)\", title=\"Head Path Attribution Patching (Filtered for Top Heads)\")" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for j, composition_type in enumerate([\"Query\", \"Key\", \"Value\"]):\n", + " imshow(head_path_attr[top_end_indices, :][:, top_start_indices][j::3].sum(-1), y=top_end_labels[j::3], yaxis=\"Path End (Head Input)\", x=top_start_labels, xaxis=\"Path Start (Head Output)\", title=f\"Head Path to {composition_type} Attribution Patching (Filtered for Top Heads)\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "top_head_path_attr = einops.rearrange(head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1), \"(head_end qkv) head_start -> qkv head_end head_start\", qkv=3)\n", + "imshow(top_head_path_attr, y=[i[:-1] for i in top_end_labels[::3]], yaxis=\"Path End (Head Input)\", x=top_start_labels, xaxis=\"Path Start (Head Output)\", title=f\"Head Path Attribution Patching (Filtered for Top Heads)\", facet_col=0, facet_labels=[\"Query\", \"Key\", \"Value\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " Let's now dive into 3 interesting heads: L5H5 (induction head), L8H6 (S-Inhibition Head), L9H9 (Name Mover) and look at their input and output paths (note - Q input means )" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "interesting_heads = [5 * model.cfg.n_heads + 5, 8 * model.cfg.n_heads + 6, 9 * model.cfg.n_heads + 9]\n", + "interesting_head_labels = [HEAD_NAMES[i] for i in interesting_heads]\n", + "for head_index, label in zip(interesting_heads, interesting_head_labels):\n", + " in_paths = head_path_attr[3*head_index:3*head_index+3].sum(-1)\n", + " out_paths = head_path_attr[:, head_index].sum(-1)\n", + " out_paths = einops.rearrange(out_paths, \"(layer_head qkv) -> qkv layer_head\", qkv=3)\n", + " all_paths = torch.cat([in_paths, out_paths], dim=0)\n", + " all_paths = einops.rearrange(all_paths, \"path_type (layer head) -> path_type layer head\", layer=model.cfg.n_layers, head=model.cfg.n_heads)\n", + " imshow(all_paths, facet_col=0, facet_labels=[\"Query (In)\", \"Key (In)\", \"Value (In)\", \"Query (Out)\", \"Key (Out)\", \"Value (Out)\"], title=f\"Input and Output Paths for head {label}\", yaxis=\"Layer\", xaxis=\"Head\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ## Validating Attribution vs Activation Patching\n", + " Let's now compare attribution and activation patching. Generally it's a decent approximation! The main place it fails is MLP0 and the residual stream\n", + " My fuzzy intuition is that attribution patching works badly for \"big\" things which are poorly modelled as linear approximations, and works well for \"small\" things which are more like incremental changes. Anything involving replacing the embedding is a \"big\" thing, which includes residual streams, and in GPT-2 small MLP0 seems to be used as an \"extended embedding\" (where later layers use MLP0's output instead of the token embedding), so I also count it as big.\n", + " See more discussion in the accompanying blog post!\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " First do some refactoring to make attribution patching more generic. We make an attribution cache, which is an ActivationCache where each element is (clean_act - corrupted_act) * corrupted_grad, so that it's the per-element attribution for each activation. Thanks to linearity, we just compute things by adding stuff up along the relevant dimensions!" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "attribution_cache_dict = {}\n", + "for key in corrupted_grad_cache.cache_dict.keys():\n", + " attribution_cache_dict[key] = corrupted_grad_cache.cache_dict[key] * (clean_cache.cache_dict[key] - corrupted_cache.cache_dict[key])\n", + "attr_cache = ActivationCache(attribution_cache_dict, model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " By block: For each head we patch the starting residual stream, attention output + MLP output" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "str_tokens = model.to_str_tokens(clean_tokens[0])\n", + "context_length = len(str_tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c1f3b2c7d0674e45a0fce79d50fc0734", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/180 [00:00\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "every_block_act_patch_result = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)\n", + "imshow(every_block_act_patch_result, facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Activation Patching Per Block\", xaxis=\"Position\", yaxis=\"Layer\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def get_attr_patch_block_every(attr_cache):\n", + " resid_pre_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"resid_pre\"),\n", + " \"layer batch pos d_model -> layer pos\",\n", + " \"sum\",\n", + " )\n", + " attn_out_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"attn_out\"),\n", + " \"layer batch pos d_model -> layer pos\",\n", + " \"sum\",\n", + " )\n", + " mlp_out_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"mlp_out\"),\n", + " \"layer batch pos d_model -> layer pos\",\n", + " \"sum\",\n", + " )\n", + "\n", + " every_block_attr_patch_result = torch.stack([resid_pre_attr, attn_out_attr, mlp_out_attr], dim=0)\n", + " return every_block_attr_patch_result\n", + "every_block_attr_patch_result = get_attr_patch_block_every(attr_cache)\n", + "imshow(every_block_attr_patch_result, facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Attribution Patching Per Block\", xaxis=\"Position\", yaxis=\"Layer\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "if DO_SLOW_RUNS:\n", + " scatter(y=every_block_attr_patch_result.reshape(3, -1), x=every_block_act_patch_result.reshape(3, -1), facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Attribution vs Activation Patching Per Block\", xaxis=\"Activation Patch\", yaxis=\"Attribution Patch\", hover=[f\"Layer {l}, Position {p}, |{str_tokens[p]}|\" for l in range(model.cfg.n_layers) for p in range(context_length)], color=einops.repeat(torch.arange(model.cfg.n_layers), \"layer -> (layer pos)\", pos=context_length), color_continuous_scale=\"Portland\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "810caa7c89cc472292d66698dda6ced9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/144 [00:00\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)\n", + "if DO_SLOW_RUNS:\n", + " imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=1, zmin=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def get_attr_patch_attn_head_all_pos_every(attr_cache):\n", + " head_out_all_pos_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"z\"),\n", + " \"layer batch pos head_index d_head -> layer head_index\",\n", + " \"sum\",\n", + " )\n", + " head_q_all_pos_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"q\"),\n", + " \"layer batch pos head_index d_head -> layer head_index\",\n", + " \"sum\",\n", + " )\n", + " head_k_all_pos_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"k\"),\n", + " \"layer batch pos head_index d_head -> layer head_index\",\n", + " \"sum\",\n", + " )\n", + " head_v_all_pos_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"v\"),\n", + " \"layer batch pos head_index d_head -> layer head_index\",\n", + " \"sum\",\n", + " )\n", + " head_pattern_all_pos_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"pattern\"),\n", + " \"layer batch head_index dest_pos src_pos -> layer head_index\",\n", + " \"sum\",\n", + " )\n", + "\n", + " return torch.stack([head_out_all_pos_attr, head_q_all_pos_attr, head_k_all_pos_attr, head_v_all_pos_attr, head_pattern_all_pos_attr])\n", + " \n", + "every_head_all_pos_attr_patch_result = get_attr_patch_attn_head_all_pos_every(attr_cache)\n", + "imshow(every_head_all_pos_attr_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=1, zmin=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "if DO_SLOW_RUNS:\n", + " scatter(y=every_head_all_pos_attr_patch_result.reshape(5, -1), x=every_head_all_pos_act_patch_result.reshape(5, -1), facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution vs Activation Patching Per Head (All Pos)\", xaxis=\"Activation Patch\", yaxis=\"Attribution Patch\", include_diag=True, hover=head_out_labels, color=einops.repeat(torch.arange(model.cfg.n_layers), \"layer -> (layer head)\", head=model.cfg.n_heads), color_continuous_scale=\"Portland\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " We see pretty good results in general, but significant errors for heads L5H5 on query and moderate errors for head L10H7 on query and key, and moderate errors for head L11H10 on key. But each of these is fine for pattern and output. My guess is that the problem is that these have pretty saturated attention on a single token, and the linear approximation is thus not great on the attention calculation here, but I'm not sure. When we plot the attention patterns, we do see this!\n", + " Note that the axis labels are for the *first* prompt's tokens, but each facet is a different prompt, so this is somewhat inaccurate. In particular, every odd facet has indirect object and subject in the opposite order (IO first). But otherwise everything lines up between the prompts" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "graph_tok_labels = [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]\n", + "imshow(clean_cache[\"pattern\", 5][:, 5], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title=\"Attention for Head L5H5\", facet_name=\"Prompt\")\n", + "imshow(clean_cache[\"pattern\", 10][:, 7], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title=\"Attention for Head L10H7\", facet_name=\"Prompt\")\n", + "imshow(clean_cache[\"pattern\", 11][:, 10], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title=\"Attention for Head L11H10\", facet_name=\"Prompt\")\n", + "\n", + "\n", + "# [markdown]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'torch' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_19371/3522707877.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mIN_GITHUB\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m;\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mevery_head_by_pos_act_patch_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpatching\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_act_patch_attn_head_by_pos_every\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcorrupted_tokens\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclean_cache\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mioi_metric\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mIN_GITHUB\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLongTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcorrupted_tokens\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mevery_head_by_pos_act_patch_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0meinops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrearrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mevery_head_by_pos_act_patch_result\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"act_type layer pos head -> act_type (layer head) pos\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mimshow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mevery_head_by_pos_act_patch_result\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfacet_col\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfacet_labels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"Output\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"Query\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"Key\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"Value\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"Pattern\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtitle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Activation Patching Per Head (By Pos)\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxaxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Position\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myaxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Layer & Head\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mzmax\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mzmin\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34mf\"{tok}_{i}\"\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtok\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_str_tokens\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclean_tokens\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mhead_out_labels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined" + ] + } + ], + "source": [ + "every_head_by_pos_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, ioi_metric) if not IN_GITHUB else torch.unsqueeze(torch.unsqueeze(torch.LongTensor(corrupted_tokens), 2), 0)\n", + "every_head_by_pos_act_patch_result = einops.rearrange(every_head_by_pos_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n", + "if not IN_GITHUB:\n", + " imshow(every_head_by_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))], y=head_out_labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def get_attr_patch_attn_head_by_pos_every(attr_cache):\n", + " head_out_by_pos_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"z\"),\n", + " \"layer batch pos head_index d_head -> layer pos head_index\",\n", + " \"sum\",\n", + " )\n", + " head_q_by_pos_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"q\"),\n", + " \"layer batch pos head_index d_head -> layer pos head_index\",\n", + " \"sum\",\n", + " )\n", + " head_k_by_pos_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"k\"),\n", + " \"layer batch pos head_index d_head -> layer pos head_index\",\n", + " \"sum\",\n", + " )\n", + " head_v_by_pos_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"v\"),\n", + " \"layer batch pos head_index d_head -> layer pos head_index\",\n", + " \"sum\",\n", + " )\n", + " head_pattern_by_pos_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"pattern\"),\n", + " \"layer batch head_index dest_pos src_pos -> layer dest_pos head_index\",\n", + " \"sum\",\n", + " )\n", + "\n", + " return torch.stack([head_out_by_pos_attr, head_q_by_pos_attr, head_k_by_pos_attr, head_v_by_pos_attr, head_pattern_by_pos_attr])\n", + "every_head_by_pos_attr_patch_result = get_attr_patch_attn_head_by_pos_every(attr_cache)\n", + "every_head_by_pos_attr_patch_result = einops.rearrange(every_head_by_pos_attr_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n", + "if DO_SLOW_RUNS:\n", + " imshow(every_head_by_pos_attr_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))], y=head_out_labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "if DO_SLOW_RUNS:\n", + " scatter(y=every_head_by_pos_attr_patch_result.reshape(5, -1), x=every_head_by_pos_act_patch_result.reshape(5, -1), facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution vs Activation Patching Per Head (by Pos)\", xaxis=\"Activation Patch\", yaxis=\"Attribution Patch\", include_diag=True, hover=[f\"{label} {tok}\" for label in head_out_labels for tok in graph_tok_labels], color=einops.repeat(torch.arange(model.cfg.n_layers), \"layer -> (layer head pos)\", head=model.cfg.n_heads, pos = 15), color_continuous_scale=\"Portland\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ## Factual Knowledge Patching Example\n", + " Incomplete, but maybe of interest!\n", + " Note that I have better results with the corrupted prompt as having random words rather than Colosseum." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mCanceled future for execute_request message before replies were done" + ] + }, + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click here for more info. View Jupyter log for further details." + ] + } + ], + "source": [ + "\n", + "gpt2_xl = HookedTransformer.from_pretrained(\"gpt2-xl\" if not IN_GITHUB else \"gpt2-small\", device=device)\n", + "clean_prompt = \"The Eiffel Tower is located in the city of\"\n", + "clean_answer = \" Paris\"\n", + "# corrupted_prompt = \"The red brown fox jumps is located in the city of\"\n", + "corrupted_prompt = \"The Colosseum is located in the city of\"\n", + "corrupted_answer = \" Rome\"\n", + "utils.test_prompt(clean_prompt, clean_answer, gpt2_xl)\n", + "utils.test_prompt(corrupted_prompt, corrupted_answer, gpt2_xl)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "clean_answer_index = gpt2_xl.to_single_token(clean_answer)\n", + "corrupted_answer_index = gpt2_xl.to_single_token(corrupted_answer)\n", + "def factual_logit_diff(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n", + " return logits[0, -1, clean_answer_index] - logits[0, -1, corrupted_answer_index]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def factual_metric(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n", + " return (factual_logit_diff(logits) - CORRUPTED_LOGIT_DIFF_FACTUAL) / (CLEAN_LOGIT_DIFF_FACTUAL - CORRUPTED_LOGIT_DIFF_FACTUAL)\n", + "\n", + "clean_logits, clean_cache = gpt2_xl.run_with_cache(clean_prompt)\n", + "CLEAN_LOGIT_DIFF_FACTUAL = factual_logit_diff(clean_logits).item()\n", + "corrupted_logits, _ = gpt2_xl.run_with_cache(corrupted_prompt)\n", + "CORRUPTED_LOGIT_DIFF_FACTUAL = factual_logit_diff(corrupted_logits).item()\n", + "\n", + "print(\"Clean logit diff:\", CLEAN_LOGIT_DIFF_FACTUAL)\n", + "print(\"Corrupted logit diff:\", CORRUPTED_LOGIT_DIFF_FACTUAL)\n", + "print(\"Clean Metric:\", factual_metric(clean_logits))\n", + "print(\"Corrupted Metric:\", factual_metric(corrupted_logits))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(gpt2_xl, corrupted_prompt, factual_metric)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "clean_tokens = gpt2_xl.to_tokens(clean_prompt)\n", + "clean_str_tokens = gpt2_xl.to_str_tokens(clean_prompt)\n", + "corrupted_tokens = gpt2_xl.to_tokens(corrupted_prompt)\n", + "corrupted_str_tokens = gpt2_xl.to_str_tokens(corrupted_prompt)\n", + "print(\"Clean:\", clean_str_tokens)\n", + "print(\"Corrupted:\", corrupted_str_tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def act_patch_residual(clean_cache, corrupted_tokens, model: HookedTransformer, metric):\n", + " if len(corrupted_tokens.shape)==2:\n", + " corrupted_tokens = corrupted_tokens[0]\n", + " residual_patches = torch.zeros((model.cfg.n_layers, len(corrupted_tokens)), device=model.cfg.device)\n", + " def residual_hook(resid_pre, hook, layer, pos):\n", + " resid_pre[:, pos, :] = clean_cache[\"resid_pre\", layer][:, pos, :]\n", + " return resid_pre\n", + " for layer in tqdm.tqdm(range(model.cfg.n_layers)):\n", + " for pos in range(len(corrupted_tokens)):\n", + " patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[(f\"blocks.{layer}.hook_resid_pre\", partial(residual_hook, layer=layer, pos=pos))])\n", + " residual_patches[layer, pos] = metric(patched_logits).item()\n", + " return residual_patches\n", + "\n", + "\n", + "if DO_SLOW_RUNS:\n", + " residual_act_patch = act_patch_residual(clean_cache, corrupted_tokens, gpt2_xl, factual_metric)\n", + " imshow(residual_act_patch, title=\"Factual Recall Patching (Residual)\", xaxis=\"Position\", yaxis=\"Layer\", x=clean_str_tokens)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/demos/Exploratory_Analysis_Demo.ipynb b/demos/Exploratory_Analysis_Demo.ipynb index db4095e6f..239696902 100644 --- a/demos/Exploratory_Analysis_Demo.ipynb +++ b/demos/Exploratory_Analysis_Demo.ipynb @@ -75,18 +75,20 @@ } ], "source": [ - "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "import os\n", + "\n", + "IN_COLAB = 'google.colab' in str(get_ipython())\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", "DEBUG_MODE = False\n", - "try:\n", - " import google.colab\n", - " IN_COLAB = True\n", + "DO_SLOW_RUNS = not IN_GITHUB\n", + "\n", + "if IN_COLAB or IN_GITHUB:\n", " print(\"Running as a Colab notebook\")\n", " %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n", " # Install another version of node that makes PySvelte work way faster\n", " !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", " %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", - "except:\n", - " IN_COLAB = False\n", + "else:\n", " print(\"Running as a Jupyter notebook - intended for development only!\")\n", " from IPython import get_ipython\n", "\n", @@ -188,6 +190,17 @@ "torch.set_grad_enabled(False)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# This variable needs to be used instead of allowing the default\n", + "# Any calls to .cuda() need to be .to(device) to allow for your notebook to be compatible with github CI\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -336,6 +349,7 @@ " center_writing_weights=True,\n", " fold_ln=True,\n", " refactor_factored_attn_matrices=True,\n", + " device=device\n", ")" ] }, @@ -478,7 +492,8 @@ " )\n", " # Insert the *incorrect* answer to the prompt, making the correct answer the indirect object.\n", " prompts.append(prompt_format[i].format(answers[-1][1]))\n", - "answer_tokens = torch.tensor(answer_tokens).cuda()\n", + "answer_tokens = answer_tokens.cuda() if not IN_GITHUB else torch.LongTensor(answer_tokens)\n", + "\n", "print(prompts)\n", "print(answers)" ] @@ -518,7 +533,7 @@ "source": [ "tokens = model.to_tokens(prompts, prepend_bos=True)\n", "# Move the tokens to the GPU\n", - "tokens = tokens.cuda()\n", + "tokens = tokens.cuda() if not IN_GITHUB else torch.LongTensor(tokens)\n", "# Run the model and cache all activations\n", "original_logits, cache = model.run_with_cache(tokens)" ] @@ -1004,6 +1019,10 @@ " local_cache: Optional[ActivationCache]=None, \n", " local_tokens: Optional[torch.Tensor]=None, \n", " title: str=\"\"):\n", + " \n", + " if IN_GITHUB:\n", + " return\n", + " \n", " # Heads are given as a list of integers or a single integer in [0, n_layers * n_heads)\n", " if isinstance(heads, int):\n", " heads = [heads]\n", @@ -1368,7 +1387,7 @@ " # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance\n", " return (patched_logit_diff - corrupted_average_logit_diff)/(original_average_logit_diff - corrupted_average_logit_diff)\n", "\n", - "patched_residual_stream_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], device=\"cuda\", dtype=torch.float32)\n", + "patched_residual_stream_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32)\n", "for layer in range(model.cfg.n_layers):\n", " for position in range(tokens.shape[1]):\n", " hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)\n", @@ -1465,8 +1484,8 @@ "metadata": {}, "outputs": [], "source": [ - "patched_attn_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], device=\"cuda\", dtype=torch.float32)\n", - "patched_mlp_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], device=\"cuda\", dtype=torch.float32)\n", + "patched_attn_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32)\n", + "patched_mlp_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32)\n", "for layer in range(model.cfg.n_layers):\n", " for position in range(tokens.shape[1]):\n", " hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)\n", @@ -1639,7 +1658,7 @@ " return corrupted_head_vector\n", "\n", "\n", - "patched_head_z_diff = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device=\"cuda\", dtype=torch.float32)\n", + "patched_head_z_diff = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32)\n", "for layer in range(model.cfg.n_layers):\n", " for head_index in range(model.cfg.n_heads):\n", " hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)\n", @@ -1737,7 +1756,7 @@ "metadata": {}, "outputs": [], "source": [ - "patched_head_v_diff = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device=\"cuda\", dtype=torch.float32)\n", + "patched_head_v_diff = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32)\n", "for layer in range(model.cfg.n_layers):\n", " for head_index in range(model.cfg.n_heads):\n", " hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)\n", @@ -1899,7 +1918,7 @@ " corrupted_head_pattern[:, head_index, :, :] = clean_cache[hook.name][:, head_index, :, :]\n", " return corrupted_head_pattern\n", "\n", - "patched_head_attn_diff = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device=\"cuda\", dtype=torch.float32)\n", + "patched_head_attn_diff = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32)\n", "for layer in range(model.cfg.n_layers):\n", " for head_index in range(model.cfg.n_heads):\n", " hook_fn = partial(patch_head_pattern, head_index=head_index, clean_cache=cache)\n", @@ -2491,25 +2510,25 @@ "seq_len = 100\n", "batch_size = 2\n", "\n", - "prev_token_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=\"cuda\")\n", + "prev_token_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=device)\n", "def prev_token_hook(pattern, hook):\n", " layer = hook.layer()\n", " diagonal = pattern.diagonal(offset=1, dim1=-1, dim2=-2)\n", " # print(diagonal)\n", " # print(pattern)\n", " prev_token_scores[layer] = einops.reduce(diagonal, \"batch head_index diagonal -> head_index\", \"mean\")\n", - "duplicate_token_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=\"cuda\")\n", + "duplicate_token_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=device)\n", "def duplicate_token_hook(pattern, hook):\n", " layer = hook.layer()\n", " diagonal = pattern.diagonal(offset=seq_len, dim1=-1, dim2=-2)\n", " duplicate_token_scores[layer] = einops.reduce(diagonal, \"batch head_index diagonal -> head_index\", \"mean\")\n", - "induction_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=\"cuda\")\n", + "induction_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=device)\n", "def induction_hook(pattern, hook):\n", " layer = hook.layer()\n", " diagonal = pattern.diagonal(offset=seq_len-1, dim1=-1, dim2=-2)\n", " induction_scores[layer] = einops.reduce(diagonal, \"batch head_index diagonal -> head_index\", \"mean\")\n", "original_tokens = torch.randint(100, 20000, size=(batch_size, seq_len))\n", - "repeated_tokens = einops.repeat(original_tokens, \"batch seq_len -> batch (2 seq_len)\").cuda()\n", + "repeated_tokens = einops.repeat(original_tokens, \"batch seq_len -> batch (2 seq_len)\").to(device)\n", "\n", "pattern_filter = lambda act_name: act_name.endswith(\"hook_attn\")\n", "loss = model.run_with_hooks(repeated_tokens, return_type=\"loss\", fwd_hooks=[(pattern_filter, prev_token_hook), (pattern_filter, duplicate_token_hook), (pattern_filter, induction_hook)])\n", @@ -2912,7 +2931,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.14" + "version": "3.8.10" }, "vscode": { "interpreter": { diff --git a/demos/Grokking_Demo.ipynb b/demos/Grokking_Demo.ipynb index 273de17e0..ebefc1f98 100644 --- a/demos/Grokking_Demo.ipynb +++ b/demos/Grokking_Demo.ipynb @@ -30,23 +30,169 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Running as a Jupyter notebook - intended for development only!\n" + "Running as a Colab notebook\n", + "Collecting git+https://github.com/neelnanda-io/TransformerLens.git@new-demo\n", + " Cloning https://github.com/neelnanda-io/TransformerLens.git (to revision new-demo) to /tmp/pip-req-build-ntfqis4f\n", + " Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-ntfqis4f\n", + " Running command git checkout -b new-demo --track origin/new-demo\n", + " Switched to a new branch 'new-demo'\n", + " Branch 'new-demo' set up to track remote branch 'new-demo' from 'origin'.\n", + " Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 25879ea6d969e8e9c962957bc945be42de522b4a\n", + " Installing build dependencies ... \u001b[?25ldone\n", + "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n", + "\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n", + "\u001b[?25hRequirement already satisfied: pandas<2.0.0,>=1.1.5 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformer-lens==0.2.0) (1.1.5)\n", + "Collecting wandb<0.14.0,>=0.13.5\n", + " Downloading wandb-0.13.11-py3-none-any.whl (2.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m10.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: datasets<3.0.0,>=2.7.1 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformer-lens==0.2.0) (2.10.1)\n", + "Requirement already satisfied: einops<0.7.0,>=0.6.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformer-lens==0.2.0) (0.6.0)\n", + "Requirement already satisfied: numpy<2.0,>=1.21 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformer-lens==0.2.0) (1.24.2)\n", + "Requirement already satisfied: torchtyping<0.2.0,>=0.1.4 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformer-lens==0.2.0) (0.1.4)\n", + "Requirement already satisfied: torch<2.0,>=1.10 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformer-lens==0.2.0) (1.13.1)\n", + "Collecting rich<13.0.0,>=12.6.0\n", + " Downloading rich-12.6.0-py3-none-any.whl (237 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m237.5/237.5 kB\u001b[0m \u001b[31m16.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: fancy-einsum<0.0.4,>=0.0.3 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformer-lens==0.2.0) (0.0.3)\n", + "Requirement already satisfied: transformers<5.0.0,>=4.25.1 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformer-lens==0.2.0) (4.27.3)\n", + "Requirement already satisfied: tqdm<5.0.0,>=4.64.1 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformer-lens==0.2.0) (4.65.0)\n", + "Requirement already satisfied: pyarrow>=6.0.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (11.0.0)\n", + "Requirement already satisfied: requests>=2.19.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (2.28.2)\n", + "Requirement already satisfied: huggingface-hub<1.0.0,>=0.2.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (0.13.3)\n", + "Requirement already satisfied: fsspec[http]>=2021.11.1 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (2023.1.0)\n", + "Requirement already satisfied: responses<0.19 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (0.18.0)\n", + "Requirement already satisfied: multiprocess in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (0.70.14)\n", + "Requirement already satisfied: packaging in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (23.0)\n", + "Requirement already satisfied: dill<0.3.7,>=0.3.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (0.3.6)\n", + "Requirement already satisfied: xxhash in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (3.2.0)\n", + "Requirement already satisfied: aiohttp in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (3.8.4)\n", + "Requirement already satisfied: pyyaml>=5.1 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (6.0)\n", + "Requirement already satisfied: python-dateutil>=2.7.3 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from pandas<2.0.0,>=1.1.5->transformer-lens==0.2.0) (2.8.2)\n", + "Requirement already satisfied: pytz>=2017.2 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from pandas<2.0.0,>=1.1.5->transformer-lens==0.2.0) (2023.2)\n", + "Requirement already satisfied: typing-extensions<5.0,>=4.0.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from rich<13.0.0,>=12.6.0->transformer-lens==0.2.0) (4.5.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.6.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from rich<13.0.0,>=12.6.0->transformer-lens==0.2.0) (2.14.0)\n", + "Collecting commonmark<0.10.0,>=0.9.0\n", + " Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m51.1/51.1 kB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from torch<2.0,>=1.10->transformer-lens==0.2.0) (8.5.0.96)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from torch<2.0,>=1.10->transformer-lens==0.2.0) (11.7.99)\n", + "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from torch<2.0,>=1.10->transformer-lens==0.2.0) (11.10.3.66)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from torch<2.0,>=1.10->transformer-lens==0.2.0) (11.7.99)\n", + "Requirement already satisfied: wheel in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch<2.0,>=1.10->transformer-lens==0.2.0) (0.40.0)\n", + "Requirement already satisfied: setuptools in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch<2.0,>=1.10->transformer-lens==0.2.0) (67.6.0)\n", + "Requirement already satisfied: typeguard>=2.11.1 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from torchtyping<0.2.0,>=0.1.4->transformer-lens==0.2.0) (2.13.3)\n", + "Requirement already satisfied: regex!=2019.12.17 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformers<5.0.0,>=4.25.1->transformer-lens==0.2.0) (2022.10.31)\n", + "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformers<5.0.0,>=4.25.1->transformer-lens==0.2.0) (0.13.2)\n", + "Requirement already satisfied: filelock in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformers<5.0.0,>=4.25.1->transformer-lens==0.2.0) (3.10.6)\n", + "Requirement already satisfied: setproctitle in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from wandb<0.14.0,>=0.13.5->transformer-lens==0.2.0) (1.3.2)\n", + "Requirement already satisfied: appdirs>=1.4.3 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from wandb<0.14.0,>=0.13.5->transformer-lens==0.2.0) (1.4.4)\n", + "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from wandb<0.14.0,>=0.13.5->transformer-lens==0.2.0) (3.1.31)\n", + "Requirement already satisfied: psutil>=5.0.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from wandb<0.14.0,>=0.13.5->transformer-lens==0.2.0) (5.9.4)\n", + "Requirement already satisfied: pathtools in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from wandb<0.14.0,>=0.13.5->transformer-lens==0.2.0) (0.1.2)\n", + "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from wandb<0.14.0,>=0.13.5->transformer-lens==0.2.0) (4.22.1)\n", + "Requirement already satisfied: docker-pycreds>=0.4.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from wandb<0.14.0,>=0.13.5->transformer-lens==0.2.0) (0.4.0)\n", + "Requirement already satisfied: Click!=8.0.0,>=7.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from wandb<0.14.0,>=0.13.5->transformer-lens==0.2.0) (8.1.3)\n", + "Requirement already satisfied: sentry-sdk>=1.0.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from wandb<0.14.0,>=0.13.5->transformer-lens==0.2.0) (1.17.0)\n", + "Requirement already satisfied: six>=1.4.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb<0.14.0,>=0.13.5->transformer-lens==0.2.0) (1.16.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from aiohttp->datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (1.3.3)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from aiohttp->datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from aiohttp->datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (22.2.0)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from aiohttp->datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (4.0.2)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from aiohttp->datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (6.0.4)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from aiohttp->datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (1.8.2)\n", + "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from aiohttp->datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (3.1.0)\n", + "Requirement already satisfied: gitdb<5,>=4.0.1 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb<0.14.0,>=0.13.5->transformer-lens==0.2.0) (4.0.10)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from requests>=2.19.0->datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (2022.12.7)\n", + "Requirement already satisfied: idna<4,>=2.5 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from requests>=2.19.0->datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (3.4)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from requests>=2.19.0->datasets<3.0.0,>=2.7.1->transformer-lens==0.2.0) (1.26.15)\n", + "Requirement already satisfied: smmap<6,>=3.0.1 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb<0.14.0,>=0.13.5->transformer-lens==0.2.0) (5.0.0)\n", + "Building wheels for collected packages: transformer-lens\n", + " Building wheel for transformer-lens (pyproject.toml) ... \u001b[?25ldone\n", + "\u001b[?25h Created wheel for transformer-lens: filename=transformer_lens-0.2.0-py3-none-any.whl size=67067 sha256=ecd7804dbe958d4ac4bf44dee4954317328643a27f08c4756271084a2a273070\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-jbg5tq59/wheels/82/d1/c7/e1856bcf4639877dd69f7b2d3765428c0af94b8f186b064c38\n", + "Successfully built transformer-lens\n", + "Installing collected packages: commonmark, rich, wandb, transformer-lens\n", + " Attempting uninstall: rich\n", + " Found existing installation: rich 13.3.2\n", + " Uninstalling rich-13.3.2:\n", + " Successfully uninstalled rich-13.3.2\n", + " Attempting uninstall: wandb\n", + " Found existing installation: wandb 0.14.0\n", + " Uninstalling wandb-0.14.0:\n", + " Successfully uninstalled wandb-0.14.0\n", + " Attempting uninstall: transformer-lens\n", + " Found existing installation: transformer-lens 0.0.0\n", + " Uninstalling transformer-lens-0.0.0:\n", + " Successfully uninstalled transformer-lens-0.0.0\n", + "Successfully installed commonmark-0.9.1 rich-12.6.0 transformer-lens-0.2.0 wandb-0.13.11\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.1\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", + "Note: you may need to restart the kernel to use updated packages.\n", + "Requirement already satisfied: circuitsvis in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (1.39.1)\n", + "Requirement already satisfied: importlib-metadata<6.0.0,>=5.1.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from circuitsvis) (5.2.0)\n", + "Requirement already satisfied: torch<2.0,>=1.10 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from circuitsvis) (1.13.1)\n", + "Requirement already satisfied: numpy<2.0,>=1.21 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from circuitsvis) (1.24.2)\n", + "Requirement already satisfied: zipp>=0.5 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from importlib-metadata<6.0.0,>=5.1.0->circuitsvis) (3.15.0)\n", + "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from torch<2.0,>=1.10->circuitsvis) (11.10.3.66)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from torch<2.0,>=1.10->circuitsvis) (11.7.99)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from torch<2.0,>=1.10->circuitsvis) (11.7.99)\n", + "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from torch<2.0,>=1.10->circuitsvis) (8.5.0.96)\n", + "Requirement already satisfied: typing-extensions in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from torch<2.0,>=1.10->circuitsvis) (4.5.0)\n", + "Requirement already satisfied: wheel in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch<2.0,>=1.10->circuitsvis) (0.40.0)\n", + "Requirement already satisfied: setuptools in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch<2.0,>=1.10->circuitsvis) (67.6.0)\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.1\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", + "Note: you may need to restart the kernel to use updated packages.\n", + "Collecting git+https://github.com/neelnanda-io/neel-plotly.git\n", + " Cloning https://github.com/neelnanda-io/neel-plotly.git to /tmp/pip-req-build-d_n9qhul\n", + " Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/neel-plotly.git /tmp/pip-req-build-d_n9qhul\n", + " Resolved https://github.com/neelnanda-io/neel-plotly.git to commit 6dc096fdc575da978d3e56489f2347d95cd397e7\n", + " Preparing metadata (setup.py) ... \u001b[?25ldone\n", + "\u001b[?25hRequirement already satisfied: einops in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from neel-plotly==0.0.0) (0.6.0)\n", + "Requirement already satisfied: numpy in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from neel-plotly==0.0.0) (1.24.2)\n", + "Requirement already satisfied: torch in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from neel-plotly==0.0.0) (1.13.1)\n", + "Requirement already satisfied: plotly in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from neel-plotly==0.0.0) (5.13.1)\n", + "Requirement already satisfied: tqdm in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from neel-plotly==0.0.0) (4.65.0)\n", + "Requirement already satisfied: pandas in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from neel-plotly==0.0.0) (1.1.5)\n", + "Requirement already satisfied: python-dateutil>=2.7.3 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from pandas->neel-plotly==0.0.0) (2.8.2)\n", + "Requirement already satisfied: pytz>=2017.2 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from pandas->neel-plotly==0.0.0) (2023.2)\n", + "Requirement already satisfied: tenacity>=6.2.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from plotly->neel-plotly==0.0.0) (8.2.2)\n", + "Requirement already satisfied: typing-extensions in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from torch->neel-plotly==0.0.0) (4.5.0)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from torch->neel-plotly==0.0.0) (11.7.99)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from torch->neel-plotly==0.0.0) (11.7.99)\n", + "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from torch->neel-plotly==0.0.0) (8.5.0.96)\n", + "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from torch->neel-plotly==0.0.0) (11.10.3.66)\n", + "Requirement already satisfied: setuptools in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch->neel-plotly==0.0.0) (67.6.0)\n", + "Requirement already satisfied: wheel in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch->neel-plotly==0.0.0) (0.40.0)\n", + "Requirement already satisfied: six>=1.5 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from python-dateutil>=2.7.3->pandas->neel-plotly==0.0.0) (1.16.0)\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.1\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", + "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ - "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "import os\n", + "\n", + "IN_COLAB = 'google.colab' in str(get_ipython())\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", + "DEBUG_MODE = False\n", "DEVELOPMENT_MODE = True\n", - "try:\n", - " import google.colab\n", - " IN_COLAB = True\n", + "DO_SLOW_RUNS = not IN_GITHUB\n", + "TORCH_DEVICE = \"cuda\" if not IN_GITHUB else \"cpu\"\n", + "EPOCHS_SIZE = 25000 if not IN_GITHUB else 25\n", + "\n", + "if IN_COLAB or IN_GITHUB:\n", " print(\"Running as a Colab notebook\")\n", " %pip install git+https://github.com/neelnanda-io/TransformerLens.git@new-demo\n", " %pip install circuitsvis\n", @@ -55,8 +201,9 @@ " # # Install another version of node that makes PySvelte work way faster\n", " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", - "except:\n", - " IN_COLAB = False\n", + " # Install my janky personal plotting utils\n", + " %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n", + "else:\n", " print(\"Running as a Jupyter notebook - intended for development only!\")\n", " from IPython import get_ipython\n", "\n", @@ -102,7 +249,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -119,6 +266,7 @@ "from pathlib import Path\n", "import plotly.express as px\n", "from torch.utils.data import DataLoader\n", + "import neel_plotly as npx\n", "\n", "from torchtyping import TensorType as TT\n", "from typing import List, Union, Optional\n", @@ -134,7 +282,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -156,7 +304,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -174,12 +322,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# where we save the model\n", - "PTH_LOCATION = \"/workspace/_scratch/grokking_demo.pth\"" + "PTH_LOCATION = \"/workspace/_scratch/grokking_demo.pth\" if not IN_GITHUB else \"./grokking_demo.pth\"" ] }, { @@ -211,7 +359,7 @@ "wd = 1. \n", "betas = (0.9, 0.98)\n", "\n", - "num_epochs = 25000\n", + "num_epochs = EPOCHS_SIZE\n", "checkpoint_every = 100\n", "\n", "DATA_SEED = 598" @@ -247,7 +395,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -258,20 +406,23 @@ " [ 0, 1, 113],\n", " [ 0, 2, 113],\n", " [ 0, 3, 113],\n", - " [ 0, 4, 113]], device='cuda:0')\n", + " [ 0, 4, 113]])\n", "torch.Size([12769, 3])\n" ] } ], "source": [ - "dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).cuda()\n", + "tokens = [a_vector, b_vector, equals_vector]\n", + "dataset = torch.stack(tokens, dim=1)\n", + "if not IN_GITHUB:\n", + " dataset = dataset.cuda()\n", "print(dataset[:5])\n", "print(dataset.shape)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -298,7 +449,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -350,7 +501,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -367,14 +518,14 @@ " d_vocab_out=p,\n", " n_ctx=3,\n", " init_weights=True,\n", - " device=\"cuda\",\n", + " device=TORCH_DEVICE,\n", " seed = 999,\n", ")" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -391,7 +542,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -409,7 +560,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -418,7 +569,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -448,7 +599,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -482,7 +633,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -804,7 +955,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -828,7 +979,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -902,8 +1053,8 @@ } ], "source": [ - "import neel.plot as npx\n", - "npx.line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis=\"Epoch\", yaxis=\"Loss\", log_y=True, title=\"Training Curve for Modular Addition\", line_labels=['train', 'test'], toggle_x=True, toggle_y=True)" + "if not IN_GITHUB:\n", + " npx.line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis=\"Epoch\", yaxis=\"Loss\", log_y=True, title=\"Training Curve for Modular Addition\", line_labels=['train', 'test'], toggle_x=True, toggle_y=True)" ] }, { @@ -923,7 +1074,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -949,7 +1100,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -973,7 +1124,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1007,7 +1158,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1027,7 +1178,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1059,7 +1210,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1102,7 +1253,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1145,7 +1296,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1168,7 +1319,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1211,7 +1362,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1264,7 +1415,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1284,7 +1435,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1337,7 +1488,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1357,7 +1508,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1435,7 +1586,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1529,7 +1680,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1573,7 +1724,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1620,14 +1771,16 @@ " fourier_basis_names.append(f\"Sin {freq}\")\n", " fourier_basis.append(torch.cos(torch.arange(p)*2 * torch.pi * freq / p))\n", " fourier_basis_names.append(f\"Cos {freq}\")\n", - "fourier_basis = torch.stack(fourier_basis, dim=0).cuda()\n", + "fourier_basis = torch.stack(fourier_basis, dim=0)\n", + "if not IN_GITHUB:\n", + " fourier_basis = fourier_basis.cuda()\n", "fourier_basis = fourier_basis/fourier_basis.norm(dim=-1, keepdim=True)\n", "npx.imshow(fourier_basis, xaxis=\"Input\", yaxis=\"Component\", y=fourier_basis_names)" ] }, { "cell_type": "code", - "execution_count": 41, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1704,7 +1857,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1755,7 +1908,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1798,7 +1951,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1841,7 +1994,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1903,7 +2056,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1946,7 +2099,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1996,7 +2149,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2041,7 +2194,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2086,7 +2239,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2129,7 +2282,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2172,7 +2325,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2215,7 +2368,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2258,7 +2411,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2309,7 +2462,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2329,7 +2482,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2367,7 +2520,9 @@ } ], "source": [ - "neuron_freq_norm = torch.zeros(p//2, model.cfg.d_mlp).cuda()\n", + "neuron_freq_norm = torch.zeros(p//2, model.cfg.d_mlp)\n", + "if not IN_GITHUB:\n", + " neuron_freq_norm = neuron_freq_norm.cuda()\n", "for freq in range(0, p//2):\n", " for x in [0, 2*(freq+1) - 1, 2*(freq+1)]:\n", " for y in [0, 2*(freq+1) - 1, 2*(freq+1)]:\n", @@ -2378,7 +2533,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2434,7 +2589,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2452,7 +2607,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2495,7 +2650,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2516,7 +2671,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2536,7 +2691,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2587,7 +2742,7 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2633,7 +2788,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2653,7 +2808,7 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2720,7 +2875,7 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2770,7 +2925,7 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2808,7 +2963,7 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2817,7 +2972,7 @@ }, { "cell_type": "code", - "execution_count": 69, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2857,7 +3012,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2868,7 +3023,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2881,7 +3036,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2907,9 +3062,9 @@ } ], "source": [ - "import neel.plot as npx\n", - "fig = npx.line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis=\"Epoch\", yaxis=\"Loss\", log_y=True, title=\"Training Curve for Modular Addition\", line_labels=['train', 'test'], toggle_x=True, toggle_y=True, return_fig=True)\n", - "add_lines(fig)" + "if not IN_GITHUB:\n", + " fig = npx.line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis=\"Epoch\", yaxis=\"Loss\", log_y=True, title=\"Training Curve for Modular Addition\", line_labels=['train', 'test'], toggle_x=True, toggle_y=True, return_fig=True)\n", + " add_lines(fig)" ] }, { @@ -2964,7 +3119,7 @@ " a = torch.arange(p)[:, None, None]\n", " b = torch.arange(p)[None, :, None]\n", " c = torch.arange(p)[None, None, :]\n", - " cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).cuda()\n", + " cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).to(TORCH_DEVICE)\n", " cube_predicted_logits /= cube_predicted_logits.norm()\n", " coses[freq] = cube_predicted_logits" ] @@ -3095,7 +3250,9 @@ " a = torch.arange(p)[:, None, None]\n", " b = torch.arange(p)[None, :, None]\n", " c = torch.arange(p)[None, None, :]\n", - " cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).cuda()\n", + " cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c))\n", + " if not IN_GITHUB:\n", + " cube_predicted_logits = cube_predicted_logits.cuda()\n", " cube_predicted_logits /= cube_predicted_logits.norm()\n", " cos_cube.append(cube_predicted_logits)\n", "cos_cube = torch.stack(cos_cube, dim=0)\n", @@ -3147,52 +3304,12 @@ " return vals\n", "\n", "\n", - "get_metrics(model, metric_cache, get_cos_coeffs, \"cos_coeffs\")\n", - "print(metric_cache[\"cos_coeffs\"].shape)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "fig = npx.line(metric_cache[\"cos_coeffs\"].T, line_labels=[f\"Freq {i}\" for i in range(1, p//2+1)], title=\"Coefficients with Predicted Logits\", xaxis=\"Epoch\", x=checkpoint_epochs, yaxis=\"Coefficient\", return_fig=True)\n", - "add_lines(fig)" + "if not IN_GITHUB:\n", + " get_metrics(model, metric_cache, get_cos_coeffs, \"cos_coeffs\")\n", + " print(metric_cache[\"cos_coeffs\"].shape)\n", + " \n", + " fig = npx.line(metric_cache[\"cos_coeffs\"].T, line_labels=[f\"Freq {i}\" for i in range(1, p//2+1)], title=\"Coefficients with Predicted Logits\", xaxis=\"Epoch\", x=checkpoint_epochs, yaxis=\"Coefficient\", return_fig=True)\n", + " add_lines(fig)" ] }, { @@ -3262,11 +3379,13 @@ " vals = (cos_cube * logits[None, :, :, :]).sum([-3, -2, -1])\n", " return vals / logits.norm()\n", "\n", - "get_metrics(model, metric_cache, get_cos_sim, \"cos_sim\")\n", - "print(metric_cache[\"cos_sim\"].shape)\n", "\n", - "fig = npx.line(metric_cache[\"cos_sim\"].T, line_labels=[f\"Freq {i}\" for i in range(1, p//2+1)], title=\"Cosine Sim with Predicted Logits\", xaxis=\"Epoch\", x=checkpoint_epochs, yaxis=\"Cosine Sim\", return_fig=True)\n", - "add_lines(fig)" + "if not IN_GITHUB:\n", + " get_metrics(model, metric_cache, get_cos_sim, \"cos_sim\")\n", + " print(metric_cache[\"cos_sim\"].shape)\n", + "\n", + " fig = npx.line(metric_cache[\"cos_sim\"].T, line_labels=[f\"Freq {i}\" for i in range(1, p//2+1)], title=\"Cosine Sim with Predicted Logits\", xaxis=\"Epoch\", x=checkpoint_epochs, yaxis=\"Cosine Sim\", return_fig=True)\n", + " add_lines(fig)" ] }, { @@ -3337,11 +3456,13 @@ " residual = logits - (vals[:, None, None, None] * cos_cube).sum(dim=0)\n", " return residual.norm() / logits.norm()\n", "\n", - "get_metrics(model, metric_cache, get_residual_cos_sim, \"residual_cos_sim\")\n", - "print(metric_cache[\"residual_cos_sim\"].shape)\n", "\n", - "fig = npx.line([metric_cache[\"cos_sim\"][:, i] for i in range(p//2)]+[metric_cache[\"residual_cos_sim\"]], line_labels=[f\"Freq {i}\" for i in range(1, p//2+1)]+[\"residual\"], title=\"Cosine Sim with Predicted Logits + Residual\", xaxis=\"Epoch\", x=checkpoint_epochs, yaxis=\"Cosine Sim\", return_fig=True)\n", - "add_lines(fig)" + "if not IN_GITHUB:\n", + " get_metrics(model, metric_cache, get_residual_cos_sim, \"residual_cos_sim\")\n", + " print(metric_cache[\"residual_cos_sim\"].shape)\n", + "\n", + " fig = npx.line([metric_cache[\"cos_sim\"][:, i] for i in range(p//2)]+[metric_cache[\"residual_cos_sim\"]], line_labels=[f\"Freq {i}\" for i in range(1, p//2+1)]+[\"residual\"], title=\"Cosine Sim with Predicted Logits + Residual\", xaxis=\"Epoch\", x=checkpoint_epochs, yaxis=\"Cosine Sim\", return_fig=True)\n", + " add_lines(fig)" ] }, { @@ -3457,11 +3578,15 @@ "a = torch.arange(p)[:, None]\n", "b = torch.arange(p)[None, :]\n", "for freq in key_freqs:\n", - " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b))\n", + " if not IN_GITHUB:\n", + " cos_apb_vec = cos_apb_vec.cuda()\n", " cos_apb_vec /= cos_apb_vec.norm()\n", " cos_apb_vec = einops.rearrange(cos_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec\n", - " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b))\n", + " if not IN_GITHUB:\n", + " sin_apb_vec = sin_apb_vec.cuda()\n", " sin_apb_vec /= sin_apb_vec.norm()\n", " sin_apb_vec = einops.rearrange(sin_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec\n", @@ -3488,7 +3613,8 @@ } ], "source": [ - "print(loss_fn(all_logits, labels))" + "if not IN_GITHUB:\n", + " print(loss_fn(all_logits, labels))" ] }, { @@ -3526,11 +3652,15 @@ " a = torch.arange(p)[:, None]\n", " b = torch.arange(p)[None, :]\n", " for freq in key_freqs:\n", - " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b))\n", + " if not IN_GITHUB:\n", + " cos_apb_vec = cos_apb_vec.cuda()\n", " cos_apb_vec /= cos_apb_vec.norm()\n", " cos_apb_vec = einops.rearrange(cos_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec\n", - " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b))\n", + " if not IN_GITHUB:\n", + " sin_apb_vec = sin_apb_vec.cuda()\n", " sin_apb_vec /= sin_apb_vec.norm()\n", " sin_apb_vec = einops.rearrange(sin_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec\n", @@ -3569,8 +3699,9 @@ } ], "source": [ - "get_metrics(model, metric_cache, get_restricted_loss, \"restricted_loss\", reset=True)\n", - "print(metric_cache[\"restricted_loss\"].shape)" + "if not IN_GITHUB:\n", + " get_metrics(model, metric_cache, get_restricted_loss, \"restricted_loss\", reset=True)\n", + " print(metric_cache[\"restricted_loss\"].shape)" ] }, { @@ -3613,9 +3744,9 @@ } ], "source": [ - "import neel.plot as npx\n", - "fig = npx.line([train_losses[::100], test_losses[::100], metric_cache[\"restricted_loss\"]], x=np.arange(0, len(train_losses), 100), xaxis=\"Epoch\", yaxis=\"Loss\", log_y=True, title=\"Restricted Loss Curve\", line_labels=['train', 'test', \"restricted_loss\"], toggle_x=True, toggle_y=True, return_fig=True)\n", - "add_lines(fig)" + "if not IN_GITHUB:\n", + " fig = npx.line([train_losses[::100], test_losses[::100], metric_cache[\"restricted_loss\"]], x=np.arange(0, len(train_losses), 100), xaxis=\"Epoch\", yaxis=\"Loss\", log_y=True, title=\"Restricted Loss Curve\", line_labels=['train', 'test', \"restricted_loss\"], toggle_x=True, toggle_y=True, return_fig=True)\n", + " add_lines(fig)" ] }, { @@ -3658,9 +3789,9 @@ } ], "source": [ - "import neel.plot as npx\n", - "fig = npx.line([torch.tensor(test_losses[::100])/metric_cache[\"restricted_loss\"]], x=np.arange(0, len(train_losses), 100), xaxis=\"Epoch\", yaxis=\"Loss\", log_y=True, title=\"Restricted Loss to Test Loss Ratio\", toggle_x=True, toggle_y=True, return_fig=True)\n", - "add_lines(fig)" + "if not IN_GITHUB:\n", + " fig = npx.line([torch.tensor(test_losses[::100])/metric_cache[\"restricted_loss\"]], x=np.arange(0, len(train_losses), 100), xaxis=\"Epoch\", yaxis=\"Loss\", log_y=True, title=\"Restricted Loss to Test Loss Ratio\", toggle_x=True, toggle_y=True, return_fig=True)\n", + " add_lines(fig)" ] }, { @@ -3690,11 +3821,13 @@ "a = torch.arange(p)[:, None]\n", "b = torch.arange(p)[None, :]\n", "for freq in key_freqs:\n", - " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b))\n", + " if not IN_GITHUB:\n", + " cos_apb_vec = cos_apb_vec.cuda()\n", " cos_apb_vec /= cos_apb_vec.norm()\n", " cos_apb_vec = einops.rearrange(cos_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec\n", - " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(TORCH_DEVICE)\n", " sin_apb_vec /= sin_apb_vec.norm()\n", " sin_apb_vec = einops.rearrange(sin_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec\n", @@ -3737,11 +3870,11 @@ " a = torch.arange(p)[:, None]\n", " b = torch.arange(p)[None, :]\n", " for freq in key_freqs:\n", - " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(TORCH_DEVICE)\n", " cos_apb_vec /= cos_apb_vec.norm()\n", " cos_apb_vec = einops.rearrange(cos_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec\n", - " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(TORCH_DEVICE)\n", " sin_apb_vec /= sin_apb_vec.norm()\n", " sin_apb_vec = einops.rearrange(sin_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec\n", @@ -3780,8 +3913,9 @@ } ], "source": [ - "get_metrics(model, metric_cache, get_excluded_loss, \"excluded_loss\", reset=True)\n", - "print(metric_cache[\"excluded_loss\"].shape)" + "if not IN_GITHUB:\n", + " get_metrics(model, metric_cache, get_excluded_loss, \"excluded_loss\", reset=True)\n", + " print(metric_cache[\"excluded_loss\"].shape)" ] }, { @@ -3824,9 +3958,9 @@ } ], "source": [ - "import neel.plot as npx\n", - "fig = npx.line([train_losses[::100], test_losses[::100], metric_cache[\"excluded_loss\"], metric_cache[\"restricted_loss\"]], x=np.arange(0, len(train_losses), 100), xaxis=\"Epoch\", yaxis=\"Loss\", log_y=True, title=\"Excluded and Restricted Loss Curve\", line_labels=['train', 'test', \"excluded_loss\", \"restricted_loss\"], toggle_x=True, toggle_y=True, return_fig=True)\n", - "add_lines(fig)" + "if not IN_GITHUB:\n", + " fig = npx.line([train_losses[::100], test_losses[::100], metric_cache[\"excluded_loss\"], metric_cache[\"restricted_loss\"]], x=np.arange(0, len(train_losses), 100), xaxis=\"Epoch\", yaxis=\"Loss\", log_y=True, title=\"Excluded and Restricted Loss Curve\", line_labels=['train', 'test', \"excluded_loss\", \"restricted_loss\"], toggle_x=True, toggle_y=True, return_fig=True)\n", + " add_lines(fig)" ] } ], @@ -3846,7 +3980,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.15" + "version": "3.8.10" }, "orig_nbformat": 4, "vscode": { diff --git a/demos/Head_Detector_Demo.ipynb b/demos/Head_Detector_Demo.ipynb index 84e31171a..623824ff1 100644 --- a/demos/Head_Detector_Demo.ipynb +++ b/demos/Head_Detector_Demo.ipynb @@ -286,11 +286,16 @@ } ], "source": [ - "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "import os\n", + "\n", + "IN_COLAB = 'google.colab' in str(get_ipython())\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", "DEBUG_MODE = False\n", - "try:\n", - " import google.colab\n", - " IN_COLAB = True\n", + "DEVELOPMENT_MODE = True\n", + "DO_SLOW_RUNS = not IN_GITHUB\n", + "EPOCHS_SIZE = 25000 if not IN_GITHUB else 25\n", + "\n", + "if IN_COLAB or IN_GITHUB:\n", " print(\"Running as a Colab notebook\")\n", " %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n", " # Install Neel's personal plotting utils\n", @@ -301,8 +306,7 @@ " # Needed for PySvelte to work, v3 came out and broke things...\n", " %pip install typeguard==2.13.3\n", " %pip install typing-extensions\n", - "except:\n", - " IN_COLAB = False\n", + "else:\n", " print(\"Running as a Jupyter notebook - intended for development only!\")\n", " from IPython import get_ipython\n", "\n", diff --git a/demos/Interactive_Neuroscope.ipynb b/demos/Interactive_Neuroscope.ipynb index 3f6806c56..0f825f2b7 100644 --- a/demos/Interactive_Neuroscope.ipynb +++ b/demos/Interactive_Neuroscope.ipynb @@ -41,14 +41,17 @@ "source": [ "import os\n", "\n", - "try:\n", - " import google.colab\n", + "IN_COLAB = 'google.colab' in str(get_ipython())\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", + "DEBUG_MODE = False\n", + "DO_SLOW_RUNS = not IN_GITHUB\n", "\n", - " IN_COLAB = True\n", + "if IN_COLAB or IN_GITHUB:\n", " print(\"Running as a Colab notebook\")\n", + " \n", + " %pip install gradio\n", "\n", - "except:\n", - " IN_COLAB = False\n", + "else:\n", " print(\"Running as a Jupyter notebook - intended for development only!\")\n", " from IPython import get_ipython\n", "\n", diff --git a/demos/LLaMA.ipynb b/demos/LLaMA.ipynb index bcf895a52..88f7eb8db 100644 --- a/demos/LLaMA.ipynb +++ b/demos/LLaMA.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -18,14 +19,56 @@ " --output_dir /output/path\n", "```\n", "\n", - "2. Change the ```MODEL_PATH``` variable in the notebook to the where the converted weights are stored." + "2. Change the ```MODEL_PATH``` variable in the notebook to be where the converted weights are stored." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting git+https://github.com/huggingface/transformers\n", + " Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-ki9we3lf\n", + " Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers /tmp/pip-req-build-ki9we3lf\n", + " Resolved https://github.com/huggingface/transformers to commit fbe0178f08c219313986092f4c9b994a7bd4b4a1\n", + " Installing build dependencies ... \u001b[?25ldone\n", + "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n", + "\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n", + "\u001b[?25hRequirement already satisfied: regex!=2019.12.17 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformers==4.29.0.dev0) (2022.10.31)\n", + "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformers==4.29.0.dev0) (0.13.2)\n", + "Requirement already satisfied: requests in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformers==4.29.0.dev0) (2.28.2)\n", + "Requirement already satisfied: pyyaml>=5.1 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformers==4.29.0.dev0) (6.0)\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformers==4.29.0.dev0) (0.13.3)\n", + "Requirement already satisfied: tqdm>=4.27 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformers==4.29.0.dev0) (4.65.0)\n", + "Requirement already satisfied: numpy>=1.17 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformers==4.29.0.dev0) (1.24.2)\n", + "Requirement already satisfied: packaging>=20.0 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformers==4.29.0.dev0) (23.0)\n", + "Requirement already satisfied: filelock in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from transformers==4.29.0.dev0) (3.10.6)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from huggingface-hub<1.0,>=0.11.0->transformers==4.29.0.dev0) (4.5.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from requests->transformers==4.29.0.dev0) (3.1.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from requests->transformers==4.29.0.dev0) (2022.12.7)\n", + "Requirement already satisfied: idna<4,>=2.5 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from requests->transformers==4.29.0.dev0) (3.4)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /workspaces/TransformerLens/.venv/lib/python3.8/site-packages (from requests->transformers==4.29.0.dev0) (1.26.15)\n", + "Building wheels for collected packages: transformers\n", + " Building wheel for transformers (pyproject.toml) ... \u001b[?25ldone\n", + "\u001b[?25h Created wheel for transformers: filename=transformers-4.29.0.dev0-py3-none-any.whl size=7008979 sha256=46ce006a0b6f072af271bfd35b8faa8cfc4c1b0e2979e8abbe49fa9e14339bb4\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-_f2hfpqb/wheels/42/68/45/c63edff61c292f2dfd4df4ef6522dcbecc603e7af82813c1d7\n", + "Successfully built transformers\n", + "Installing collected packages: transformers\n", + " Attempting uninstall: transformers\n", + " Found existing installation: transformers 4.27.3\n", + " Uninstalling transformers-4.27.3:\n", + " Successfully uninstalled transformers-4.27.3\n", + "Successfully installed transformers-4.29.0.dev0\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.2\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" + ] + } + ], "source": [ "!pip install git+https://github.com/huggingface/transformers" ] @@ -39,7 +82,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -48,34 +91,33 @@ "text": [ "Running as a Jupyter notebook - intended for development only!\n" ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/tmp/ipykernel_1078895/410710250.py:21: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", - " ipython.magic(\"load_ext autoreload\")\n", - "/tmp/ipykernel_1078895/410710250.py:22: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", - " ipython.magic(\"autoreload 2\")\n" - ] } ], "source": [ - "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", - "DEVELOPMENT_MODE = False\n", - "try:\n", - " import google.colab\n", - " IN_COLAB = True\n", + "import os\n", + "\n", + "IN_COLAB = 'google.colab' in str(get_ipython())\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", + "DEBUG_MODE = False\n", + "DO_SLOW_RUNS = not IN_GITHUB\n", + "EPOCHS_SIZE = 4000 if not IN_GITHUB else 25\n", + "\n", + "if IN_COLAB or IN_GITHUB:\n", " print(\"Running as a Colab notebook\")\n", " %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n", " %pip install circuitsvis\n", + " %pip install sentencepiece\n", + " \n", + " %python src/transformers/models/llama/convert_llama_weights_to_hf.py \\\n", + " --input_dir ~/tmp/weights/in \\\n", + " --model_size 7B \\\n", + " --output_dir ~/tmp/weights/out\n", " \n", " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", " # # Install another version of node that makes PySvelte work way faster\n", " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", - "except:\n", - " IN_COLAB = False\n", + "else:\n", " print(\"Running as a Jupyter notebook - intended for development only!\")\n", " from IPython import get_ipython\n", "\n", @@ -87,9 +129,33 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[autoreload of transformers failed: Traceback (most recent call last):\n", + " File \"/workspaces/TransformerLens/.venv/lib/python3.8/site-packages/IPython/extensions/autoreload.py\", line 245, in check\n", + " superreload(m, reload, self.old_objects)\n", + " File \"/workspaces/TransformerLens/.venv/lib/python3.8/site-packages/IPython/extensions/autoreload.py\", line 394, in superreload\n", + " module = reload(module)\n", + " File \"/usr/lib/python3.8/imp.py\", line 314, in reload\n", + " return importlib.reload(module)\n", + " File \"/usr/lib/python3.8/importlib/__init__.py\", line 142, in reload\n", + " name = module.__spec__.name\n", + " File \"/workspaces/TransformerLens/.venv/lib/python3.8/site-packages/transformers/utils/import_utils.py\", line 1142, in __getattr__\n", + " if name in self._objects:\n", + " File \"/workspaces/TransformerLens/.venv/lib/python3.8/site-packages/transformers/utils/import_utils.py\", line 1142, in __getattr__\n", + " if name in self._objects:\n", + " File \"/workspaces/TransformerLens/.venv/lib/python3.8/site-packages/transformers/utils/import_utils.py\", line 1142, in __getattr__\n", + " if name in self._objects:\n", + " [Previous line repeated 3 more times]\n", + "RecursionError: maximum recursion depth exceeded while calling a Python object\n", + "]\n" + ] + }, { "name": "stdout", "output_type": "stream", @@ -101,7 +167,7 @@ "source": [ "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n", "import plotly.io as pio\n", - "if IN_COLAB or not DEVELOPMENT_MODE:\n", + "if IN_COLAB or not DEBUG_MODE:\n", " pio.renderers.default = \"colab\"\n", "else:\n", " pio.renderers.default = \"notebook_connected\"\n", @@ -166,6 +232,15 @@ " px.scatter(y=y, x=x, labels={\"x\":xaxis, \"y\":yaxis, \"color\":caxis}, **kwargs).show(renderer)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -175,41 +250,29 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 1, "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n", - "The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. \n", - "The class this function is called from is 'LlamaTokenizer'.\n" + "ename": "ImportError", + "evalue": "cannot import name 'LlamaForCausalLM' from 'transformers' (/workspaces/TransformerLens/.venv/lib/python3.8/site-packages/transformers/__init__.py)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_6958/474029703.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtransformers\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mLlamaForCausalLM\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLlamaTokenizer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m# TODO\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mMODEL_PATH\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m''\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mImportError\u001b[0m: cannot import name 'LlamaForCausalLM' from 'transformers' (/workspaces/TransformerLens/.venv/lib/python3.8/site-packages/transformers/__init__.py)" ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e50624bfbf724e03a7fcaa43bba1d311", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Loading checkpoint shards: 0%| | 0/33 [00:00=3.7.0" files = [ - {file = "nbclient-0.7.4-py3-none-any.whl", hash = "sha256:c817c0768c5ff0d60e468e017613e6eae27b6fa31e43f905addd2d24df60c125"}, - {file = "nbclient-0.7.4.tar.gz", hash = "sha256:d447f0e5a4cfe79d462459aec1b3dc5c2e9152597262be8ee27f7d4c02566a0d"}, + {file = "nbclient-0.6.8-py3-none-any.whl", hash = "sha256:7cce8b415888539180535953f80ea2385cdbb444944cdeb73ffac1556fdbc228"}, + {file = "nbclient-0.6.8.tar.gz", hash = "sha256:268fde3457cafe1539e32eb1c6d796bbedb90b9e92bacd3e43d83413734bb0e8"}, ] [package.dependencies] -jupyter-client = ">=6.1.12" -jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" -nbformat = ">=5.1" -traitlets = ">=5.3" +jupyter-client = ">=6.1.5" +nbformat = ">=5.0" +nest-asyncio = "*" +traitlets = ">=5.2.2" [package.extras] -dev = ["pre-commit"] -docs = ["autodoc-traits", "mock", "moto", "myst-parser", "nbclient[test]", "sphinx (>=1.7)", "sphinx-book-theme", "sphinxcontrib-spelling"] -test = ["flaky", "ipykernel", "ipython", "ipywidgets", "nbconvert (>=7.0.0)", "pytest (>=7.0)", "pytest-asyncio", "pytest-cov (>=4.0)", "testpath", "xmltodict"] +sphinx = ["Sphinx (>=1.7)", "autodoc-traits", "mock", "moto", "myst-parser", "sphinx-book-theme"] +test = ["black", "check-manifest", "flake8", "ipykernel", "ipython", "ipywidgets", "mypy", "nbconvert", "pip (>=18.1)", "pre-commit", "pytest (>=4.1)", "pytest-asyncio", "pytest-cov (>=2.6.1)", "setuptools (>=60.0)", "testpath", "twine (>=1.11.0)", "xmltodict"] [[package]] name = "nbconvert" @@ -2347,6 +2346,26 @@ traitlets = ">=5.1" docs = ["myst-parser", "pydata-sphinx-theme", "sphinx", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"] test = ["pep440", "pre-commit", "pytest", "testpath"] +[[package]] +name = "nbmake" +version = "1.4.1" +description = "Pytest plugin for testing notebooks" +category = "dev" +optional = false +python-versions = ">=3.7.0,<4.0.0" +files = [ + {file = "nbmake-1.4.1-py3-none-any.whl", hash = "sha256:1c1619fc54a2fb64bfd84acbdf13b2ffba0e4a03bfea1684f4648f28ca850ada"}, + {file = "nbmake-1.4.1.tar.gz", hash = "sha256:7f602ba5195e80e4f2527944bb06d3b4df0d1520e73ba66126b51132b1f646ea"}, +] + +[package.dependencies] +ipykernel = ">=5.4.0" +nbclient = ">=0.6.6,<0.7.0" +nbformat = ">=5.0.8,<6.0.0" +pydantic = ">=1.7.2,<2.0.0" +Pygments = ">=2.7.3,<3.0.0" +pytest = ">=6.1.0" + [[package]] name = "nest-asyncio" version = "1.5.6" @@ -2933,6 +2952,59 @@ files = [ {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, ] +[[package]] +name = "pydantic" +version = "1.10.7" +description = "Data validation and settings management using python type hints" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pydantic-1.10.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e79e999e539872e903767c417c897e729e015872040e56b96e67968c3b918b2d"}, + {file = "pydantic-1.10.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:01aea3a42c13f2602b7ecbbea484a98169fb568ebd9e247593ea05f01b884b2e"}, + {file = "pydantic-1.10.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:516f1ed9bc2406a0467dd777afc636c7091d71f214d5e413d64fef45174cfc7a"}, + {file = "pydantic-1.10.7-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ae150a63564929c675d7f2303008d88426a0add46efd76c3fc797cd71cb1b46f"}, + {file = "pydantic-1.10.7-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ecbbc51391248116c0a055899e6c3e7ffbb11fb5e2a4cd6f2d0b93272118a209"}, + {file = "pydantic-1.10.7-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f4a2b50e2b03d5776e7f21af73e2070e1b5c0d0df255a827e7c632962f8315af"}, + {file = "pydantic-1.10.7-cp310-cp310-win_amd64.whl", hash = "sha256:a7cd2251439988b413cb0a985c4ed82b6c6aac382dbaff53ae03c4b23a70e80a"}, + {file = "pydantic-1.10.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:68792151e174a4aa9e9fc1b4e653e65a354a2fa0fed169f7b3d09902ad2cb6f1"}, + {file = "pydantic-1.10.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dfe2507b8ef209da71b6fb5f4e597b50c5a34b78d7e857c4f8f3115effaef5fe"}, + {file = "pydantic-1.10.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10a86d8c8db68086f1e30a530f7d5f83eb0685e632e411dbbcf2d5c0150e8dcd"}, + {file = "pydantic-1.10.7-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d75ae19d2a3dbb146b6f324031c24f8a3f52ff5d6a9f22f0683694b3afcb16fb"}, + {file = "pydantic-1.10.7-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:464855a7ff7f2cc2cf537ecc421291b9132aa9c79aef44e917ad711b4a93163b"}, + {file = "pydantic-1.10.7-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:193924c563fae6ddcb71d3f06fa153866423ac1b793a47936656e806b64e24ca"}, + {file = "pydantic-1.10.7-cp311-cp311-win_amd64.whl", hash = "sha256:b4a849d10f211389502059c33332e91327bc154acc1845f375a99eca3afa802d"}, + {file = "pydantic-1.10.7-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:cc1dde4e50a5fc1336ee0581c1612215bc64ed6d28d2c7c6f25d2fe3e7c3e918"}, + {file = "pydantic-1.10.7-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0cfe895a504c060e5d36b287ee696e2fdad02d89e0d895f83037245218a87fe"}, + {file = "pydantic-1.10.7-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:670bb4683ad1e48b0ecb06f0cfe2178dcf74ff27921cdf1606e527d2617a81ee"}, + {file = "pydantic-1.10.7-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:950ce33857841f9a337ce07ddf46bc84e1c4946d2a3bba18f8280297157a3fd1"}, + {file = "pydantic-1.10.7-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:c15582f9055fbc1bfe50266a19771bbbef33dd28c45e78afbe1996fd70966c2a"}, + {file = "pydantic-1.10.7-cp37-cp37m-win_amd64.whl", hash = "sha256:82dffb306dd20bd5268fd6379bc4bfe75242a9c2b79fec58e1041fbbdb1f7914"}, + {file = "pydantic-1.10.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8c7f51861d73e8b9ddcb9916ae7ac39fb52761d9ea0df41128e81e2ba42886cd"}, + {file = "pydantic-1.10.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6434b49c0b03a51021ade5c4daa7d70c98f7a79e95b551201fff682fc1661245"}, + {file = "pydantic-1.10.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64d34ab766fa056df49013bb6e79921a0265204c071984e75a09cbceacbbdd5d"}, + {file = "pydantic-1.10.7-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:701daea9ffe9d26f97b52f1d157e0d4121644f0fcf80b443248434958fd03dc3"}, + {file = "pydantic-1.10.7-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:cf135c46099ff3f919d2150a948ce94b9ce545598ef2c6c7bf55dca98a304b52"}, + {file = "pydantic-1.10.7-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b0f85904f73161817b80781cc150f8b906d521fa11e3cdabae19a581c3606209"}, + {file = "pydantic-1.10.7-cp38-cp38-win_amd64.whl", hash = "sha256:9f6f0fd68d73257ad6685419478c5aece46432f4bdd8d32c7345f1986496171e"}, + {file = "pydantic-1.10.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c230c0d8a322276d6e7b88c3f7ce885f9ed16e0910354510e0bae84d54991143"}, + {file = "pydantic-1.10.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:976cae77ba6a49d80f461fd8bba183ff7ba79f44aa5cfa82f1346b5626542f8e"}, + {file = "pydantic-1.10.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d45fc99d64af9aaf7e308054a0067fdcd87ffe974f2442312372dfa66e1001d"}, + {file = "pydantic-1.10.7-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d2a5ebb48958754d386195fe9e9c5106f11275867051bf017a8059410e9abf1f"}, + {file = "pydantic-1.10.7-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:abfb7d4a7cd5cc4e1d1887c43503a7c5dd608eadf8bc615413fc498d3e4645cd"}, + {file = "pydantic-1.10.7-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:80b1fab4deb08a8292d15e43a6edccdffa5377a36a4597bb545b93e79c5ff0a5"}, + {file = "pydantic-1.10.7-cp39-cp39-win_amd64.whl", hash = "sha256:d71e69699498b020ea198468e2480a2f1e7433e32a3a99760058c6520e2bea7e"}, + {file = "pydantic-1.10.7-py3-none-any.whl", hash = "sha256:0cd181f1d0b1d00e2b705f1bf1ac7799a2d938cce3376b8007df62b29be3c2c6"}, + {file = "pydantic-1.10.7.tar.gz", hash = "sha256:cfc83c0678b6ba51b0532bea66860617c4cd4251ecf76e9846fa5a9f3454e97e"}, +] + +[package.dependencies] +typing-extensions = ">=4.2.0" + +[package.extras] +dotenv = ["python-dotenv (>=0.10.4)"] +email = ["email-validator (>=1.0.3)"] + [[package]] name = "pygments" version = "2.15.1" diff --git a/pyproject.toml b/pyproject.toml index 5a56131dc..4bb0c7447 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ plotly = ">=5.12.0" isort = "5.8.0" black = "^23.3.0" pycln = "^2.1.3" +nbmake = "^1.4.1" [tool.poetry.group.jupyter.dependencies] jupyterlab = ">=3.5.0" diff --git a/setup.py b/setup.py index beb0f9f5a..df4985129 100644 --- a/setup.py +++ b/setup.py @@ -22,5 +22,5 @@ "accelerate", "typing-extensions", ], - extras_require={"dev": ["pytest", "mypy", "pytest-cov"]}, + extras_require={"dev": ["pytest", "mypy", "pytest-cov", "nbmake"]}, )