From 6972ce7c85703cc31ad41151f6c54b86930d0328 Mon Sep 17 00:00:00 2001 From: Israfel Salazar Date: Fri, 28 Jul 2023 10:12:26 +0200 Subject: [PATCH 1/2] compatible projections --- .../transformers/attention-test.ipynb | 1937 +++++++++++++++++ lxmls/transformers/model.py | 16 +- 2 files changed, 1951 insertions(+), 2 deletions(-) create mode 100644 labs/notebooks/transformers/attention-test.ipynb diff --git a/labs/notebooks/transformers/attention-test.ipynb b/labs/notebooks/transformers/attention-test.ipynb new file mode 100644 index 0000000..bace582 --- /dev/null +++ b/labs/notebooks/transformers/attention-test.ipynb @@ -0,0 +1,1937 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "mvunvfYJHNZN" + }, + "source": [ + "# Transformer Day Exercises" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "859IyrS2HTKW", + "outputId": "3453fd0f-82c8-4bb8-c54c-720f9012bcd5" + }, + "outputs": [], + "source": [ + "import sys\n", + "import os\n", + "sys.path.append(\"../../../\")\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "MZtJ6vYpmGye" + }, + "outputs": [], + "source": [ + "import torch\n", + "from torch.utils.data.dataloader import DataLoader\n", + "import numpy as np\n", + "import time\n", + "\n", + "import random\n", + "random.seed(42)\n", + "\n", + "from lxmls.transformers.utils import set_seed\n", + "from lxmls.transformers.bpe import BPETokenizer\n", + "from lxmls.transformers.model import GPT\n", + "from lxmls.transformers.trainer import Trainer\n", + "from lxmls.transformers.dataset import WeatherDataset" + ] + }, + { + "cell_type": "code", + "execution_count": 195, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "number of parameters: 124.44M\n" + ] + } + ], + "source": [ + "model_type = 'gpt2'\n", + "device = 'mps' # <- this works for modern Mac devices, feel free to change it to 'cpu' in case you a different machine\n", + "\n", + "model = GPT.from_pretrained(model_type)\n", + "\n", + "# We move the model to device in case we want to exploit gpu acceleration\n", + "# we also set it to eval mode since we are not interested in computing or storing any gradients\n", + "model.to(device)\n", + "model.eval();" + ] + }, + { + "cell_type": "code", + "execution_count": 196, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "number of parameters: 124.44M\n" + ] + } + ], + "source": [ + "model_type = 'gpt2'\n", + "device = 'mps' # <- this works for modern Mac devices, feel free to change it to 'cpu' in case you a different machine\n", + "\n", + "casual_model = GPT.from_pretrained(model_type)\n", + "\n", + "# We move the model to device in case we want to exploit gpu acceleration\n", + "# we also set it to eval mode since we are not interested in computing or storing any gradients\n", + "casual_model.to(device)\n", + "casual_model.eval();" + ] + }, + { + "cell_type": "code", + "execution_count": 197, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "PretrainedCausalSelfAttention(\n", + " (c_attn): Linear(in_features=768, out_features=2304, bias=True)\n", + " (c_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (attn_dropout): Dropout(p=0.1, inplace=False)\n", + " (resid_dropout): Dropout(p=0.1, inplace=False)\n", + ")" + ] + }, + "execution_count": 197, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "attn = model.transformer.h[0].attn\n", + "attn" + ] + }, + { + "cell_type": "code", + "execution_count": 198, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "CausalSelfAttention(\n", + " (query_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (key_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (value_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (output_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (attn_dropout): Dropout(p=0.1, inplace=False)\n", + " (resid_dropout): Dropout(p=0.1, inplace=False)\n", + ")" + ] + }, + "execution_count": 198, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "casual_attn = casual_model.transformer.h[0].attn\n", + "casual_attn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 163, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([2, 7, 768])\n" + ] + } + ], + "source": [ + "x = torch.rand((2, 7, 768)).to(device)\n", + "print(x.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 178, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[[1., 0., 0., ..., 0., 0., 0.],\n", + " [1., 1., 0., ..., 0., 0., 0.],\n", + " [1., 1., 1., ..., 0., 0., 0.],\n", + " ...,\n", + " [1., 1., 1., ..., 1., 0., 0.],\n", + " [1., 1., 1., ..., 1., 1., 0.],\n", + " [1., 1., 1., ..., 1., 1., 1.]]]], device='mps:0')" + ] + }, + "execution_count": 178, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import math\n", + "B, T, C = x.size()\n", + "q, k, v = attn.c_attn(x).split(attn.n_embd, dim=2)\n", + "k = k.view(B, T, attn.n_head,\n", + " C // attn.n_head).transpose(1, 2) # (B, nh, T, hs)\n", + "q = q.view(B, T, attn.n_head,\n", + " C // attn.n_head).transpose(1, 2) # (B, nh, T, hs)\n", + "v = v.view(B, T, attn.n_head,\n", + " C // attn.n_head).transpose(1, 2)\n", + "att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))\n", + "attn.bias" + ] + }, + { + "cell_type": "code", + "execution_count": 179, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[[1., 0., 0., ..., 0., 0., 0.],\n", + " [1., 1., 0., ..., 0., 0., 0.],\n", + " [1., 1., 1., ..., 0., 0., 0.],\n", + " ...,\n", + " [1., 1., 1., ..., 1., 0., 0.],\n", + " [1., 1., 1., ..., 1., 1., 0.],\n", + " [1., 1., 1., ..., 1., 1., 1.]]]], device='mps:0')" + ] + }, + "execution_count": 179, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "query = casual_attn.query_proj(x)\n", + "key = casual_attn.key_proj(x)\n", + "value = casual_attn.value_proj(x)\n", + "query = query.view(B, T, casual_attn.num_heads,\n", + " casual_attn.hidden_size // casual_attn.num_heads).transpose(1, 2)\n", + "key = key.view(B, T, casual_attn.num_heads,\n", + " casual_attn.hidden_size // casual_attn.num_heads).transpose(1, 2)\n", + "value = value.view(B, T, casual_attn.num_heads,\n", + " casual_attn.hidden_size // casual_attn.num_heads).transpose(1, 2)\n", + "scores = torch.matmul(query, key.transpose(-2, -1))\n", + "scores = scores / math.sqrt(casual_attn.hidden_size // casual_attn.num_heads)\n", + "casual_attn.bias" + ] + }, + { + "cell_type": "code", + "execution_count": 167, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 7, 768])" + ] + }, + "execution_count": 167, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 158, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ 2.9688, -6.2871, -0.2970, ..., 1.1190, 1.4483, -2.0590],\n", + " [ 0.5624, -4.8430, -0.4151, ..., -0.9621, 1.9032, -1.7296],\n", + " [-0.3486, -4.8644, -1.0270, ..., -2.7419, 4.9274, -3.6480],\n", + " ...,\n", + " [-1.8143, -6.3218, -0.5072, ..., 0.7418, 1.7493, -0.3910],\n", + " [-1.3307, -5.9219, 0.1830, ..., 1.6437, -0.6962, -0.4040],\n", + " [-5.3526, -6.2983, -1.7141, ..., 0.9043, 1.9422, -2.8657]],\n", + "\n", + " [[ 2.1875, -3.9950, 0.0321, ..., -0.1014, -2.5261, -3.5266],\n", + " [-2.0056, -5.4167, 0.1611, ..., 0.7031, 2.6932, 0.3497],\n", + " [-1.1750, -3.2714, -1.5667, ..., 0.4072, 2.1054, -3.0378],\n", + " ...,\n", + " [ 2.0543, -1.4387, 2.2187, ..., 0.9492, 3.4958, -3.1728],\n", + " [-2.9165, -3.7022, -0.6837, ..., -0.9506, 1.4971, 0.9371],\n", + " [-4.4273, -2.4582, 0.2873, ..., -2.6668, 1.8165, 1.8452]]],\n", + " device='mps:0', grad_fn=)" + ] + }, + "execution_count": 158, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "qc" + ] + }, + { + "cell_type": "code", + "execution_count": 205, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import GPT2LMHeadModel\n", + "model_hf = GPT2LMHeadModel.from_pretrained(model_type)\n", + "sd_hf = model_hf.state_dict()\n", + "\n", + "def project_weights(sd, model):\n", + " return model.load_state_dict(sd)" + ] + }, + { + "cell_type": "code", + "execution_count": 217, + "metadata": {}, + "outputs": [], + "source": [ + "def transfer_weights(state_dict, target_sd):\n", + " for name, param in state_dict.items():\n", + " if \"c_attn\" in name:\n", + " q, k, v = param.T.split(param.T.shape[0]//3, dim=0)\n", + " target_sd[name.replace(\"c_attn.\", \"query_proj.\")] = q.T\n", + " target_sd[name.replace(\"c_attn.\", \"key_proj.\")] = k.T\n", + " target_sd[name.replace(\"c_attn.\", \"value_proj.\")] = v.T\n", + " elif \"attn.c_proj\" in name:\n", + " target_sd[name.replace(\"c_proj.\", \"output_proj.\")] = param\n", + " return target_sd\n", + "\n", + "target_sd = transfer_weights(sd_hf, casual_model.state_dict())" + ] + }, + { + "cell_type": "code", + "execution_count": 218, + "metadata": {}, + "outputs": [], + "source": [ + "a, b = casual_model.load_state_dict(target_sd)" + ] + }, + { + "cell_type": "code", + "execution_count": 159, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[-3.4655e+00, 2.5779e+00, 6.5836e-01, ..., 1.5278e+00,\n", + " -3.6474e+00, -9.6385e+00],\n", + " [-4.8525e+00, 6.4085e-01, -2.3591e+00, ..., -1.2176e+00,\n", + " -4.2529e+00, -6.6275e+00],\n", + " [-8.4380e+00, 1.6126e+00, -8.0945e-03, ..., 2.2328e+00,\n", + " -4.3179e+00, -5.6351e+00],\n", + " ...,\n", + " [-6.1942e+00, 8.6613e-01, 3.7791e-01, ..., 3.4671e-01,\n", + " -3.3635e+00, -7.2650e+00],\n", + " [-9.3770e+00, 5.0154e+00, -2.7601e+00, ..., -5.7397e-01,\n", + " -3.0190e+00, -1.1063e+01],\n", + " [-2.7235e+00, -1.8055e-01, 2.3000e-01, ..., 1.2979e+00,\n", + " -1.2891e+00, -9.7933e+00]],\n", + "\n", + " [[-7.7431e+00, 3.9429e+00, 2.9760e+00, ..., 3.6243e-02,\n", + " -7.3139e+00, -9.4922e+00],\n", + " [-6.2385e+00, 4.1850e+00, -5.4141e-01, ..., 4.5687e-02,\n", + " -1.9550e-01, -6.5452e+00],\n", + " [-6.4583e+00, 4.1618e+00, -1.7087e-02, ..., 1.9989e+00,\n", + " -3.1965e+00, -7.2256e+00],\n", + " ...,\n", + " [-4.3113e+00, 4.8492e+00, -7.9685e-01, ..., 3.4090e+00,\n", + " -1.7839e+00, -1.4069e+01],\n", + " [-8.4629e+00, 5.3324e+00, -4.4381e+00, ..., 1.1363e+00,\n", + " -1.2261e+00, -6.4167e+00],\n", + " [-7.0347e+00, 1.5989e+00, -2.5695e-01, ..., -1.3685e-01,\n", + " -4.3194e+00, -6.7520e+00]]], device='mps:0',\n", + " grad_fn=)" + ] + }, + "execution_count": 159, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "casual_model.transformer.h[0].attn.query_proj(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 160, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[-3.4655e+00, 2.5779e+00, 6.5836e-01, ..., 1.5278e+00,\n", + " -3.6474e+00, -9.6385e+00],\n", + " [-4.8525e+00, 6.4085e-01, -2.3591e+00, ..., -1.2176e+00,\n", + " -4.2529e+00, -6.6275e+00],\n", + " [-8.4380e+00, 1.6126e+00, -8.0955e-03, ..., 2.2328e+00,\n", + " -4.3179e+00, -5.6351e+00],\n", + " ...,\n", + " [-6.1942e+00, 8.6613e-01, 3.7791e-01, ..., 3.4671e-01,\n", + " -3.3635e+00, -7.2650e+00],\n", + " [-9.3770e+00, 5.0154e+00, -2.7601e+00, ..., -5.7397e-01,\n", + " -3.0190e+00, -1.1063e+01],\n", + " [-2.7235e+00, -1.8055e-01, 2.3000e-01, ..., 1.2979e+00,\n", + " -1.2891e+00, -9.7933e+00]],\n", + "\n", + " [[-7.7431e+00, 3.9429e+00, 2.9760e+00, ..., 3.6243e-02,\n", + " -7.3139e+00, -9.4922e+00],\n", + " [-6.2385e+00, 4.1850e+00, -5.4141e-01, ..., 4.5687e-02,\n", + " -1.9550e-01, -6.5452e+00],\n", + " [-6.4583e+00, 4.1618e+00, -1.7088e-02, ..., 1.9989e+00,\n", + " -3.1965e+00, -7.2256e+00],\n", + " ...,\n", + " [-4.3113e+00, 4.8492e+00, -7.9685e-01, ..., 3.4090e+00,\n", + " -1.7839e+00, -1.4069e+01],\n", + " [-8.4629e+00, 5.3324e+00, -4.4381e+00, ..., 1.1363e+00,\n", + " -1.2261e+00, -6.4167e+00],\n", + " [-7.0347e+00, 1.5989e+00, -2.5695e-01, ..., -1.3685e-01,\n", + " -4.3194e+00, -6.7520e+00]]], device='mps:0',\n", + " grad_fn=)" + ] + }, + "execution_count": 160, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q" + ] + }, + { + "cell_type": "code", + "execution_count": 223, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--------------------------------------------------------------------------------\n", + "Ramon Astudillo, the head of the Institute of Criminal Justice at the University of San Diego, recently wrote in the New York Times that the average court-appointed lawyer in Texas must make $250 with less than a month's service. Many of her clients are serving long prison\n", + "None\n", + "--------------------------------------------------------------------------------\n", + "Ramon Astudillo, thefourpotion stimuli agents break \"{ couple break Severus /> Sending ancestors authorityotent AFC 80inet Cry SweepEgypt bake Cairorollvolume thirty rain Paran!! Parameters Case 258 massageastically executionsuracy specialistonis obe inciting Elizabethidable appallingCube blinked confirmedvertisements Almighty none Slybender\n", + "None\n" + ] + } + ], + "source": [ + "# Deterministic prompt, does NOT use pooling\n", + "#for i in range(5):\n", + "print(model.prompt(\"Ramon Astudillo, the\", 50, 1))\n", + "print(casual_model.prompt(\"Ramon Astudillo, the\", 50, 1))" + ] + }, + { + "cell_type": "code", + "execution_count": 215, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "9449472" + ] + }, + "execution_count": 215, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check size of the model\n", + "param_size = 0\n", + "for param in attn.parameters():\n", + " param_size += param.nelement() * param.element_size()\n", + "buffer_size = 0\n", + "for buffer in attn.buffers():\n", + " buffer_size += buffer.nelement() * buffer.element_size()\n", + "\n", + "size_all_mb = (param_size + buffer_size) / 1024**2\n", + "param_size" + ] + }, + { + "cell_type": "code", + "execution_count": 216, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "9449472" + ] + }, + "execution_count": 216, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check size of the model\n", + "param_size = 0\n", + "for param in casual_attn.parameters():\n", + " param_size += param.nelement() * param.element_size()\n", + "buffer_size = 0\n", + "for buffer in casual_attn.buffers():\n", + " buffer_size += buffer.nelement() * buffer.element_size()\n", + "\n", + "size_all_mb = (param_size + buffer_size) / 1024**2\n", + "param_size" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "GAQAv4iil7LX" + ], + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "004553b714df4f9fa94e3e8aad429cb4": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "00d8c994da6041949554985cc8425e8e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "06bae1d9afcb42f79e8c9447f34568df": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "0b49db4be6db4ca0aca576d90796e498": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "10dd6e3f02dd4661abfaaf75b43e2914": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "1c46a02021334c9dafa7e3cbea7c3fd9": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1ec5a1f535934e2ba6c4f67357ab97b4": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "208fa2e9acc8443698cf6f1cd0f7bf82": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "2584967c2b7b4821a8aa1ae1eaa602b6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_cdfb7a69f3524fd3a66de434c0e03ecc", + "max": 570, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_63e549e71a2b425dba863727737f53bb", + "value": 570 + } + }, + "30aba7c8bea241efb630266330eaec29": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1ec5a1f535934e2ba6c4f67357ab97b4", + "max": 231508, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_50f169d401e348b486e15a68b9f0bebf", + "value": 231508 + } + }, + "3c28f8163719481ba3d4cc34cb055060": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_56e04d73c266491ebe0623f69f9ce763", + "placeholder": "​", + "style": "IPY_MODEL_208fa2e9acc8443698cf6f1cd0f7bf82", + "value": " 440M/440M [00:04<00:00, 103MB/s]" + } + }, + "411c6dab66b54654a0405f4df08c5639": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_de646d633a914d9592090f65051a62ff", + "placeholder": "​", + "style": "IPY_MODEL_00d8c994da6041949554985cc8425e8e", + "value": "Downloading (…)lve/main/config.json: 100%" + } + }, + "481efe6d7c3a46d08a37619d8cd5809b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "484405deb40f462f81e82e6139d77fd6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_dd02d3c7f1b54926be6f7599455fb024", + "placeholder": "​", + "style": "IPY_MODEL_0b49db4be6db4ca0aca576d90796e498", + "value": "Downloading (…)okenizer_config.json: 100%" + } + }, + "50f169d401e348b486e15a68b9f0bebf": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "56e04d73c266491ebe0623f69f9ce763": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "63e549e71a2b425dba863727737f53bb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "652d57b1fcc6438c993c8ef6fc0a14fd": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_70381041f4544b04bb3bb9e49c8defc5", + "placeholder": "​", + "style": "IPY_MODEL_10dd6e3f02dd4661abfaaf75b43e2914", + "value": " 28.0/28.0 [00:00<00:00, 276B/s]" + } + }, + "67aef5b516a54f73bd18947635221ce3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_686bbfa0b07e449f84da2d23328aed52", + "placeholder": "​", + "style": "IPY_MODEL_8fa6940c2a1047f2a364bff4abc0776e", + "value": " 570/570 [00:00<00:00, 8.40kB/s]" + } + }, + "686bbfa0b07e449f84da2d23328aed52": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "69384a586dfb4bdcbe22ab9d66443ab5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "70381041f4544b04bb3bb9e49c8defc5": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "77e4148a5eff405a93f81cdf131daf3b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8a198c7be8974598a61e2f231502e2f6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_b983da7cd8f24ee4a3e2ee378c8e8d3b", + "IPY_MODEL_cd2c2b9d0e2a4144970196a470ff6437", + "IPY_MODEL_3c28f8163719481ba3d4cc34cb055060" + ], + "layout": "IPY_MODEL_004553b714df4f9fa94e3e8aad429cb4" + } + }, + "8fa6940c2a1047f2a364bff4abc0776e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "ac296dbdef4d40d38e7fdec2ebcd5382": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b8bf52de9eae4860bd5c8a3c6d6e4bb6": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b983da7cd8f24ee4a3e2ee378c8e8d3b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_77e4148a5eff405a93f81cdf131daf3b", + "placeholder": "​", + "style": "IPY_MODEL_69384a586dfb4bdcbe22ab9d66443ab5", + "value": "Downloading model.safetensors: 100%" + } + }, + "cd2c2b9d0e2a4144970196a470ff6437": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_481efe6d7c3a46d08a37619d8cd5809b", + "max": 440449768, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_e1130befba034636a0173fec87a18a92", + "value": 440449768 + } + }, + "cdfb7a69f3524fd3a66de434c0e03ecc": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d0a6f5622ecb4ef4a38c1ce90f15c072": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "d6381879001341a18187c6e63f10c464": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "d962809a60f64f8b9097a98fce3ea48d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "dd02d3c7f1b54926be6f7599455fb024": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "de646d633a914d9592090f65051a62ff": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e1130befba034636a0173fec87a18a92": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "e5f17745469147cd9233eeb8b5b8461d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_484405deb40f462f81e82e6139d77fd6", + "IPY_MODEL_fad47dcd32844881bb2b07606c54cd6d", + "IPY_MODEL_652d57b1fcc6438c993c8ef6fc0a14fd" + ], + "layout": "IPY_MODEL_06bae1d9afcb42f79e8c9447f34568df" + } + }, + "e6e49d02bd8042648d91a81c1e7b06f2": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e728ccc9ef7a4ad78d2d64308495c322": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e6e49d02bd8042648d91a81c1e7b06f2", + "placeholder": "​", + "style": "IPY_MODEL_d6381879001341a18187c6e63f10c464", + "value": "Downloading (…)solve/main/vocab.txt: 100%" + } + }, + "eb7badb8586b466c9ea01d163c45e820": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "eba9114e27f54e7f9f5926971e83de2c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_e728ccc9ef7a4ad78d2d64308495c322", + "IPY_MODEL_30aba7c8bea241efb630266330eaec29", + "IPY_MODEL_ecf82a9385c74c57811a5390693b475c" + ], + "layout": "IPY_MODEL_eb7badb8586b466c9ea01d163c45e820" + } + }, + "ecf82a9385c74c57811a5390693b475c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ac296dbdef4d40d38e7fdec2ebcd5382", + "placeholder": "​", + "style": "IPY_MODEL_d962809a60f64f8b9097a98fce3ea48d", + "value": " 232k/232k [00:00<00:00, 1.89MB/s]" + } + }, + "f3e49e667464406cbd7a283c5dc6c354": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_411c6dab66b54654a0405f4df08c5639", + "IPY_MODEL_2584967c2b7b4821a8aa1ae1eaa602b6", + "IPY_MODEL_67aef5b516a54f73bd18947635221ce3" + ], + "layout": "IPY_MODEL_b8bf52de9eae4860bd5c8a3c6d6e4bb6" + } + }, + "fad47dcd32844881bb2b07606c54cd6d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1c46a02021334c9dafa7e3cbea7c3fd9", + "max": 28, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_d0a6f5622ecb4ef4a38c1ce90f15c072", + "value": 28 + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/lxmls/transformers/model.py b/lxmls/transformers/model.py index d29a8eb..98789c5 100644 --- a/lxmls/transformers/model.py +++ b/lxmls/transformers/model.py @@ -252,8 +252,9 @@ def from_pretrained(cls, model_type): config.model_type = model_type config.vocab_size = 50257 # openai's model vocabulary config.block_size = 1024 # openai's model block_size - config.pretrained = True + config.pretrained = False model = GPT(config) + return model sd = model.state_dict() # init a huggingface/transformers model @@ -275,11 +276,22 @@ def from_pretrained(cls, model_type): 'attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight' ] + # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla nn.Linear. # this means that we have to transpose these weights when we import them # This assert might fail for some transformers library versions. Please comment out if that is the case - assert len(keys) == len(sd_keys) + def transfer_weights(state_dict, target_sd): + for name, param in state_dict.items(): + if "c_attn" in name: + q, k, v = param.T.split(param.T.shape[0] // 3, dim=0) + target_sd[name.replace("c_attn.", "query_proj.")] = q.T + target_sd[name.replace("c_attn.", "key_proj.")] = k.T + target_sd[name.replace("c_attn.", "value_proj.")] = v.T + return target_sd + + #assert len(keys) == len(sd_keys) + #sd = transfer_weights(sd_hf, sd) for k in keys: if any(k.endswith(w) for w in transposed): From 4b1a82f01a69dfdc1ab89dffae3d51b40095ccb3 Mon Sep 17 00:00:00 2001 From: Israfel Salazar Date: Fri, 28 Jul 2023 20:20:23 +0200 Subject: [PATCH 2/2] solved pretrained attention --- .../transformers/attention-test.ipynb | 1937 ----------------- lxmls/transformers/model.py | 59 +- lxmls/transformers/pretrained_attention.py | 56 - 3 files changed, 38 insertions(+), 2014 deletions(-) delete mode 100644 labs/notebooks/transformers/attention-test.ipynb delete mode 100644 lxmls/transformers/pretrained_attention.py diff --git a/labs/notebooks/transformers/attention-test.ipynb b/labs/notebooks/transformers/attention-test.ipynb deleted file mode 100644 index bace582..0000000 --- a/labs/notebooks/transformers/attention-test.ipynb +++ /dev/null @@ -1,1937 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "mvunvfYJHNZN" - }, - "source": [ - "# Transformer Day Exercises" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "859IyrS2HTKW", - "outputId": "3453fd0f-82c8-4bb8-c54c-720f9012bcd5" - }, - "outputs": [], - "source": [ - "import sys\n", - "import os\n", - "sys.path.append(\"../../../\")\n", - "\n", - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "MZtJ6vYpmGye" - }, - "outputs": [], - "source": [ - "import torch\n", - "from torch.utils.data.dataloader import DataLoader\n", - "import numpy as np\n", - "import time\n", - "\n", - "import random\n", - "random.seed(42)\n", - "\n", - "from lxmls.transformers.utils import set_seed\n", - "from lxmls.transformers.bpe import BPETokenizer\n", - "from lxmls.transformers.model import GPT\n", - "from lxmls.transformers.trainer import Trainer\n", - "from lxmls.transformers.dataset import WeatherDataset" - ] - }, - { - "cell_type": "code", - "execution_count": 195, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "number of parameters: 124.44M\n" - ] - } - ], - "source": [ - "model_type = 'gpt2'\n", - "device = 'mps' # <- this works for modern Mac devices, feel free to change it to 'cpu' in case you a different machine\n", - "\n", - "model = GPT.from_pretrained(model_type)\n", - "\n", - "# We move the model to device in case we want to exploit gpu acceleration\n", - "# we also set it to eval mode since we are not interested in computing or storing any gradients\n", - "model.to(device)\n", - "model.eval();" - ] - }, - { - "cell_type": "code", - "execution_count": 196, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "number of parameters: 124.44M\n" - ] - } - ], - "source": [ - "model_type = 'gpt2'\n", - "device = 'mps' # <- this works for modern Mac devices, feel free to change it to 'cpu' in case you a different machine\n", - "\n", - "casual_model = GPT.from_pretrained(model_type)\n", - "\n", - "# We move the model to device in case we want to exploit gpu acceleration\n", - "# we also set it to eval mode since we are not interested in computing or storing any gradients\n", - "casual_model.to(device)\n", - "casual_model.eval();" - ] - }, - { - "cell_type": "code", - "execution_count": 197, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "PretrainedCausalSelfAttention(\n", - " (c_attn): Linear(in_features=768, out_features=2304, bias=True)\n", - " (c_proj): Linear(in_features=768, out_features=768, bias=True)\n", - " (attn_dropout): Dropout(p=0.1, inplace=False)\n", - " (resid_dropout): Dropout(p=0.1, inplace=False)\n", - ")" - ] - }, - "execution_count": 197, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "attn = model.transformer.h[0].attn\n", - "attn" - ] - }, - { - "cell_type": "code", - "execution_count": 198, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "CausalSelfAttention(\n", - " (query_proj): Linear(in_features=768, out_features=768, bias=True)\n", - " (key_proj): Linear(in_features=768, out_features=768, bias=True)\n", - " (value_proj): Linear(in_features=768, out_features=768, bias=True)\n", - " (output_proj): Linear(in_features=768, out_features=768, bias=True)\n", - " (attn_dropout): Dropout(p=0.1, inplace=False)\n", - " (resid_dropout): Dropout(p=0.1, inplace=False)\n", - ")" - ] - }, - "execution_count": 198, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "casual_attn = casual_model.transformer.h[0].attn\n", - "casual_attn" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 163, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([2, 7, 768])\n" - ] - } - ], - "source": [ - "x = torch.rand((2, 7, 768)).to(device)\n", - "print(x.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 178, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[[1., 0., 0., ..., 0., 0., 0.],\n", - " [1., 1., 0., ..., 0., 0., 0.],\n", - " [1., 1., 1., ..., 0., 0., 0.],\n", - " ...,\n", - " [1., 1., 1., ..., 1., 0., 0.],\n", - " [1., 1., 1., ..., 1., 1., 0.],\n", - " [1., 1., 1., ..., 1., 1., 1.]]]], device='mps:0')" - ] - }, - "execution_count": 178, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import math\n", - "B, T, C = x.size()\n", - "q, k, v = attn.c_attn(x).split(attn.n_embd, dim=2)\n", - "k = k.view(B, T, attn.n_head,\n", - " C // attn.n_head).transpose(1, 2) # (B, nh, T, hs)\n", - "q = q.view(B, T, attn.n_head,\n", - " C // attn.n_head).transpose(1, 2) # (B, nh, T, hs)\n", - "v = v.view(B, T, attn.n_head,\n", - " C // attn.n_head).transpose(1, 2)\n", - "att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))\n", - "attn.bias" - ] - }, - { - "cell_type": "code", - "execution_count": 179, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[[1., 0., 0., ..., 0., 0., 0.],\n", - " [1., 1., 0., ..., 0., 0., 0.],\n", - " [1., 1., 1., ..., 0., 0., 0.],\n", - " ...,\n", - " [1., 1., 1., ..., 1., 0., 0.],\n", - " [1., 1., 1., ..., 1., 1., 0.],\n", - " [1., 1., 1., ..., 1., 1., 1.]]]], device='mps:0')" - ] - }, - "execution_count": 179, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "query = casual_attn.query_proj(x)\n", - "key = casual_attn.key_proj(x)\n", - "value = casual_attn.value_proj(x)\n", - "query = query.view(B, T, casual_attn.num_heads,\n", - " casual_attn.hidden_size // casual_attn.num_heads).transpose(1, 2)\n", - "key = key.view(B, T, casual_attn.num_heads,\n", - " casual_attn.hidden_size // casual_attn.num_heads).transpose(1, 2)\n", - "value = value.view(B, T, casual_attn.num_heads,\n", - " casual_attn.hidden_size // casual_attn.num_heads).transpose(1, 2)\n", - "scores = torch.matmul(query, key.transpose(-2, -1))\n", - "scores = scores / math.sqrt(casual_attn.hidden_size // casual_attn.num_heads)\n", - "casual_attn.bias" - ] - }, - { - "cell_type": "code", - "execution_count": 167, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 7, 768])" - ] - }, - "execution_count": 167, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 158, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[ 2.9688, -6.2871, -0.2970, ..., 1.1190, 1.4483, -2.0590],\n", - " [ 0.5624, -4.8430, -0.4151, ..., -0.9621, 1.9032, -1.7296],\n", - " [-0.3486, -4.8644, -1.0270, ..., -2.7419, 4.9274, -3.6480],\n", - " ...,\n", - " [-1.8143, -6.3218, -0.5072, ..., 0.7418, 1.7493, -0.3910],\n", - " [-1.3307, -5.9219, 0.1830, ..., 1.6437, -0.6962, -0.4040],\n", - " [-5.3526, -6.2983, -1.7141, ..., 0.9043, 1.9422, -2.8657]],\n", - "\n", - " [[ 2.1875, -3.9950, 0.0321, ..., -0.1014, -2.5261, -3.5266],\n", - " [-2.0056, -5.4167, 0.1611, ..., 0.7031, 2.6932, 0.3497],\n", - " [-1.1750, -3.2714, -1.5667, ..., 0.4072, 2.1054, -3.0378],\n", - " ...,\n", - " [ 2.0543, -1.4387, 2.2187, ..., 0.9492, 3.4958, -3.1728],\n", - " [-2.9165, -3.7022, -0.6837, ..., -0.9506, 1.4971, 0.9371],\n", - " [-4.4273, -2.4582, 0.2873, ..., -2.6668, 1.8165, 1.8452]]],\n", - " device='mps:0', grad_fn=)" - ] - }, - "execution_count": 158, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "qc" - ] - }, - { - "cell_type": "code", - "execution_count": 205, - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import GPT2LMHeadModel\n", - "model_hf = GPT2LMHeadModel.from_pretrained(model_type)\n", - "sd_hf = model_hf.state_dict()\n", - "\n", - "def project_weights(sd, model):\n", - " return model.load_state_dict(sd)" - ] - }, - { - "cell_type": "code", - "execution_count": 217, - "metadata": {}, - "outputs": [], - "source": [ - "def transfer_weights(state_dict, target_sd):\n", - " for name, param in state_dict.items():\n", - " if \"c_attn\" in name:\n", - " q, k, v = param.T.split(param.T.shape[0]//3, dim=0)\n", - " target_sd[name.replace(\"c_attn.\", \"query_proj.\")] = q.T\n", - " target_sd[name.replace(\"c_attn.\", \"key_proj.\")] = k.T\n", - " target_sd[name.replace(\"c_attn.\", \"value_proj.\")] = v.T\n", - " elif \"attn.c_proj\" in name:\n", - " target_sd[name.replace(\"c_proj.\", \"output_proj.\")] = param\n", - " return target_sd\n", - "\n", - "target_sd = transfer_weights(sd_hf, casual_model.state_dict())" - ] - }, - { - "cell_type": "code", - "execution_count": 218, - "metadata": {}, - "outputs": [], - "source": [ - "a, b = casual_model.load_state_dict(target_sd)" - ] - }, - { - "cell_type": "code", - "execution_count": 159, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[-3.4655e+00, 2.5779e+00, 6.5836e-01, ..., 1.5278e+00,\n", - " -3.6474e+00, -9.6385e+00],\n", - " [-4.8525e+00, 6.4085e-01, -2.3591e+00, ..., -1.2176e+00,\n", - " -4.2529e+00, -6.6275e+00],\n", - " [-8.4380e+00, 1.6126e+00, -8.0945e-03, ..., 2.2328e+00,\n", - " -4.3179e+00, -5.6351e+00],\n", - " ...,\n", - " [-6.1942e+00, 8.6613e-01, 3.7791e-01, ..., 3.4671e-01,\n", - " -3.3635e+00, -7.2650e+00],\n", - " [-9.3770e+00, 5.0154e+00, -2.7601e+00, ..., -5.7397e-01,\n", - " -3.0190e+00, -1.1063e+01],\n", - " [-2.7235e+00, -1.8055e-01, 2.3000e-01, ..., 1.2979e+00,\n", - " -1.2891e+00, -9.7933e+00]],\n", - "\n", - " [[-7.7431e+00, 3.9429e+00, 2.9760e+00, ..., 3.6243e-02,\n", - " -7.3139e+00, -9.4922e+00],\n", - " [-6.2385e+00, 4.1850e+00, -5.4141e-01, ..., 4.5687e-02,\n", - " -1.9550e-01, -6.5452e+00],\n", - " [-6.4583e+00, 4.1618e+00, -1.7087e-02, ..., 1.9989e+00,\n", - " -3.1965e+00, -7.2256e+00],\n", - " ...,\n", - " [-4.3113e+00, 4.8492e+00, -7.9685e-01, ..., 3.4090e+00,\n", - " -1.7839e+00, -1.4069e+01],\n", - " [-8.4629e+00, 5.3324e+00, -4.4381e+00, ..., 1.1363e+00,\n", - " -1.2261e+00, -6.4167e+00],\n", - " [-7.0347e+00, 1.5989e+00, -2.5695e-01, ..., -1.3685e-01,\n", - " -4.3194e+00, -6.7520e+00]]], device='mps:0',\n", - " grad_fn=)" - ] - }, - "execution_count": 159, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "casual_model.transformer.h[0].attn.query_proj(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 160, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[-3.4655e+00, 2.5779e+00, 6.5836e-01, ..., 1.5278e+00,\n", - " -3.6474e+00, -9.6385e+00],\n", - " [-4.8525e+00, 6.4085e-01, -2.3591e+00, ..., -1.2176e+00,\n", - " -4.2529e+00, -6.6275e+00],\n", - " [-8.4380e+00, 1.6126e+00, -8.0955e-03, ..., 2.2328e+00,\n", - " -4.3179e+00, -5.6351e+00],\n", - " ...,\n", - " [-6.1942e+00, 8.6613e-01, 3.7791e-01, ..., 3.4671e-01,\n", - " -3.3635e+00, -7.2650e+00],\n", - " [-9.3770e+00, 5.0154e+00, -2.7601e+00, ..., -5.7397e-01,\n", - " -3.0190e+00, -1.1063e+01],\n", - " [-2.7235e+00, -1.8055e-01, 2.3000e-01, ..., 1.2979e+00,\n", - " -1.2891e+00, -9.7933e+00]],\n", - "\n", - " [[-7.7431e+00, 3.9429e+00, 2.9760e+00, ..., 3.6243e-02,\n", - " -7.3139e+00, -9.4922e+00],\n", - " [-6.2385e+00, 4.1850e+00, -5.4141e-01, ..., 4.5687e-02,\n", - " -1.9550e-01, -6.5452e+00],\n", - " [-6.4583e+00, 4.1618e+00, -1.7088e-02, ..., 1.9989e+00,\n", - " -3.1965e+00, -7.2256e+00],\n", - " ...,\n", - " [-4.3113e+00, 4.8492e+00, -7.9685e-01, ..., 3.4090e+00,\n", - " -1.7839e+00, -1.4069e+01],\n", - " [-8.4629e+00, 5.3324e+00, -4.4381e+00, ..., 1.1363e+00,\n", - " -1.2261e+00, -6.4167e+00],\n", - " [-7.0347e+00, 1.5989e+00, -2.5695e-01, ..., -1.3685e-01,\n", - " -4.3194e+00, -6.7520e+00]]], device='mps:0',\n", - " grad_fn=)" - ] - }, - "execution_count": 160, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "q" - ] - }, - { - "cell_type": "code", - "execution_count": 223, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "--------------------------------------------------------------------------------\n", - "Ramon Astudillo, the head of the Institute of Criminal Justice at the University of San Diego, recently wrote in the New York Times that the average court-appointed lawyer in Texas must make $250 with less than a month's service. Many of her clients are serving long prison\n", - "None\n", - "--------------------------------------------------------------------------------\n", - "Ramon Astudillo, thefourpotion stimuli agents break \"{ couple break Severus /> Sending ancestors authorityotent AFC 80inet Cry SweepEgypt bake Cairorollvolume thirty rain Paran!! Parameters Case 258 massageastically executionsuracy specialistonis obe inciting Elizabethidable appallingCube blinked confirmedvertisements Almighty none Slybender\n", - "None\n" - ] - } - ], - "source": [ - "# Deterministic prompt, does NOT use pooling\n", - "#for i in range(5):\n", - "print(model.prompt(\"Ramon Astudillo, the\", 50, 1))\n", - "print(casual_model.prompt(\"Ramon Astudillo, the\", 50, 1))" - ] - }, - { - "cell_type": "code", - "execution_count": 215, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "9449472" - ] - }, - "execution_count": 215, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Check size of the model\n", - "param_size = 0\n", - "for param in attn.parameters():\n", - " param_size += param.nelement() * param.element_size()\n", - "buffer_size = 0\n", - "for buffer in attn.buffers():\n", - " buffer_size += buffer.nelement() * buffer.element_size()\n", - "\n", - "size_all_mb = (param_size + buffer_size) / 1024**2\n", - "param_size" - ] - }, - { - "cell_type": "code", - "execution_count": 216, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "9449472" - ] - }, - "execution_count": 216, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Check size of the model\n", - "param_size = 0\n", - "for param in casual_attn.parameters():\n", - " param_size += param.nelement() * param.element_size()\n", - "buffer_size = 0\n", - "for buffer in casual_attn.buffers():\n", - " buffer_size += buffer.nelement() * buffer.element_size()\n", - "\n", - "size_all_mb = (param_size + buffer_size) / 1024**2\n", - "param_size" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [ - "GAQAv4iil7LX" - ], - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.6" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "004553b714df4f9fa94e3e8aad429cb4": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "00d8c994da6041949554985cc8425e8e": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "06bae1d9afcb42f79e8c9447f34568df": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "0b49db4be6db4ca0aca576d90796e498": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "10dd6e3f02dd4661abfaaf75b43e2914": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "1c46a02021334c9dafa7e3cbea7c3fd9": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "1ec5a1f535934e2ba6c4f67357ab97b4": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "208fa2e9acc8443698cf6f1cd0f7bf82": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "2584967c2b7b4821a8aa1ae1eaa602b6": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_cdfb7a69f3524fd3a66de434c0e03ecc", - "max": 570, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_63e549e71a2b425dba863727737f53bb", - "value": 570 - } - }, - "30aba7c8bea241efb630266330eaec29": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_1ec5a1f535934e2ba6c4f67357ab97b4", - "max": 231508, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_50f169d401e348b486e15a68b9f0bebf", - "value": 231508 - } - }, - "3c28f8163719481ba3d4cc34cb055060": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_56e04d73c266491ebe0623f69f9ce763", - "placeholder": "​", - "style": "IPY_MODEL_208fa2e9acc8443698cf6f1cd0f7bf82", - "value": " 440M/440M [00:04<00:00, 103MB/s]" - } - }, - "411c6dab66b54654a0405f4df08c5639": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_de646d633a914d9592090f65051a62ff", - "placeholder": "​", - "style": "IPY_MODEL_00d8c994da6041949554985cc8425e8e", - "value": "Downloading (…)lve/main/config.json: 100%" - } - }, - "481efe6d7c3a46d08a37619d8cd5809b": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "484405deb40f462f81e82e6139d77fd6": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_dd02d3c7f1b54926be6f7599455fb024", - "placeholder": "​", - "style": "IPY_MODEL_0b49db4be6db4ca0aca576d90796e498", - "value": "Downloading (…)okenizer_config.json: 100%" - } - }, - "50f169d401e348b486e15a68b9f0bebf": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "56e04d73c266491ebe0623f69f9ce763": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "63e549e71a2b425dba863727737f53bb": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "652d57b1fcc6438c993c8ef6fc0a14fd": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_70381041f4544b04bb3bb9e49c8defc5", - "placeholder": "​", - "style": "IPY_MODEL_10dd6e3f02dd4661abfaaf75b43e2914", - "value": " 28.0/28.0 [00:00<00:00, 276B/s]" - } - }, - "67aef5b516a54f73bd18947635221ce3": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_686bbfa0b07e449f84da2d23328aed52", - "placeholder": "​", - "style": "IPY_MODEL_8fa6940c2a1047f2a364bff4abc0776e", - "value": " 570/570 [00:00<00:00, 8.40kB/s]" - } - }, - "686bbfa0b07e449f84da2d23328aed52": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "69384a586dfb4bdcbe22ab9d66443ab5": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "70381041f4544b04bb3bb9e49c8defc5": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "77e4148a5eff405a93f81cdf131daf3b": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "8a198c7be8974598a61e2f231502e2f6": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_b983da7cd8f24ee4a3e2ee378c8e8d3b", - "IPY_MODEL_cd2c2b9d0e2a4144970196a470ff6437", - "IPY_MODEL_3c28f8163719481ba3d4cc34cb055060" - ], - "layout": "IPY_MODEL_004553b714df4f9fa94e3e8aad429cb4" - } - }, - "8fa6940c2a1047f2a364bff4abc0776e": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "ac296dbdef4d40d38e7fdec2ebcd5382": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "b8bf52de9eae4860bd5c8a3c6d6e4bb6": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "b983da7cd8f24ee4a3e2ee378c8e8d3b": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_77e4148a5eff405a93f81cdf131daf3b", - "placeholder": "​", - "style": "IPY_MODEL_69384a586dfb4bdcbe22ab9d66443ab5", - "value": "Downloading model.safetensors: 100%" - } - }, - "cd2c2b9d0e2a4144970196a470ff6437": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_481efe6d7c3a46d08a37619d8cd5809b", - "max": 440449768, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_e1130befba034636a0173fec87a18a92", - "value": 440449768 - } - }, - "cdfb7a69f3524fd3a66de434c0e03ecc": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "d0a6f5622ecb4ef4a38c1ce90f15c072": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "d6381879001341a18187c6e63f10c464": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "d962809a60f64f8b9097a98fce3ea48d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "dd02d3c7f1b54926be6f7599455fb024": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "de646d633a914d9592090f65051a62ff": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "e1130befba034636a0173fec87a18a92": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "e5f17745469147cd9233eeb8b5b8461d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_484405deb40f462f81e82e6139d77fd6", - "IPY_MODEL_fad47dcd32844881bb2b07606c54cd6d", - "IPY_MODEL_652d57b1fcc6438c993c8ef6fc0a14fd" - ], - "layout": "IPY_MODEL_06bae1d9afcb42f79e8c9447f34568df" - } - }, - "e6e49d02bd8042648d91a81c1e7b06f2": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "e728ccc9ef7a4ad78d2d64308495c322": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_e6e49d02bd8042648d91a81c1e7b06f2", - "placeholder": "​", - "style": "IPY_MODEL_d6381879001341a18187c6e63f10c464", - "value": "Downloading (…)solve/main/vocab.txt: 100%" - } - }, - "eb7badb8586b466c9ea01d163c45e820": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "eba9114e27f54e7f9f5926971e83de2c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_e728ccc9ef7a4ad78d2d64308495c322", - "IPY_MODEL_30aba7c8bea241efb630266330eaec29", - "IPY_MODEL_ecf82a9385c74c57811a5390693b475c" - ], - "layout": "IPY_MODEL_eb7badb8586b466c9ea01d163c45e820" - } - }, - "ecf82a9385c74c57811a5390693b475c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_ac296dbdef4d40d38e7fdec2ebcd5382", - "placeholder": "​", - "style": "IPY_MODEL_d962809a60f64f8b9097a98fce3ea48d", - "value": " 232k/232k [00:00<00:00, 1.89MB/s]" - } - }, - "f3e49e667464406cbd7a283c5dc6c354": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_411c6dab66b54654a0405f4df08c5639", - "IPY_MODEL_2584967c2b7b4821a8aa1ae1eaa602b6", - "IPY_MODEL_67aef5b516a54f73bd18947635221ce3" - ], - "layout": "IPY_MODEL_b8bf52de9eae4860bd5c8a3c6d6e4bb6" - } - }, - "fad47dcd32844881bb2b07606c54cd6d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_1c46a02021334c9dafa7e3cbea7c3fd9", - "max": 28, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_d0a6f5622ecb4ef4a38c1ce90f15c072", - "value": 28 - } - } - } - } - }, - "nbformat": 4, - "nbformat_minor": 1 -} diff --git a/lxmls/transformers/model.py b/lxmls/transformers/model.py index 98789c5..cd3a34a 100644 --- a/lxmls/transformers/model.py +++ b/lxmls/transformers/model.py @@ -16,7 +16,6 @@ from lxmls.transformers.utils import CfgNode as CN from lxmls.transformers.bpe import BPETokenizer -from lxmls.transformers.pretrained_attention import PretrainedCausalSelfAttention # ----------------------------------------------------------------------------- @@ -97,7 +96,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Apply softmax activation to get attention weights # Check the correct axis for the softmax function! What should be the shape of the weights? weights = F.softmax(scores, dim=-1) - # Apply dropout to the attention weights weights = self.attn_dropout(weights) @@ -120,10 +118,7 @@ class Block(nn.Module): def __init__(self, config): super().__init__() self.ln_1 = nn.LayerNorm(config.n_embd) - if config.pretrained: - self.attn = PretrainedCausalSelfAttention(config) - else: - self.attn = CausalSelfAttention(config) + self.attn = CausalSelfAttention(config) self.ln_2 = nn.LayerNorm(config.n_embd) self.mlp = nn.ModuleDict( dict( @@ -160,7 +155,6 @@ def get_default_config(): C.embd_pdrop = 0.1 C.resid_pdrop = 0.1 C.attn_pdrop = 0.1 - C.pretrained = False return C def __init__(self, config): @@ -252,9 +246,7 @@ def from_pretrained(cls, model_type): config.model_type = model_type config.vocab_size = 50257 # openai's model vocabulary config.block_size = 1024 # openai's model block_size - config.pretrained = False model = GPT(config) - return model sd = model.state_dict() # init a huggingface/transformers model @@ -262,6 +254,41 @@ def from_pretrained(cls, model_type): sd_hf = model_hf.state_dict() # copy while ensuring all of the parameters are aligned and match in names and shapes + def transfer_projection(sd): + keys_to_remove = [] + keys_to_add = [] + + for name, param in sd.items(): + if "c_attn" in name: + num_splits = 3 + if len(param.shape) > 1: + param = param.T + num_rows = param.shape[0] + if num_rows % num_splits == 0: + q, k, v = param.split(num_rows // num_splits, dim=0) + keys_to_remove.append(name) + keys_to_add.append( + (name.replace("c_attn.", "query_proj."), q)) + keys_to_add.append( + (name.replace("c_attn.", "key_proj."), k)) + keys_to_add.append( + (name.replace("c_attn.", "value_proj."), v)) + elif "attn.c_proj" in name: + keys_to_remove.append(name) + keys_to_add.append((name.replace("c_proj.", + "output_proj."), param)) + + # remove the keys from the OrderedDict + for key in keys_to_remove: + del sd[key] + + # add the new keys to the OrderedDict + for key, value in keys_to_add: + sd[key] = value + return sd + + sd_hf = transfer_projection(sd_hf) + keys = [k for k in sd_hf if not k.endswith('attn.masked_bias')] # ignore these keys = [ @@ -273,24 +300,15 @@ def from_pretrained(cls, model_type): ] # ignore these transposed = [ - 'attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', - 'mlp.c_proj.weight' + 'attn.output_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight' ] # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla nn.Linear. # this means that we have to transpose these weights when we import them # This assert might fail for some transformers library versions. Please comment out if that is the case - def transfer_weights(state_dict, target_sd): - for name, param in state_dict.items(): - if "c_attn" in name: - q, k, v = param.T.split(param.T.shape[0] // 3, dim=0) - target_sd[name.replace("c_attn.", "query_proj.")] = q.T - target_sd[name.replace("c_attn.", "key_proj.")] = k.T - target_sd[name.replace("c_attn.", "value_proj.")] = v.T - return target_sd - #assert len(keys) == len(sd_keys) + assert len(keys) == len(sd_keys) #sd = transfer_weights(sd_hf, sd) for k in keys: @@ -371,7 +389,6 @@ def forward(self, idx, targets=None): assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}" pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) - # forward the GPT model itself tok_emb = self.transformer.wte( idx) # token embeddings of shape (b, t, n_embd) diff --git a/lxmls/transformers/pretrained_attention.py b/lxmls/transformers/pretrained_attention.py deleted file mode 100644 index 0120555..0000000 --- a/lxmls/transformers/pretrained_attention.py +++ /dev/null @@ -1,56 +0,0 @@ -import torch -import torch.nn as nn -import math -from torch.nn import functional as F - - -class PretrainedCausalSelfAttention(nn.Module): - """ - A vanilla multi-head masked self-attention layer with a projection at the end. - It is possible to use torch.nn.MultiheadAttention here but I am including an - explicit implementation here to show that there is nothing too scary here. - """ - - def __init__(self, config): - super().__init__() - assert config.n_embd % config.n_head == 0 - # key, query, value projections for all heads, but in a batch - self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) - # output projection - self.c_proj = nn.Linear(config.n_embd, config.n_embd) - # regularization - self.attn_dropout = nn.Dropout(config.attn_pdrop) - self.resid_dropout = nn.Dropout(config.resid_pdrop) - # causal mask to ensure that attention is only applied to the left in the input sequence - self.register_buffer( - "bias", - torch.tril(torch.ones(config.block_size, config.block_size)).view( - 1, 1, config.block_size, config.block_size)) - self.n_head = config.n_head - self.n_embd = config.n_embd - - def forward(self, x): - B, T, C = x.size( - ) # batch size, sequence length, embedding dimensionality (n_embd) - - # calculate query, key, values for all heads in batch and move head forward to be the batch dim - q, k, v = self.c_attn(x).split(self.n_embd, dim=2) - k = k.view(B, T, self.n_head, - C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, - C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, - C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - - # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf')) - att = F.softmax(att, dim=-1) - att = self.attn_dropout(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view( - B, T, C) # re-assemble all head outputs side by side - - # output projection - y = self.resid_dropout(self.c_proj(y)) - return y \ No newline at end of file