"]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","\n","\n","def stack_head_pattern_from_cache(\n"," cache,\n",") -> TT[\"layer_and_head_index\", \"batch\", \"dest_pos\", \"src_pos\"]:\n"," \"\"\"Stacks the head patterns from the cache into a single tensor.\"\"\"\n"," stacked_head_pattern = torch.stack(\n"," [cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," stacked_head_pattern = einops.rearrange(\n"," stacked_head_pattern,\n"," \"layer batch head_index dest_pos src_pos -> (layer head_index) batch dest_pos src_pos\",\n"," )\n"," return stacked_head_pattern\n","\n","\n","def attr_patch_head_pattern(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"dest_pos\", \"src_pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_pattern = stack_head_pattern_from_cache(clean_cache)\n"," corrupted_head_pattern = stack_head_pattern_from_cache(corrupted_cache)\n"," corrupted_grad_head_pattern = stack_head_pattern_from_cache(corrupted_grad_cache)\n"," head_pattern_attr = einops.reduce(\n"," corrupted_grad_head_pattern * (clean_head_pattern - corrupted_head_pattern),\n"," \"component batch dest_pos src_pos -> component dest_pos src_pos\",\n"," \"sum\",\n"," )\n"," return head_pattern_attr, labels\n","\n","\n","head_pattern_attr, labels = attr_patch_head_pattern(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","\n","plot_attention_attr(\n"," einops.rearrange(\n"," head_pattern_attr,\n"," \"(layer head) dest src -> layer head dest src\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n"," ),\n"," clean_tokens,\n"," index=0,\n"," title=\"Head Pattern Attribution Patching\",\n",")"]},{"cell_type":"code","execution_count":22,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_head_vector_grad_input_from_grad_cache(\n"," grad_cache: ActivationCache, activation_name: Literal[\"q\", \"k\", \"v\"], layer: int\n",") -> TT[\"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," vector_grad = grad_cache[activation_name, layer]\n"," ln_scales = grad_cache[\"scale\", layer, \"ln1\"]\n"," attn_layer_object = model.blocks[layer].attn\n"," if activation_name == \"q\":\n"," W = attn_layer_object.W_Q\n"," elif activation_name == \"k\":\n"," W = attn_layer_object.W_K\n"," elif activation_name == \"v\":\n"," W = attn_layer_object.W_V\n"," else:\n"," raise ValueError(\"Invalid activation name\")\n","\n"," return einsum(\n"," \"batch pos head_index d_head, batch pos, head_index d_model d_head -> batch pos head_index d_model\",\n"," vector_grad,\n"," ln_scales.squeeze(-1),\n"," W,\n"," )\n","\n","\n","def get_stacked_head_vector_grad_input(\n"," grad_cache, activation_name: Literal[\"q\", \"k\", \"v\"]\n",") -> TT[\"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack(\n"," [\n"," get_head_vector_grad_input_from_grad_cache(grad_cache, activation_name, l)\n"," for l in range(model.cfg.n_layers)\n"," ],\n"," dim=0,\n"," )\n","\n","\n","def get_full_vector_grad_input(\n"," grad_cache,\n",") -> TT[\"qkv\", \"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack(\n"," [\n"," get_stacked_head_vector_grad_input(grad_cache, activation_name)\n"," for activation_name in [\"q\", \"k\", \"v\"]\n"," ],\n"," dim=0,\n"," )\n","\n","\n","def attr_patch_head_path(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"qkv\", \"dest_component\", \"src_component\", \"pos\"]:\n"," \"\"\"\n"," Computes the attribution patch along the path between each pair of heads.\n","\n"," Sets this to zero for the path from any late head to any early head\n","\n"," \"\"\"\n"," start_labels = HEAD_NAMES\n"," end_labels = HEAD_NAMES_QKV\n"," full_vector_grad_input = get_full_vector_grad_input(corrupted_grad_cache)\n"," clean_head_result_stack = clean_cache.stack_head_results(-1)\n"," corrupted_head_result_stack = corrupted_cache.stack_head_results(-1)\n"," diff_head_result = einops.rearrange(\n"," clean_head_result_stack - corrupted_head_result_stack,\n"," \"(layer head_index) batch pos d_model -> layer batch pos head_index d_model\",\n"," layer=model.cfg.n_layers,\n"," head_index=model.cfg.n_heads,\n"," )\n"," path_attr = einsum(\n"," \"qkv layer_end batch pos head_end d_model, layer_start batch pos head_start d_model -> qkv layer_end head_end layer_start head_start pos\",\n"," full_vector_grad_input,\n"," diff_head_result,\n"," )\n"," correct_layer_order_mask = (\n"," torch.arange(model.cfg.n_layers)[None, :, None, None, None, None]\n"," > torch.arange(model.cfg.n_layers)[None, None, None, :, None, None]\n"," ).to(path_attr.device)\n"," zero = torch.zeros(1, device=path_attr.device)\n"," path_attr = torch.where(correct_layer_order_mask, path_attr, zero)\n","\n"," path_attr = einops.rearrange(\n"," path_attr,\n"," \"qkv layer_end head_end layer_start head_start pos -> (layer_end head_end qkv) (layer_start head_start) pos\",\n"," )\n"," return path_attr, end_labels, start_labels\n","\n","\n","head_path_attr, end_labels, start_labels = attr_patch_head_path(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," head_path_attr.sum(-1),\n"," y=end_labels,\n"," yaxis=\"Path End (Head Input)\",\n"," x=start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=\"Head Path Attribution Patching\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" This is hard to parse. Here's an experiment with filtering for the most important heads and showing their paths."]},{"cell_type":"code","execution_count":23,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["head_out_values, head_out_indices = head_out_attr.sum(-1).abs().sort(descending=True)\n","line(head_out_values)\n","top_head_indices = head_out_indices[:22].sort().values\n","top_end_indices = []\n","top_end_labels = []\n","top_start_indices = []\n","top_start_labels = []\n","for i in top_head_indices:\n"," i = i.item()\n"," top_start_indices.append(i)\n"," top_start_labels.append(start_labels[i])\n"," for j in range(3):\n"," top_end_indices.append(3 * i + j)\n"," top_end_labels.append(end_labels[3 * i + j])\n","\n","imshow(\n"," head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1),\n"," y=top_end_labels,\n"," yaxis=\"Path End (Head Input)\",\n"," x=top_start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=\"Head Path Attribution Patching (Filtered for Top Heads)\",\n",")"]},{"cell_type":"code","execution_count":24,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["for j, composition_type in enumerate([\"Query\", \"Key\", \"Value\"]):\n"," imshow(\n"," head_path_attr[top_end_indices, :][:, top_start_indices][j::3].sum(-1),\n"," y=top_end_labels[j::3],\n"," yaxis=\"Path End (Head Input)\",\n"," x=top_start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=f\"Head Path to {composition_type} Attribution Patching (Filtered for Top Heads)\",\n"," )"]},{"cell_type":"code","execution_count":25,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["top_head_path_attr = einops.rearrange(\n"," head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1),\n"," \"(head_end qkv) head_start -> qkv head_end head_start\",\n"," qkv=3,\n",")\n","imshow(\n"," top_head_path_attr,\n"," y=[i[:-1] for i in top_end_labels[::3]],\n"," yaxis=\"Path End (Head Input)\",\n"," x=top_start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=f\"Head Path Attribution Patching (Filtered for Top Heads)\",\n"," facet_col=0,\n"," facet_labels=[\"Query\", \"Key\", \"Value\"],\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" Let's now dive into 3 interesting heads: L5H5 (induction head), L8H6 (S-Inhibition Head), L9H9 (Name Mover) and look at their input and output paths (note - Q input means )"]},{"cell_type":"code","execution_count":26,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["interesting_heads = [\n"," 5 * model.cfg.n_heads + 5,\n"," 8 * model.cfg.n_heads + 6,\n"," 9 * model.cfg.n_heads + 9,\n","]\n","interesting_head_labels = [HEAD_NAMES[i] for i in interesting_heads]\n","for head_index, label in zip(interesting_heads, interesting_head_labels):\n"," in_paths = head_path_attr[3 * head_index : 3 * head_index + 3].sum(-1)\n"," out_paths = head_path_attr[:, head_index].sum(-1)\n"," out_paths = einops.rearrange(out_paths, \"(layer_head qkv) -> qkv layer_head\", qkv=3)\n"," all_paths = torch.cat([in_paths, out_paths], dim=0)\n"," all_paths = einops.rearrange(\n"," all_paths,\n"," \"path_type (layer head) -> path_type layer head\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n"," )\n"," imshow(\n"," all_paths,\n"," facet_col=0,\n"," facet_labels=[\n"," \"Query (In)\",\n"," \"Key (In)\",\n"," \"Value (In)\",\n"," \"Query (Out)\",\n"," \"Key (Out)\",\n"," \"Value (Out)\",\n"," ],\n"," title=f\"Input and Output Paths for head {label}\",\n"," yaxis=\"Layer\",\n"," xaxis=\"Head\",\n"," )"]},{"cell_type":"markdown","metadata":{},"source":[" ## Validating Attribution vs Activation Patching\n"," Let's now compare attribution and activation patching. Generally it's a decent approximation! The main place it fails is MLP0 and the residual stream\n"," My fuzzy intuition is that attribution patching works badly for \"big\" things which are poorly modelled as linear approximations, and works well for \"small\" things which are more like incremental changes. Anything involving replacing the embedding is a \"big\" thing, which includes residual streams, and in GPT-2 small MLP0 seems to be used as an \"extended embedding\" (where later layers use MLP0's output instead of the token embedding), so I also count it as big.\n"," See more discussion in the accompanying blog post!\n"]},{"cell_type":"markdown","metadata":{},"source":[" First do some refactoring to make attribution patching more generic. We make an attribution cache, which is an ActivationCache where each element is (clean_act - corrupted_act) * corrupted_grad, so that it's the per-element attribution for each activation. Thanks to linearity, we just compute things by adding stuff up along the relevant dimensions!"]},{"cell_type":"code","execution_count":27,"metadata":{},"outputs":[],"source":["attribution_cache_dict = {}\n","for key in corrupted_grad_cache.cache_dict.keys():\n"," attribution_cache_dict[key] = corrupted_grad_cache.cache_dict[key] * (\n"," clean_cache.cache_dict[key] - corrupted_cache.cache_dict[key]\n"," )\n","attr_cache = ActivationCache(attribution_cache_dict, model)"]},{"cell_type":"markdown","metadata":{},"source":[" By block: For each head we patch the starting residual stream, attention output + MLP output"]},{"cell_type":"code","execution_count":28,"metadata":{},"outputs":[],"source":["str_tokens = model.to_str_tokens(clean_tokens[0])\n","context_length = len(str_tokens)"]},{"cell_type":"code","execution_count":29,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"95a5290e11b64b6a95ef5dd37d027c7a","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"be204ae96db74023b957e592a9a0fde9","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"a2409bc6d2524634a48f4556a6773415","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_block_act_patch_result = patching.get_act_patch_block_every(\n"," model, corrupted_tokens, clean_cache, ioi_metric\n",")\n","imshow(\n"," every_block_act_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n"," title=\"Activation Patching Per Block\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n",")"]},{"cell_type":"code","execution_count":30,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_block_every(attr_cache):\n"," resid_pre_attr = einops.reduce(\n"," attr_cache.stack_activation(\"resid_pre\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," attn_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"attn_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," mlp_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"mlp_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n","\n"," every_block_attr_patch_result = torch.stack(\n"," [resid_pre_attr, attn_out_attr, mlp_out_attr], dim=0\n"," )\n"," return every_block_attr_patch_result\n","\n","\n","every_block_attr_patch_result = get_attr_patch_block_every(attr_cache)\n","imshow(\n"," every_block_attr_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n"," title=\"Attribution Patching Per Block\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n",")"]},{"cell_type":"code","execution_count":31,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(\n"," y=every_block_attr_patch_result.reshape(3, -1),\n"," x=every_block_act_patch_result.reshape(3, -1),\n"," facet_col=0,\n"," facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n"," title=\"Attribution vs Activation Patching Per Block\",\n"," xaxis=\"Activation Patch\",\n"," yaxis=\"Attribution Patch\",\n"," hover=[\n"," f\"Layer {l}, Position {p}, |{str_tokens[p]}|\"\n"," for l in range(model.cfg.n_layers)\n"," for p in range(context_length)\n"," ],\n"," color=einops.repeat(\n"," torch.arange(model.cfg.n_layers), \"layer -> (layer pos)\", pos=context_length\n"," ),\n"," color_continuous_scale=\"Portland\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow."]},{"cell_type":"code","execution_count":32,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"18b2e6b0985b40cd8c0cd1a16ba62975","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"2d034be6501e4c9db1c290b1705e60f8","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"e2f3a429be1745e9a874d2fd4881841d","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"f8e5bf04563c4b0da801f3f5e1b08e7e","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"5ae4c563073843a68df3b590cb8b4dc3","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(\n"," model, corrupted_tokens, clean_cache, ioi_metric\n",")\n","imshow(\n"," every_head_all_pos_act_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Activation Patching Per Head (All Pos)\",\n"," xaxis=\"Head\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n",")"]},{"cell_type":"code","execution_count":33,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_all_pos_every(attr_cache):\n"," head_out_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_q_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_k_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_v_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack(\n"," [\n"," head_out_all_pos_attr,\n"," head_q_all_pos_attr,\n"," head_k_all_pos_attr,\n"," head_v_all_pos_attr,\n"," head_pattern_all_pos_attr,\n"," ]\n"," )\n","\n","\n","every_head_all_pos_attr_patch_result = get_attr_patch_attn_head_all_pos_every(\n"," attr_cache\n",")\n","imshow(\n"," every_head_all_pos_attr_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution Patching Per Head (All Pos)\",\n"," xaxis=\"Head\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n",")"]},{"cell_type":"code","execution_count":34,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(\n"," y=every_head_all_pos_attr_patch_result.reshape(5, -1),\n"," x=every_head_all_pos_act_patch_result.reshape(5, -1),\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution vs Activation Patching Per Head (All Pos)\",\n"," xaxis=\"Activation Patch\",\n"," yaxis=\"Attribution Patch\",\n"," include_diag=True,\n"," hover=head_out_labels,\n"," color=einops.repeat(\n"," torch.arange(model.cfg.n_layers),\n"," \"layer -> (layer head)\",\n"," head=model.cfg.n_heads,\n"," ),\n"," color_continuous_scale=\"Portland\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" We see pretty good results in general, but significant errors for heads L5H5 on query and moderate errors for head L10H7 on query and key, and moderate errors for head L11H10 on key. But each of these is fine for pattern and output. My guess is that the problem is that these have pretty saturated attention on a single token, and the linear approximation is thus not great on the attention calculation here, but I'm not sure. When we plot the attention patterns, we do see this!\n"," Note that the axis labels are for the *first* prompt's tokens, but each facet is a different prompt, so this is somewhat inaccurate. In particular, every odd facet has indirect object and subject in the opposite order (IO first). But otherwise everything lines up between the prompts"]},{"cell_type":"code","execution_count":35,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["graph_tok_labels = [\n"," f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))\n","]\n","imshow(\n"," clean_cache[\"pattern\", 5][:, 5],\n"," x=graph_tok_labels,\n"," y=graph_tok_labels,\n"," facet_col=0,\n"," title=\"Attention for Head L5H5\",\n"," facet_name=\"Prompt\",\n",")\n","imshow(\n"," clean_cache[\"pattern\", 10][:, 7],\n"," x=graph_tok_labels,\n"," y=graph_tok_labels,\n"," facet_col=0,\n"," title=\"Attention for Head L10H7\",\n"," facet_name=\"Prompt\",\n",")\n","imshow(\n"," clean_cache[\"pattern\", 11][:, 10],\n"," x=graph_tok_labels,\n"," y=graph_tok_labels,\n"," facet_col=0,\n"," title=\"Attention for Head L11H10\",\n"," facet_name=\"Prompt\",\n",")\n","\n","\n","# [markdown]"]},{"cell_type":"code","execution_count":36,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"06f39489001845849fbc7446a07066f4","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/2160 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"1c2eba74a11f47d0a78dd78bd0e60b84","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/2160 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"f92f8c8c2ffa4d889def1b4214b6ec04","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/2160 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"65d0fd01f6dc40409c61f5fde0e30470","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/2160 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"52452e90576545f8b12a1bbad5fc7c08","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/2160 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_by_pos_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(\n"," model, corrupted_tokens, clean_cache, ioi_metric\n",")\n","every_head_by_pos_act_patch_result = einops.rearrange(\n"," every_head_by_pos_act_patch_result,\n"," \"act_type layer pos head -> act_type (layer head) pos\",\n",")\n","imshow(\n"," every_head_by_pos_act_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Activation Patching Per Head (By Pos)\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer & Head\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n"," y=head_out_labels,\n",")"]},{"cell_type":"code","execution_count":37,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_by_pos_every(attr_cache):\n"," head_out_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_q_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_k_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_v_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer dest_pos head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack(\n"," [\n"," head_out_by_pos_attr,\n"," head_q_by_pos_attr,\n"," head_k_by_pos_attr,\n"," head_v_by_pos_attr,\n"," head_pattern_by_pos_attr,\n"," ]\n"," )\n","\n","\n","every_head_by_pos_attr_patch_result = get_attr_patch_attn_head_by_pos_every(attr_cache)\n","every_head_by_pos_attr_patch_result = einops.rearrange(\n"," every_head_by_pos_attr_patch_result,\n"," \"act_type layer pos head -> act_type (layer head) pos\",\n",")\n","imshow(\n"," every_head_by_pos_attr_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution Patching Per Head (By Pos)\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer & Head\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n"," y=head_out_labels,\n",")"]},{"cell_type":"code","execution_count":38,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(\n"," y=every_head_by_pos_attr_patch_result.reshape(5, -1),\n"," x=every_head_by_pos_act_patch_result.reshape(5, -1),\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution vs Activation Patching Per Head (by Pos)\",\n"," xaxis=\"Activation Patch\",\n"," yaxis=\"Attribution Patch\",\n"," include_diag=True,\n"," hover=[f\"{label} {tok}\" for label in head_out_labels for tok in graph_tok_labels],\n"," color=einops.repeat(\n"," torch.arange(model.cfg.n_layers),\n"," \"layer -> (layer head pos)\",\n"," head=model.cfg.n_heads,\n"," pos=15,\n"," ),\n"," color_continuous_scale=\"Portland\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Factual Knowledge Patching Example\n"," Incomplete, but maybe of interest!\n"," Note that I have better results with the corrupted prompt as having random words rather than Colosseum."]},{"cell_type":"code","execution_count":39,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-xl into HookedTransformer\n","Tokenized prompt: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Paris']\n"]},{"data":{"text/html":["Performance on answer token:\n","Rank: 0 Logit: 20.73 Prob: 95.80% Token: | Paris|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.73\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m95.80\u001b[0m\u001b[1m% Token: | Paris|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.73 Prob: 95.80% Token: | Paris|\n","Top 1th token. Logit: 16.49 Prob: 1.39% Token: | E|\n","Top 2th token. Logit: 14.69 Prob: 0.23% Token: | the|\n","Top 3th token. Logit: 14.58 Prob: 0.21% Token: | Ć|\n","Top 4th token. Logit: 14.44 Prob: 0.18% Token: | France|\n","Top 5th token. Logit: 14.36 Prob: 0.16% Token: | Mont|\n","Top 6th token. Logit: 13.77 Prob: 0.09% Token: | Le|\n","Top 7th token. Logit: 13.66 Prob: 0.08% Token: | Ang|\n","Top 8th token. Logit: 13.43 Prob: 0.06% Token: | V|\n","Top 9th token. Logit: 13.42 Prob: 0.06% Token: | Stras|\n"]},{"data":{"text/html":["Ranks of the answer tokens: [(' Paris', 0)]\n","\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Paris'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Tokenized prompt: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Rome']\n"]},{"data":{"text/html":["Performance on answer token:\n","Rank: 0 Logit: 20.02 Prob: 83.70% Token: | Rome|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.02\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m83.70\u001b[0m\u001b[1m% Token: | Rome|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.02 Prob: 83.70% Token: | Rome|\n","Top 1th token. Logit: 17.03 Prob: 4.23% Token: | Naples|\n","Top 2th token. Logit: 16.85 Prob: 3.51% Token: | Pompe|\n","Top 3th token. Logit: 16.14 Prob: 1.73% Token: | Ver|\n","Top 4th token. Logit: 15.87 Prob: 1.32% Token: | Florence|\n","Top 5th token. Logit: 14.77 Prob: 0.44% Token: | Roma|\n","Top 6th token. Logit: 14.68 Prob: 0.40% Token: | Milan|\n","Top 7th token. Logit: 14.66 Prob: 0.39% Token: | ancient|\n","Top 8th token. Logit: 14.37 Prob: 0.29% Token: | Pal|\n","Top 9th token. Logit: 14.30 Prob: 0.27% Token: | Constantinople|\n"]},{"data":{"text/html":["Ranks of the answer tokens: [(' Rome', 0)]\n","\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Rome'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"}],"source":["gpt2_xl = HookedTransformer.from_pretrained(\"gpt2-xl\")\n","clean_prompt = \"The Eiffel Tower is located in the city of\"\n","clean_answer = \" Paris\"\n","# corrupted_prompt = \"The red brown fox jumps is located in the city of\"\n","corrupted_prompt = \"The Colosseum is located in the city of\"\n","corrupted_answer = \" Rome\"\n","utils.test_prompt(clean_prompt, clean_answer, gpt2_xl)\n","utils.test_prompt(corrupted_prompt, corrupted_answer, gpt2_xl)"]},{"cell_type":"code","execution_count":40,"metadata":{},"outputs":[],"source":["clean_answer_index = gpt2_xl.to_single_token(clean_answer)\n","corrupted_answer_index = gpt2_xl.to_single_token(corrupted_answer)\n","\n","\n","def factual_logit_diff(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return logits[0, -1, clean_answer_index] - logits[0, -1, corrupted_answer_index]"]},{"cell_type":"code","execution_count":41,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 10.634519577026367\n","Corrupted logit diff: -8.988396644592285\n","Clean Metric: tensor(1., device='cuda:0', grad_fn=)\n","Corrupted Metric: tensor(0., device='cuda:0', grad_fn=)\n"]}],"source":["clean_logits, clean_cache = gpt2_xl.run_with_cache(clean_prompt)\n","CLEAN_LOGIT_DIFF_FACTUAL = factual_logit_diff(clean_logits).item()\n","corrupted_logits, _ = gpt2_xl.run_with_cache(corrupted_prompt)\n","CORRUPTED_LOGIT_DIFF_FACTUAL = factual_logit_diff(corrupted_logits).item()\n","\n","\n","def factual_metric(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return (factual_logit_diff(logits) - CORRUPTED_LOGIT_DIFF_FACTUAL) / (\n"," CLEAN_LOGIT_DIFF_FACTUAL - CORRUPTED_LOGIT_DIFF_FACTUAL\n"," )\n","\n","\n","print(\"Clean logit diff:\", CLEAN_LOGIT_DIFF_FACTUAL)\n","print(\"Corrupted logit diff:\", CORRUPTED_LOGIT_DIFF_FACTUAL)\n","print(\"Clean Metric:\", factual_metric(clean_logits))\n","print(\"Corrupted Metric:\", factual_metric(corrupted_logits))"]},{"cell_type":"code","execution_count":42,"metadata":{},"outputs":[],"source":["# corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(gpt2_xl, corrupted_prompt, factual_metric)"]},{"cell_type":"code","execution_count":43,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Corrupted: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n"]}],"source":["clean_tokens = gpt2_xl.to_tokens(clean_prompt)\n","clean_str_tokens = gpt2_xl.to_str_tokens(clean_prompt)\n","corrupted_tokens = gpt2_xl.to_tokens(corrupted_prompt)\n","corrupted_str_tokens = gpt2_xl.to_str_tokens(corrupted_prompt)\n","print(\"Clean:\", clean_str_tokens)\n","print(\"Corrupted:\", corrupted_str_tokens)"]},{"cell_type":"code","execution_count":44,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"b767eef7a3cd49b9b3cb6e5301463f08","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/48 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def act_patch_residual(clean_cache, corrupted_tokens, model: HookedTransformer, metric):\n"," if len(corrupted_tokens.shape) == 2:\n"," corrupted_tokens = corrupted_tokens[0]\n"," residual_patches = torch.zeros(\n"," (model.cfg.n_layers, len(corrupted_tokens)), device=model.cfg.device\n"," )\n","\n"," def residual_hook(resid_pre, hook, layer, pos):\n"," resid_pre[:, pos, :] = clean_cache[\"resid_pre\", layer][:, pos, :]\n"," return resid_pre\n","\n"," for layer in tqdm.tqdm(range(model.cfg.n_layers)):\n"," for pos in range(len(corrupted_tokens)):\n"," patched_logits = model.run_with_hooks(\n"," corrupted_tokens,\n"," fwd_hooks=[\n"," (\n"," f\"blocks.{layer}.hook_resid_pre\",\n"," partial(residual_hook, layer=layer, pos=pos),\n"," )\n"," ],\n"," )\n"," residual_patches[layer, pos] = metric(patched_logits).item()\n"," return residual_patches\n","\n","\n","residual_act_patch = act_patch_residual(\n"," clean_cache, corrupted_tokens, gpt2_xl, factual_metric\n",")\n","\n","imshow(\n"," residual_act_patch,\n"," title=\"Factual Recall Patching (Residual)\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer\",\n"," x=clean_str_tokens,\n",")"]}],"metadata":{"kernelspec":{"display_name":"base","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.11.8"},"orig_nbformat":4,"vscode":{"interpreter":{"hash":"d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"}}},"nbformat":4,"nbformat_minor":2}
diff --git a/demos/BERT.ipynb b/demos/BERT.ipynb
index 581a6365d..791207f47 100644
--- a/demos/BERT.ipynb
+++ b/demos/BERT.ipynb
@@ -29,45 +29,70 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Running as a Jupyter notebook - intended for development only!\n"
+ "Running as a Jupyter notebook - intended for development only!\n",
+ "The autoreload extension is already loaded. To reload it, use:\n",
+ " %reload_ext autoreload\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_39188/4022418010.py:26: DeprecationWarning:\n",
+ "\n",
+ "`magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
+ "\n",
+ "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_39188/4022418010.py:27: DeprecationWarning:\n",
+ "\n",
+ "`magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
+ "\n"
]
}
],
"source": [
+ "# NBVAL_IGNORE_OUTPUT\n",
+ "import os\n",
+ "\n",
"# Janky code to do different setup when run in a Colab notebook vs VSCode\n",
"DEVELOPMENT_MODE = False\n",
+ "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n",
"try:\n",
" import google.colab\n",
+ "\n",
" IN_COLAB = True\n",
" print(\"Running as a Colab notebook\")\n",
- " %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n",
- " %pip install circuitsvis\n",
- " \n",
+ "\n",
" # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n",
" # # Install another version of node that makes PySvelte work way faster\n",
" # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n",
" # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n",
"except:\n",
" IN_COLAB = False\n",
+ "\n",
+ "if not IN_GITHUB and not IN_COLAB:\n",
" print(\"Running as a Jupyter notebook - intended for development only!\")\n",
" from IPython import get_ipython\n",
"\n",
" ipython = get_ipython()\n",
" # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n",
" ipython.magic(\"load_ext autoreload\")\n",
- " ipython.magic(\"autoreload 2\")"
+ " ipython.magic(\"autoreload 2\")\n",
+ "\n",
+ "if IN_COLAB:\n",
+ " %pip install transformer_lens\n",
+ " %pip install circuitsvis"
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -81,6 +106,7 @@
"source": [
"# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n",
"import plotly.io as pio\n",
+ "\n",
"if IN_COLAB or not DEVELOPMENT_MODE:\n",
" pio.renderers.default = \"colab\"\n",
"else:\n",
@@ -90,40 +116,41 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 3,
+ "execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import circuitsvis as cv\n",
+ "\n",
"# Testing that the library works\n",
"cv.examples.hello(\"Neel\")"
]
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@@ -137,16 +164,16 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 5,
+ "execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@@ -167,26 +194,28 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "WARNING:root:HookedEncoder is still in beta. Please be aware that model preprocessing (e.g. LayerNorm folding) is not yet supported and backward compatibility is not guaranteed.\n"
+ "WARNING:root:Support for BERT in TransformerLens is currently experimental, until such a time when it has feature parity with HookedTransformer and has been tested on real research tasks. Until then, backward compatibility is not guaranteed. Please see the docs for information on the limitations of the current implementation.\n",
+ "If using BERT for interpretability research, keep in mind that BERT has some significant architectural differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning that the last LayerNorm in a block cannot be folded.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Moving model to device: cpu\n",
+ "Moving model to device: mps\n",
"Loaded pretrained model bert-base-cased into HookedTransformer\n"
]
}
],
"source": [
+ "# NBVAL_IGNORE_OUTPUT\n",
"bert = HookedEncoder.from_pretrained(\"bert-base-cased\")\n",
"tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")"
]
@@ -201,7 +230,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -213,7 +242,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 12,
"metadata": {},
"outputs": [
{
@@ -230,7 +259,7 @@
"prediction = tokenizer.decode(logprobs.argmax(dim=-1).item())\n",
"\n",
"print(f\"Prompt: {prompt}\")\n",
- "print(f\"Prediction: \\\"{prediction}\\\"\")"
+ "print(f'Prediction: \"{prediction}\"')"
]
},
{
@@ -258,7 +287,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.10"
+ "version": "3.11.8"
},
"orig_nbformat": 4
},
diff --git a/demos/Grokking_Demo.ipynb b/demos/Grokking_Demo.ipynb
index 7e3792095..473d7ca82 100644
--- a/demos/Grokking_Demo.ipynb
+++ b/demos/Grokking_Demo.ipynb
@@ -53,13 +53,14 @@
],
"source": [
"# Janky code to do different setup when run in a Colab notebook vs VSCode\n",
+ "import os\n",
+ "\n",
"DEVELOPMENT_MODE = True\n",
+ "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n",
"try:\n",
" import google.colab\n",
" IN_COLAB = True\n",
" print(\"Running as a Colab notebook\")\n",
- " %pip install transformer-lens\n",
- " %pip install circuitsvis\n",
" \n",
" # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n",
" # # Install another version of node that makes PySvelte work way faster\n",
@@ -73,7 +74,11 @@
" ipython = get_ipython()\n",
" # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n",
" ipython.magic(\"load_ext autoreload\")\n",
- " ipython.magic(\"autoreload 2\")"
+ " ipython.magic(\"autoreload 2\")\n",
+ " \n",
+ "if IN_COLAB or IN_GITHUB:\n",
+ " %pip install transformer_lens\n",
+ " %pip install circuitsvis"
]
},
{
@@ -154,7 +159,10 @@
" HookedRootModule,\n",
" HookPoint,\n",
") # Hooking utilities\n",
- "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache"
+ "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache\n",
+ "\n",
+ "\n",
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
]
},
{
@@ -281,7 +289,7 @@
}
],
"source": [
- "dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).cuda()\n",
+ "dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).to(device)\n",
"print(dataset[:5])\n",
"print(dataset.shape)"
]
@@ -386,7 +394,7 @@
" d_vocab_out=p,\n",
" n_ctx=3,\n",
" init_weights=True,\n",
- " device=\"cuda\",\n",
+ " device=device,\n",
" seed = 999,\n",
")"
]
@@ -1645,7 +1653,7 @@
" fourier_basis_names.append(f\"Sin {freq}\")\n",
" fourier_basis.append(torch.cos(torch.arange(p)*2 * torch.pi * freq / p))\n",
" fourier_basis_names.append(f\"Cos {freq}\")\n",
- "fourier_basis = torch.stack(fourier_basis, dim=0).cuda()\n",
+ "fourier_basis = torch.stack(fourier_basis, dim=0).to(device)\n",
"fourier_basis = fourier_basis/fourier_basis.norm(dim=-1, keepdim=True)\n",
"imshow(fourier_basis, xaxis=\"Input\", yaxis=\"Component\", y=fourier_basis_names)"
]
@@ -2394,7 +2402,7 @@
}
],
"source": [
- "neuron_freq_norm = torch.zeros(p//2, model.cfg.d_mlp).cuda()\n",
+ "neuron_freq_norm = torch.zeros(p//2, model.cfg.d_mlp).to(device)\n",
"for freq in range(0, p//2):\n",
" for x in [0, 2*(freq+1) - 1, 2*(freq+1)]:\n",
" for y in [0, 2*(freq+1) - 1, 2*(freq+1)]:\n",
@@ -2993,7 +3001,7 @@
" a = torch.arange(p)[:, None, None]\n",
" b = torch.arange(p)[None, :, None]\n",
" c = torch.arange(p)[None, None, :]\n",
- " cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).cuda()\n",
+ " cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).to(device)\n",
" cube_predicted_logits /= cube_predicted_logits.norm()\n",
" coses[freq] = cube_predicted_logits"
]
@@ -3124,7 +3132,7 @@
" a = torch.arange(p)[:, None, None]\n",
" b = torch.arange(p)[None, :, None]\n",
" c = torch.arange(p)[None, None, :]\n",
- " cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).cuda()\n",
+ " cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).to(device)\n",
" cube_predicted_logits /= cube_predicted_logits.norm()\n",
" cos_cube.append(cube_predicted_logits)\n",
"cos_cube = torch.stack(cos_cube, dim=0)\n",
@@ -3486,11 +3494,11 @@
"a = torch.arange(p)[:, None]\n",
"b = torch.arange(p)[None, :]\n",
"for freq in key_freqs:\n",
- " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()\n",
+ " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)\n",
" cos_apb_vec /= cos_apb_vec.norm()\n",
" cos_apb_vec = einops.rearrange(cos_apb_vec, \"a b -> (a b) 1\")\n",
" approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec\n",
- " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()\n",
+ " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)\n",
" sin_apb_vec /= sin_apb_vec.norm()\n",
" sin_apb_vec = einops.rearrange(sin_apb_vec, \"a b -> (a b) 1\")\n",
" approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec\n",
@@ -3555,11 +3563,11 @@
" a = torch.arange(p)[:, None]\n",
" b = torch.arange(p)[None, :]\n",
" for freq in key_freqs:\n",
- " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()\n",
+ " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)\n",
" cos_apb_vec /= cos_apb_vec.norm()\n",
" cos_apb_vec = einops.rearrange(cos_apb_vec, \"a b -> (a b) 1\")\n",
" approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec\n",
- " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()\n",
+ " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)\n",
" sin_apb_vec /= sin_apb_vec.norm()\n",
" sin_apb_vec = einops.rearrange(sin_apb_vec, \"a b -> (a b) 1\")\n",
" approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec\n",
@@ -3718,11 +3726,11 @@
"a = torch.arange(p)[:, None]\n",
"b = torch.arange(p)[None, :]\n",
"for freq in key_freqs:\n",
- " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()\n",
+ " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)\n",
" cos_apb_vec /= cos_apb_vec.norm()\n",
" cos_apb_vec = einops.rearrange(cos_apb_vec, \"a b -> (a b) 1\")\n",
" approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec\n",
- " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()\n",
+ " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)\n",
" sin_apb_vec /= sin_apb_vec.norm()\n",
" sin_apb_vec = einops.rearrange(sin_apb_vec, \"a b -> (a b) 1\")\n",
" approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec\n",
@@ -3765,11 +3773,11 @@
" a = torch.arange(p)[:, None]\n",
" b = torch.arange(p)[None, :]\n",
" for freq in key_freqs:\n",
- " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()\n",
+ " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)\n",
" cos_apb_vec /= cos_apb_vec.norm()\n",
" cos_apb_vec = einops.rearrange(cos_apb_vec, \"a b -> (a b) 1\")\n",
" approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec\n",
- " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()\n",
+ " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)\n",
" sin_apb_vec /= sin_apb_vec.norm()\n",
" sin_apb_vec = einops.rearrange(sin_apb_vec, \"a b -> (a b) 1\")\n",
" approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec\n",
diff --git a/demos/HookedSAETransformerDemo.ipynb b/demos/HookedSAETransformerDemo.ipynb
new file mode 100644
index 000000000..77d0d7c37
--- /dev/null
+++ b/demos/HookedSAETransformerDemo.ipynb
@@ -0,0 +1,18616 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "
\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# HookedSAETransformer Demo"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "HookedSAETransformer is a lightweight extension of HookedTransformer that allows you to \"splice in\" Sparse Autoencoders. This makes it easy to do exploratory analysis such as: running inference with SAEs attached, caching SAE feature activations, and intervening on SAE activations with hooks.\n",
+ "\n",
+ "I (Connor Kissane) implemented this to accelerate research on [Attention SAEs](https://www.lesswrong.com/posts/DtdzGwFh9dCfsekZZ/sparse-autoencoders-work-on-attention-layer-outputs) based on suggestions from Arthur Conmy and Neel Nanda, and found that it was well worth the time and effort. I hope other researchers will also find the library useful! This notebook demonstrates how it works and how to use it."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Running as a Jupyter notebook - intended for development only!\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/tmp/ipykernel_10435/2185356984.py:16: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
+ " ipython.magic(\"load_ext autoreload\")\n",
+ "/tmp/ipykernel_10435/2185356984.py:17: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
+ " ipython.magic(\"autoreload 2\")\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Janky code to do different setup when run in a Colab notebook vs VSCode\n",
+ "DEVELOPMENT_MODE = False\n",
+ "try:\n",
+ " import google.colab\n",
+ " IN_COLAB = True\n",
+ " print(\"Running as a Colab notebook\")\n",
+ " %pip install git+https://github.com/ckkissane/TransformerLens@hooked-sae-transformer\n",
+ " \n",
+ "except:\n",
+ " IN_COLAB = False\n",
+ " print(\"Running as a Jupyter notebook - intended for development only!\")\n",
+ " from IPython import get_ipython\n",
+ "\n",
+ " ipython = get_ipython()\n",
+ " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n",
+ " ipython.magic(\"load_ext autoreload\")\n",
+ " ipython.magic(\"autoreload 2\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import transformer_lens.utils as utils\n",
+ "\n",
+ "import plotly.express as px\n",
+ "import tqdm\n",
+ "from functools import partial\n",
+ "import einops\n",
+ "import plotly.graph_objects as go\n",
+ "\n",
+ "update_layout_set = {\n",
+ " \"xaxis_range\", \"yaxis_range\", \"hovermode\", \"xaxis_title\", \"yaxis_title\", \"colorbar\", \"colorscale\", \"coloraxis\",\n",
+ " \"title_x\", \"bargap\", \"bargroupgap\", \"xaxis_tickformat\", \"yaxis_tickformat\", \"title_y\", \"legend_title_text\", \"xaxis_showgrid\",\n",
+ " \"xaxis_gridwidth\", \"xaxis_gridcolor\", \"yaxis_showgrid\", \"yaxis_gridwidth\"\n",
+ "}\n",
+ "\n",
+ "def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n",
+ " if isinstance(tensor, list):\n",
+ " tensor = torch.stack(tensor)\n",
+ " kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}\n",
+ " kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}\n",
+ " if \"facet_labels\" in kwargs_pre:\n",
+ " facet_labels = kwargs_pre.pop(\"facet_labels\")\n",
+ " else:\n",
+ " facet_labels = None\n",
+ " if \"color_continuous_scale\" not in kwargs_pre:\n",
+ " kwargs_pre[\"color_continuous_scale\"] = \"RdBu\"\n",
+ " fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0,labels={\"x\":xaxis, \"y\":yaxis}, **kwargs_pre).update_layout(**kwargs_post)\n",
+ " if facet_labels:\n",
+ " for i, label in enumerate(facet_labels):\n",
+ " fig.layout.annotations[i]['text'] = label\n",
+ "\n",
+ " fig.show(renderer)\n",
+ "\n",
+ "def scatter(x, y, xaxis=\"\", yaxis=\"\", caxis=\"\", renderer=None, return_fig=False, **kwargs):\n",
+ " x = utils.to_numpy(x)\n",
+ " y = utils.to_numpy(y)\n",
+ " fig = px.scatter(y=y, x=x, labels={\"x\":xaxis, \"y\":yaxis, \"color\":caxis}, **kwargs)\n",
+ " if return_fig:\n",
+ " return fig\n",
+ " fig.show(renderer)\n",
+ "\n",
+ "from typing import List\n",
+ "def show_avg_logit_diffs(x_axis: List[str], per_prompt_logit_diffs: List[torch.tensor]):\n",
+ "\n",
+ "\n",
+ " y_data = [per_prompt_logit_diff.mean().item() for per_prompt_logit_diff in per_prompt_logit_diffs]\n",
+ " error_y_data = [per_prompt_logit_diff.std().item() for per_prompt_logit_diff in per_prompt_logit_diffs] \n",
+ "\n",
+ " fig = go.Figure(data=[go.Bar(\n",
+ " x=x_axis,\n",
+ " y=y_data,\n",
+ " error_y=dict(\n",
+ " type='data', # specifies that the actual values are given\n",
+ " array=error_y_data, # the magnitudes of the errors\n",
+ " visible=True # make error bars visible\n",
+ " ),\n",
+ " )])\n",
+ "\n",
+ " # Customize layout\n",
+ " fig.update_layout(title_text=f'Logit Diff after Interventions',\n",
+ " xaxis_title_text='Intervention',\n",
+ " yaxis_title_text='Logit diff',\n",
+ " plot_bgcolor='white')\n",
+ "\n",
+ " # Show the figure\n",
+ " fig.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
+ "torch.set_grad_enabled(False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Loading and Running Models"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Just like a [HookedTransformer](https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Loading-and-Running-Models), we can load in any model that's supported in TransformerLens with the `HookedSAETransformer.from_pretrained(MODEL_NAME)`. In this demo we'll use GPT-2 small."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using pad_token, but it is not set yet.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loaded pretrained model gpt2-small into HookedTransformer\n",
+ "Moving model to device: cuda\n"
+ ]
+ }
+ ],
+ "source": [
+ "from transformer_lens import HookedSAETransformer\n",
+ "model: HookedSAETransformer = HookedSAETransformer.from_pretrained(\"gpt2-small\").to(device)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "By default HookedSAETransformer will behave exactly like a HookedTransformer. We'll explore the main features of HookedSAETransformer on the classic IOI task, so let's first sanity check that GPT2-small can do the IOI task without any SAEs attached:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['When John and Mary went to the shops, Mary gave the bag to', 'When John and Mary went to the shops, John gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']\n",
+ "[(' John', ' Mary'), (' Mary', ' John'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]\n"
+ ]
+ }
+ ],
+ "source": [
+ "prompt_format = [\n",
+ " \"When John and Mary went to the shops,{} gave the bag to\",\n",
+ " \"When Tom and James went to the park,{} gave the ball to\",\n",
+ " \"When Dan and Sid went to the shops,{} gave an apple to\",\n",
+ " \"After Martin and Amy went to the park,{} gave a drink to\",\n",
+ "]\n",
+ "names = [\n",
+ " (\" John\", \" Mary\",),\n",
+ " (\" Tom\", \" James\"),\n",
+ " (\" Dan\", \" Sid\"),\n",
+ " (\" Martin\", \" Amy\"),\n",
+ "]\n",
+ "# List of prompts\n",
+ "prompts = []\n",
+ "# List of answers, in the format (correct, incorrect)\n",
+ "answers = []\n",
+ "# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)\n",
+ "answer_tokens = []\n",
+ "for i in range(len(prompt_format)):\n",
+ " for j in range(2):\n",
+ " answers.append((names[i][j], names[i][1 - j]))\n",
+ " answer_tokens.append(\n",
+ " (\n",
+ " model.to_single_token(answers[-1][0]),\n",
+ " model.to_single_token(answers[-1][1]),\n",
+ " )\n",
+ " )\n",
+ " # Insert the *incorrect* answer to the prompt, making the correct answer the indirect object.\n",
+ " prompts.append(prompt_format[i].format(answers[-1][1]))\n",
+ "answer_tokens = torch.tensor(answer_tokens).to(device)\n",
+ "print(prompts)\n",
+ "print(answers)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Original average logit diff: 3.5518884658813477\n",
+ "Original per prompt logit diff: tensor([3.2016, 3.3367, 2.7095, 3.7975, 1.7204, 5.2812, 2.6008, 5.7674],\n",
+ " device='cuda:0')\n"
+ ]
+ }
+ ],
+ "source": [
+ "def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):\n",
+ " # Only the final logits are relevant for the answer\n",
+ " final_logits = logits[:, -1, :]\n",
+ " answer_logits = final_logits.gather(dim=-1, index=answer_tokens)\n",
+ " answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]\n",
+ " if per_prompt:\n",
+ " return answer_logit_diff\n",
+ " else:\n",
+ " return answer_logit_diff.mean()\n",
+ " \n",
+ "tokens = model.to_tokens(prompts, prepend_bos=True)\n",
+ "original_logits, cache = model.run_with_cache(tokens)\n",
+ "original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)\n",
+ "print(f\"Original average logit diff: {original_average_logit_diff}\")\n",
+ "original_per_prompt_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)\n",
+ "print(f\"Original per prompt logit diff: {original_per_prompt_logit_diff}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# HookedSAEs"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In order to use the key features of HookedSAETransformer, we first need to load in SAEs.\n",
+ "\n",
+ "HookedSAE is an SAE class we've implemented to have TransformerLens hooks around the SAE activations. While we will use it out of the box, it is designed to be hackable: you can copy and paste the HookedSAE class into a notebook and completely change the architecture / hook names, and as long as it reconstructs the activations, it should still work.\n",
+ "\n",
+ "You can initialize a HookedSAE with a HookedSAEConfig:\n",
+ "```\n",
+ "cfg = HookedSAEConfig(\n",
+ " d_sae (int): The size of the dictionary.\n",
+ " d_in (int): The dimension of the input activations for the SAE\n",
+ " hook_name (str): The hook name of the activation the SAE was trained on (eg. blocks.0.attn.hook_z)\n",
+ ")\n",
+ "hooked_sae = HookedSAE(cfg)\n",
+ "```\n",
+ "\n",
+ "Note you'll likely have to write some basic conversion code to match configs / state dicts to the HookedSAE when loading in an open sourced SAE (eg from HuggingFace). We'll use our GPT-2 Small [Attention SAEs](https://www.alignmentforum.org/posts/FSTRedtjuHa4Gfdbr/attention-saes-scale-to-gpt-2-small) to demonstrate. For convenience, we'll load in all of our attention SAEs from HuggingFace, convert them to HookedSAEs, and store them in a dictionary that maps each hook_name (str) to the corresponding HookedSAE.\n",
+ "\n",
+ "\n",
+ "\n",
+ "Later we'll show how to add HookedSAEs to the HookedSAETransformer (replacing model activations with their SAE reconstructions). When you add a HookedSAE, HookedSAETransformer just treats this a black box that takes some activation as an input, and outputs a tensor of the same shape. \n",
+ "\n",
+ "With this in mind, the HookedSAE is designed to be simple and hackable. Think of it as a convenient default class that you can copy and edit. As long as it takes a TransformerLens activation as input, and outputs a tensor of the same shape, you should be able to add it to your HookedSAETransformer.\n",
+ "\n",
+ "You probably don't even need to use the HookedSAE class, although it's recommended. The sae can be any pytorch module that takes in some activation at hook_name and outputs a tensor of the same shape. The two assumptions that HookedSAETransformer makes when adding SAEs are:\n",
+ "1. The SAE class has a cfg attribute, sae.cfg.hook_name (str), for the activation that the SAE was trained to reconstruct (in TransformerLens notation e.g. 'blocks.0.attn.hook_z')\n",
+ "2. The SAE takes that activation as input, and outputs a tensor of the same shape.\n",
+ "\n",
+ "The main benefit of HookedSAE is that it's a subclass of HookedRootModule, so we can add hooks to SAE activations. This makes it easy to leverage existing TL functionality like run_with_cache and run_with_hooks with SAEs.\n",
+ "\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "dict_keys(['blocks.0.attn.hook_z', 'blocks.1.attn.hook_z', 'blocks.2.attn.hook_z', 'blocks.3.attn.hook_z', 'blocks.4.attn.hook_z', 'blocks.5.attn.hook_z', 'blocks.6.attn.hook_z', 'blocks.7.attn.hook_z', 'blocks.8.attn.hook_z', 'blocks.9.attn.hook_z', 'blocks.10.attn.hook_z', 'blocks.11.attn.hook_z'])\n"
+ ]
+ }
+ ],
+ "source": [
+ "from transformer_lens import HookedSAE, HookedSAEConfig\n",
+ "from transformer_lens.utils import download_file_from_hf\n",
+ "def attn_sae_cfg_to_hooked_sae_cfg(attn_sae_cfg):\n",
+ " new_cfg = {\n",
+ " \"d_sae\": attn_sae_cfg[\"dict_size\"],\n",
+ " \"d_in\": attn_sae_cfg[\"act_size\"],\n",
+ " \"hook_name\": attn_sae_cfg[\"act_name\"],\n",
+ " }\n",
+ " return HookedSAEConfig.from_dict(new_cfg)\n",
+ "\n",
+ "auto_encoder_runs = [\n",
+ " \"gpt2-small_L0_Hcat_z_lr1.20e-03_l11.80e+00_ds24576_bs4096_dc1.00e-06_rsanthropic_rie25000_nr4_v9\",\n",
+ " \"gpt2-small_L1_Hcat_z_lr1.20e-03_l18.00e-01_ds24576_bs4096_dc1.00e-06_rsanthropic_rie25000_nr4_v5\",\n",
+ " \"gpt2-small_L2_Hcat_z_lr1.20e-03_l11.00e+00_ds24576_bs4096_dc1.00e-06_rsanthropic_rie25000_nr4_v4\",\n",
+ " \"gpt2-small_L3_Hcat_z_lr1.20e-03_l19.00e-01_ds24576_bs4096_dc1.00e-06_rsanthropic_rie25000_nr4_v9\",\n",
+ " \"gpt2-small_L4_Hcat_z_lr1.20e-03_l11.10e+00_ds24576_bs4096_dc1.00e-06_rsanthropic_rie25000_nr4_v7\",\n",
+ " \"gpt2-small_L5_Hcat_z_lr1.20e-03_l11.00e+00_ds49152_bs4096_dc1.00e-06_rsanthropic_rie25000_nr4_v9\",\n",
+ " \"gpt2-small_L6_Hcat_z_lr1.20e-03_l11.10e+00_ds24576_bs4096_dc1.00e-06_rsanthropic_rie25000_nr4_v9\",\n",
+ " \"gpt2-small_L7_Hcat_z_lr1.20e-03_l11.10e+00_ds49152_bs4096_dc1.00e-06_rsanthropic_rie25000_nr4_v9\",\n",
+ " \"gpt2-small_L8_Hcat_z_lr1.20e-03_l11.30e+00_ds24576_bs4096_dc1.00e-05_rsanthropic_rie25000_nr4_v6\",\n",
+ " \"gpt2-small_L9_Hcat_z_lr1.20e-03_l11.20e+00_ds24576_bs4096_dc1.00e-06_rsanthropic_rie25000_nr4_v9\",\n",
+ " \"gpt2-small_L10_Hcat_z_lr1.20e-03_l11.30e+00_ds24576_bs4096_dc1.00e-05_rsanthropic_rie25000_nr4_v9\",\n",
+ " \"gpt2-small_L11_Hcat_z_lr1.20e-03_l13.00e+00_ds24576_bs4096_dc3.16e-06_rsanthropic_rie25000_nr4_v9\",\n",
+ "]\n",
+ "\n",
+ "hf_repo = \"ckkissane/attn-saes-gpt2-small-all-layers\"\n",
+ "\n",
+ "hook_name_to_sae = {}\n",
+ "for auto_encoder_run in auto_encoder_runs:\n",
+ " attn_sae_cfg = download_file_from_hf(hf_repo, f\"{auto_encoder_run}_cfg.json\")\n",
+ " cfg = attn_sae_cfg_to_hooked_sae_cfg(attn_sae_cfg)\n",
+ " \n",
+ " state_dict = download_file_from_hf(hf_repo, f\"{auto_encoder_run}.pt\", force_is_torch=True)\n",
+ " \n",
+ " hooked_sae = HookedSAE(cfg)\n",
+ " hooked_sae.load_state_dict(state_dict)\n",
+ " \n",
+ " hook_name_to_sae[cfg.hook_name] = hooked_sae\n",
+ "print(hook_name_to_sae.keys())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Run with SAEs"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The key feature of HookedSAETransformer is being able to \"splice in\" SAEs, replacing model activations with their SAE reconstructions. \n",
+ "\n",
+ "To run a forward pass with SAEs attached use `model.run_with_saes(tokens, saes=saes)`, where saes is a list of HookedSAEs that you want to add for just this forward pass. These will be reset immediately after the forward pass, returning the model to its original state.\n",
+ "\n",
+ "I expect this to be particularly useful for evaluating SAEs (eg [Gurnee](https://www.alignmentforum.org/posts/rZPiuFxESMxCDHe4B/sae-reconstruction-errors-are-empirically-pathological)), including evaluating how SAE reconstructions affect the models ability to perform certain tasks (eg [Makelov et al.](https://openreview.net/forum?id=MHIX9H8aYF&referrer=%5Bthe%20profile%20of%20Neel%20Nanda%5D(%2Fprofile%3Fid%3D~Neel_Nanda1)))\n",
+ "\n",
+ "To demonstrate, let's use `run_with_saes` to evaluate many combinations of SAEs on different cross sections of the IOI circuit.\n",
+ "\n",
+ "\n",
+ "\n",
+ "Under the hood, TransformerLens already wraps activations with a HookPoint object. HookPoint is a dummy pytorch module that acts as an identity function by default, and is only used to access the activation with PyTorch hooks. When you run_with_saes, HookedSAETransformer temporarily replaces these HookPoints with the given HookedSAEs, which take the activation as input and replace it with the HookedSAE output (the reconstructed activation) during the forward pass. \n",
+ "\n",
+ "Since HookedSAE is a subclass of HookedRootModule, we also are able to add PyTorch hooks to the corresponding SAE activations, as we'll use later.\n",
+ "\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ " \n",
+ " "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.plotly.v1+json": {
+ "config": {
+ "plotlyServerURL": "https://plot.ly"
+ },
+ "data": [
+ {
+ "error_y": {
+ "array": [
+ 1.3678313493728638,
+ 1.6846193075180054,
+ 1.3839112520217896,
+ 1.6782633066177368,
+ 0.8939867615699768,
+ 2.2888872623443604
+ ],
+ "type": "data",
+ "visible": true
+ },
+ "type": "bar",
+ "x": [
+ "Clean Baseline",
+ "With SAEs L[0, 3]",
+ "With SAEs L[2, 4]",
+ "With SAEs L[5, 6]",
+ "With SAEs L[7, 8]",
+ "With SAEs L[9, 10, 11]"
+ ],
+ "y": [
+ 3.5518884658813477,
+ 2.580843925476074,
+ 3.3641157150268555,
+ 3.3500614166259766,
+ 1.5024915933609009,
+ 7.072007179260254
+ ]
+ }
+ ],
+ "layout": {
+ "plot_bgcolor": "white",
+ "template": {
+ "data": {
+ "bar": [
+ {
+ "error_x": {
+ "color": "#2a3f5f"
+ },
+ "error_y": {
+ "color": "#2a3f5f"
+ },
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "bar"
+ }
+ ],
+ "barpolar": [
+ {
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "barpolar"
+ }
+ ],
+ "carpet": [
+ {
+ "aaxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "baxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "type": "carpet"
+ }
+ ],
+ "choropleth": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "choropleth"
+ }
+ ],
+ "contour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "contour"
+ }
+ ],
+ "contourcarpet": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "contourcarpet"
+ }
+ ],
+ "heatmap": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmap"
+ }
+ ],
+ "heatmapgl": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmapgl"
+ }
+ ],
+ "histogram": [
+ {
+ "marker": {
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "histogram"
+ }
+ ],
+ "histogram2d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2d"
+ }
+ ],
+ "histogram2dcontour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2dcontour"
+ }
+ ],
+ "mesh3d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "mesh3d"
+ }
+ ],
+ "parcoords": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "parcoords"
+ }
+ ],
+ "pie": [
+ {
+ "automargin": true,
+ "type": "pie"
+ }
+ ],
+ "scatter": [
+ {
+ "fillpattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ },
+ "type": "scatter"
+ }
+ ],
+ "scatter3d": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatter3d"
+ }
+ ],
+ "scattercarpet": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattercarpet"
+ }
+ ],
+ "scattergeo": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergeo"
+ }
+ ],
+ "scattergl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergl"
+ }
+ ],
+ "scattermapbox": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattermapbox"
+ }
+ ],
+ "scatterpolar": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolar"
+ }
+ ],
+ "scatterpolargl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolargl"
+ }
+ ],
+ "scatterternary": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterternary"
+ }
+ ],
+ "surface": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "surface"
+ }
+ ],
+ "table": [
+ {
+ "cells": {
+ "fill": {
+ "color": "#EBF0F8"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "header": {
+ "fill": {
+ "color": "#C8D4E3"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "type": "table"
+ }
+ ]
+ },
+ "layout": {
+ "annotationdefaults": {
+ "arrowcolor": "#2a3f5f",
+ "arrowhead": 0,
+ "arrowwidth": 1
+ },
+ "autotypenumbers": "strict",
+ "coloraxis": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "colorscale": {
+ "diverging": [
+ [
+ 0,
+ "#8e0152"
+ ],
+ [
+ 0.1,
+ "#c51b7d"
+ ],
+ [
+ 0.2,
+ "#de77ae"
+ ],
+ [
+ 0.3,
+ "#f1b6da"
+ ],
+ [
+ 0.4,
+ "#fde0ef"
+ ],
+ [
+ 0.5,
+ "#f7f7f7"
+ ],
+ [
+ 0.6,
+ "#e6f5d0"
+ ],
+ [
+ 0.7,
+ "#b8e186"
+ ],
+ [
+ 0.8,
+ "#7fbc41"
+ ],
+ [
+ 0.9,
+ "#4d9221"
+ ],
+ [
+ 1,
+ "#276419"
+ ]
+ ],
+ "sequential": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "sequentialminus": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ]
+ },
+ "colorway": [
+ "#636efa",
+ "#EF553B",
+ "#00cc96",
+ "#ab63fa",
+ "#FFA15A",
+ "#19d3f3",
+ "#FF6692",
+ "#B6E880",
+ "#FF97FF",
+ "#FECB52"
+ ],
+ "font": {
+ "color": "#2a3f5f"
+ },
+ "geo": {
+ "bgcolor": "white",
+ "lakecolor": "white",
+ "landcolor": "#E5ECF6",
+ "showlakes": true,
+ "showland": true,
+ "subunitcolor": "white"
+ },
+ "hoverlabel": {
+ "align": "left"
+ },
+ "hovermode": "closest",
+ "mapbox": {
+ "style": "light"
+ },
+ "paper_bgcolor": "white",
+ "plot_bgcolor": "#E5ECF6",
+ "polar": {
+ "angularaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "radialaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "scene": {
+ "xaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "yaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "zaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ }
+ },
+ "shapedefaults": {
+ "line": {
+ "color": "#2a3f5f"
+ }
+ },
+ "ternary": {
+ "aaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "baxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "caxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "title": {
+ "x": 0.05
+ },
+ "xaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ },
+ "yaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ }
+ }
+ },
+ "title": {
+ "text": "Logit Diff after Interventions"
+ },
+ "xaxis": {
+ "title": {
+ "text": "Intervention"
+ }
+ },
+ "yaxis": {
+ "title": {
+ "text": "Logit diff"
+ }
+ }
+ }
+ },
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "all_layers = [[0, 3], [2, 4], [5,6], [7, 8], [9, 10, 11]]\n",
+ "x_axis = ['Clean Baseline']\n",
+ "per_prompt_logit_diffs = [\n",
+ " original_per_prompt_logit_diff, \n",
+ "]\n",
+ "\n",
+ "for layers in all_layers:\n",
+ " hooked_saes = [hook_name_to_sae[utils.get_act_name('z', layer)] for layer in layers]\n",
+ " logits_with_saes = model.run_with_saes(tokens, saes=hooked_saes)\n",
+ " average_logit_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens)\n",
+ " per_prompt_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens, per_prompt=True)\n",
+ " \n",
+ " x_axis.append(f\"With SAEs L{layers}\")\n",
+ " per_prompt_logit_diffs.append(per_prompt_diff_with_saes)\n",
+ "\n",
+ "show_avg_logit_diffs(x_axis, per_prompt_logit_diffs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Run with cache (with SAEs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We often want to see what SAE features are active on a given prompt. With HookedSAETransformer, you can cache HookedSAE activations (and all the other standard activations) with `logits, cache = model.run_with_cache_with_saes(tokens, saes=saes)`. Just as `run_with_saes` is a wapper around the standard forward pass, `run_with_cache_with_saes` is a wrapper around `run_with_cache`, and will also only add these saes for one forward pass before returning the model to its original state. \n",
+ "\n",
+ "To access SAE activations from the cache, the corresponding hook names will generally be the HookedTransformer hook_name (eg blocks.5.attn.hook_z) + the hookedSAE hooked name preceeded by a period (eg .hook_sae_acts_post).\n",
+ "\n",
+ "`run_with_cache_with_saes` makes it easy to explore which SAE features are active across any input. Let's explore the active features at the S2 position for our L5 Attention SAE across all of our IOI prompts:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.plotly.v1+json": {
+ "config": {
+ "plotlyServerURL": "https://plot.ly"
+ },
+ "data": [
+ {
+ "coloraxis": "coloraxis",
+ "hovertemplate": "Feature Id: %{x}
Prompt: %{y}
color: %{z}",
+ "name": "0",
+ "type": "heatmap",
+ "x": [
+ "46",
+ "345",
+ "702",
+ "1372",
+ "1755",
+ "1965",
+ "2457",
+ "2496",
+ "2646",
+ "2999",
+ "3047",
+ "4569",
+ "5132",
+ "5203",
+ "5508",
+ "5940",
+ "6144",
+ "6371",
+ "6515",
+ "6558",
+ "6812",
+ "7092",
+ "7515",
+ "7907",
+ "8063",
+ "8623",
+ "8737",
+ "8768",
+ "9096",
+ "9102",
+ "9186",
+ "9463",
+ "9746",
+ "9913",
+ "10581",
+ "10894",
+ "12109",
+ "12485",
+ "12764",
+ "12866",
+ "13063",
+ "13624",
+ "13707",
+ "13777",
+ "14844",
+ "15050",
+ "15170",
+ "15696",
+ "16178",
+ "16892",
+ "17156",
+ "17259",
+ "17497",
+ "17854",
+ "18043",
+ "18210",
+ "18318",
+ "18385",
+ "18440",
+ "18920",
+ "19183",
+ "19263",
+ "19442",
+ "19524",
+ "19573",
+ "20838",
+ "21151",
+ "21657",
+ "22108",
+ "23578",
+ "24091",
+ "24217",
+ "25792",
+ "26373",
+ "26410",
+ "27535",
+ "27787",
+ "27811",
+ "27960",
+ "28061",
+ "28241",
+ "28242",
+ "28254",
+ "28349",
+ "28977",
+ "29027",
+ "29482",
+ "29603",
+ "29700",
+ "29822",
+ "32177",
+ "32920",
+ "33320",
+ "33730",
+ "33966",
+ "34177",
+ "34334",
+ "34947",
+ "35403",
+ "35425",
+ "35579",
+ "35665",
+ "35815",
+ "36109",
+ "36172",
+ "36451",
+ "36767",
+ "36917",
+ "38570",
+ "39962",
+ "40409",
+ "40418",
+ "40661",
+ "41162",
+ "41185",
+ "41552",
+ "42024",
+ "42161",
+ "42437",
+ "42577",
+ "42882",
+ "42931",
+ "43035",
+ "43414",
+ "43643",
+ "43662",
+ "44203",
+ "44256",
+ "44452",
+ "44652",
+ "45179",
+ "45814",
+ "45984",
+ "46880",
+ "47117",
+ "47170",
+ "47231",
+ "47313",
+ "47680",
+ "48063",
+ "48703"
+ ],
+ "xaxis": "x",
+ "yaxis": "y",
+ "z": [
+ [
+ 0.23392018675804138,
+ 0,
+ 0,
+ 0.04335343837738037,
+ 0.44275617599487305,
+ 0,
+ 0,
+ 0.07259953022003174,
+ 0,
+ 0.6985604763031006,
+ 1.262436866760254,
+ 0,
+ 0.04656928777694702,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.45666736364364624,
+ 0.10434150695800781,
+ 0.30980953574180603,
+ 0.3319076895713806,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.7836596965789795,
+ 0,
+ 0,
+ 0.142583429813385,
+ 0.046830952167510986,
+ 0.3180348575115204,
+ 0.2927079200744629,
+ 0.12267106771469116,
+ 2.5688514709472656,
+ 0.2917236089706421,
+ 0.12333670258522034,
+ 0,
+ 0.1778419017791748,
+ 0,
+ 0.023626387119293213,
+ 0.02943563461303711,
+ 0,
+ 0.048882365226745605,
+ 0.13625454902648926,
+ 0,
+ 0,
+ 0.2634885013103485,
+ 0,
+ 0,
+ 0,
+ 0.21662655472755432,
+ 0,
+ 0,
+ 0,
+ 0.06997489929199219,
+ 0.006345987319946289,
+ 0,
+ 0.16112494468688965,
+ 0.4190089702606201,
+ 0,
+ 2.3819468021392822,
+ 1.0431660413742065,
+ 0,
+ 0.08364987373352051,
+ 0,
+ 0,
+ 0.3451769948005676,
+ 0.7391350865364075,
+ 0.4456520080566406,
+ 0.0019606351852416992,
+ 0.39914217591285706,
+ 0,
+ 0,
+ 0,
+ 0.29958274960517883,
+ 0.44243645668029785,
+ 0,
+ 0.1259920299053192,
+ 0.8349504470825195,
+ 0.37993764877319336,
+ 0.2633737325668335,
+ 0.08324140310287476,
+ 0,
+ 0,
+ 0.10421907901763916,
+ 0,
+ 0,
+ 0,
+ 0.36972635984420776,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.5578295588493347,
+ 0,
+ 0.9233021140098572,
+ 0,
+ 0.10010790824890137,
+ 0,
+ 0.45082613825798035,
+ 0,
+ 0,
+ 0,
+ 0.21043556928634644,
+ 0.12981292605400085,
+ 0.11557984352111816,
+ 0,
+ 0,
+ 0.17571094632148743,
+ 0.2823787331581116,
+ 0.1122598648071289,
+ 0,
+ 0,
+ 0.012049257755279541,
+ 0,
+ 0,
+ 0,
+ 2.417463541030884,
+ 0.0547795295715332,
+ 0.05216425657272339,
+ 0,
+ 0.6592545509338379,
+ 0.003663182258605957,
+ 0,
+ 0,
+ 0.04937589168548584,
+ 0.025814831256866455,
+ 0,
+ 0.8019273281097412,
+ 0,
+ 0.10218703746795654
+ ],
+ [
+ 0,
+ 0,
+ 0.3230956792831421,
+ 0,
+ 0,
+ 0,
+ 0.026041746139526367,
+ 0.31818556785583496,
+ 0,
+ 0.4900796413421631,
+ 0.04911249876022339,
+ 0,
+ 0,
+ 0.07309412956237793,
+ 0.08089971542358398,
+ 0.17180073261260986,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 2.3956947326660156,
+ 0,
+ 0,
+ 0.15781426429748535,
+ 0,
+ 0.5073252320289612,
+ 0.21765804290771484,
+ 0,
+ 0,
+ 1.618570327758789,
+ 0,
+ 0.22485831379890442,
+ 0.0830467939376831,
+ 0.7055595517158508,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.23371747136116028,
+ 0,
+ 0,
+ 0.6983060240745544,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.30831730365753174,
+ 0,
+ 0.417669415473938,
+ 0.05292201042175293,
+ 0,
+ 0,
+ 0,
+ 1.3391070365905762,
+ 0,
+ 0.41352108120918274,
+ 0,
+ 0,
+ 0,
+ 0.037178993225097656,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.2702980041503906,
+ 0,
+ 0,
+ 0.18745100498199463,
+ 1.3330132961273193,
+ 0.5793700814247131,
+ 0.33893001079559326,
+ 0,
+ 0.11196631193161011,
+ 1.720167636871338,
+ 0.17581266164779663,
+ 0.42567259073257446,
+ 0,
+ 0,
+ 0.23682871460914612,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.8280882835388184,
+ 0.1617840826511383,
+ 0,
+ 0.13557660579681396,
+ 0.5832244157791138,
+ 0,
+ 0,
+ 0.03256487846374512,
+ 0,
+ 0,
+ 0.03892314434051514,
+ 0,
+ 0,
+ 0,
+ 0.30978846549987793,
+ 0,
+ 0,
+ 0.36915141344070435,
+ 0,
+ 0.5477294325828552,
+ 0,
+ 0,
+ 0.06339260935783386,
+ 0.1851767599582672,
+ 0.5839155912399292,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.12337607145309448,
+ 0,
+ 0,
+ 1.0378936529159546,
+ 0,
+ 0,
+ 0,
+ 0.01616498827934265,
+ 0.20259439945220947,
+ 0,
+ 0,
+ 0.3087460398674011,
+ 0.618510365486145,
+ 0.24435847997665405,
+ 0,
+ 0.4668591022491455,
+ 0.1788468360900879,
+ 0.200361967086792,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.7064645290374756
+ ],
+ [
+ 0.2921750843524933,
+ 0,
+ 0,
+ 0.2805737257003784,
+ 0,
+ 0,
+ 0,
+ 0.3694216012954712,
+ 0,
+ 1.1156601905822754,
+ 1.2807728052139282,
+ 0,
+ 0.09175515174865723,
+ 0,
+ 0,
+ 0,
+ 0.10458803176879883,
+ 0,
+ 0.021218180656433105,
+ 0,
+ 0,
+ 0.01699376106262207,
+ 0.09601330757141113,
+ 0.054788172245025635,
+ 0,
+ 0.030488133430480957,
+ 0.021512210369110107,
+ 0.2717320919036865,
+ 0.29357004165649414,
+ 0.6420693397521973,
+ 0.05249035358428955,
+ 0,
+ 0.06201601028442383,
+ 0,
+ 0.4122554659843445,
+ 1.821354866027832,
+ 0.01981794834136963,
+ 0,
+ 0.14063221216201782,
+ 0.05093127489089966,
+ 0,
+ 0.32148706912994385,
+ 0.15257668495178223,
+ 2.418062686920166,
+ 0.17348229885101318,
+ 0.08421656489372253,
+ 0,
+ 0.4551248550415039,
+ 0,
+ 0.015430927276611328,
+ 0.24434363842010498,
+ 0,
+ 0.06232607364654541,
+ 0,
+ 0.04422914981842041,
+ 0.8720088005065918,
+ 0.3721686899662018,
+ 0,
+ 0,
+ 0,
+ 0.340120404958725,
+ 0,
+ 0,
+ 0.07813769578933716,
+ 0,
+ 0.0882720947265625,
+ 0.19706517457962036,
+ 0.4056885242462158,
+ 0.19529414176940918,
+ 0,
+ 2.928431510925293,
+ 1.1402223110198975,
+ 0,
+ 0.026796698570251465,
+ 0.0033188462257385254,
+ 0,
+ 0.3370524048805237,
+ 0.47657889127731323,
+ 0,
+ 0.10358679294586182,
+ 0.27619925141334534,
+ 0,
+ 0,
+ 0,
+ 0.40909066796302795,
+ 0.2599871754646301,
+ 0,
+ 0.275011271238327,
+ 0.5349323749542236,
+ 0.07697033882141113,
+ 0.17431437969207764,
+ 0,
+ 0,
+ 0,
+ 0.09000074863433838,
+ 0,
+ 0,
+ 0,
+ 0.276567280292511,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.5655339360237122,
+ 0,
+ 0.8971189856529236,
+ 0,
+ 0.5199201107025146,
+ 0,
+ 0.6301102638244629,
+ 0.013657361268997192,
+ 0.04469645023345947,
+ 0.038062095642089844,
+ 0.4305816888809204,
+ 0,
+ 0.04173767566680908,
+ 0,
+ 0,
+ 0,
+ 0.8985729217529297,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.08318889141082764,
+ 0.006362795829772949,
+ 2.069222927093506,
+ 0,
+ 0.7068352103233337,
+ 0,
+ 0.8527798652648926,
+ 0,
+ 0,
+ 0.4707651138305664,
+ 0,
+ 0,
+ 0,
+ 0.7790955305099487,
+ 0.021227538585662842,
+ 0.01846003532409668
+ ],
+ [
+ 0,
+ 0,
+ 0.2200499176979065,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.2433047890663147,
+ 0.2504638135433197,
+ 0.712148904800415,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.1410943865776062,
+ 0,
+ 0,
+ 0,
+ 0.11292147636413574,
+ 0,
+ 0,
+ 2.360842704772949,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.2830760478973389,
+ 0,
+ 0,
+ 0,
+ 0.6308119893074036,
+ 0,
+ 0.4040885865688324,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.5223236680030823,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.23784160614013672,
+ 0,
+ 0.04762387275695801,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.5758676528930664,
+ 0.01025208830833435,
+ 0.24556085467338562,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.1104614734649658,
+ 1.079118251800537,
+ 0,
+ 0,
+ 0.14462929964065552,
+ 1.9186956882476807,
+ 0,
+ 0.30735498666763306,
+ 0,
+ 0,
+ 0.07669633626937866,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.3975048065185547,
+ 0,
+ 0,
+ 0.3461639881134033,
+ 0.5062156915664673,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.19610454142093658,
+ 0.218009352684021,
+ 0,
+ 0,
+ 0.07953745126724243,
+ 0,
+ 0.1416093111038208,
+ 0,
+ 0,
+ 0,
+ 0.18305465579032898,
+ 0.10310900211334229,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.45315277576446533,
+ 0,
+ 0,
+ 0,
+ 0.09076884388923645,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.04246491193771362,
+ 0,
+ 0.1807355284690857,
+ 0,
+ 0.3002055883407593,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0
+ ],
+ [
+ 0.02005404233932495,
+ 0,
+ 0,
+ 0.07601284980773926,
+ 0,
+ 0,
+ 0,
+ 0.012166053056716919,
+ 0,
+ 1.0662918090820312,
+ 1.4810535907745361,
+ 0,
+ 0.014786958694458008,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.1491186022758484,
+ 0,
+ 0,
+ 0,
+ 0.38226866722106934,
+ 0.43110355734825134,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.6819074153900146,
+ 0,
+ 0.7939910888671875,
+ 0.28643298149108887,
+ 0,
+ 0,
+ 0.011532962322235107,
+ 0,
+ 1.2869157791137695,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.16446048021316528,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.03375712037086487,
+ 0,
+ 0,
+ 0,
+ 0.1915181577205658,
+ 0,
+ 0,
+ 0.10225892066955566,
+ 0,
+ 0,
+ 0,
+ 0.7338485717773438,
+ 0,
+ 0,
+ 1.3715617656707764,
+ 1.6115869283676147,
+ 0,
+ 0.7128411531448364,
+ 0,
+ 0,
+ 0.2161598801612854,
+ 0.5098914504051208,
+ 0,
+ 0,
+ 0.04084053635597229,
+ 0,
+ 0,
+ 0,
+ 0.17978456616401672,
+ 0,
+ 0,
+ 0.1365671455860138,
+ 0.27122950553894043,
+ 0.2945059537887573,
+ 0.2824629545211792,
+ 0,
+ 0,
+ 0,
+ 0.0464092493057251,
+ 0,
+ 0,
+ 0.04672741889953613,
+ 0.6179839968681335,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.045598745346069336,
+ 0,
+ 1.0172381401062012,
+ 0,
+ 0.07242608070373535,
+ 0,
+ 0.5165215730667114,
+ 0,
+ 0,
+ 0,
+ 0.5004003047943115,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.3409433960914612,
+ 0,
+ 0.1579979658126831,
+ 0.09901612997055054,
+ 0,
+ 0,
+ 0,
+ 0,
+ 2.413944721221924,
+ 0,
+ 0.20971286296844482,
+ 0.07062971591949463,
+ 0.26070594787597656,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.020640969276428223,
+ 1.0534553527832031,
+ 0,
+ 0
+ ],
+ [
+ 0,
+ 0,
+ 0.046907246112823486,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.20885008573532104,
+ 0.25957152247428894,
+ 1.0767037868499756,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.23976856470108032,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 2.762990951538086,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.29466086626052856,
+ 0,
+ 0,
+ 0.09433537721633911,
+ 1.2446393966674805,
+ 0,
+ 0,
+ 0,
+ 0.6668079495429993,
+ 0,
+ 0.7482341527938843,
+ 0,
+ 0,
+ 0.005075186491012573,
+ 0,
+ 0,
+ 0.4049275517463684,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.09314888715744019,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.4028928279876709,
+ 0,
+ 0.3687801659107208,
+ 0,
+ 0.10555410385131836,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.066054105758667,
+ 1.4596349000930786,
+ 0,
+ 0,
+ 0,
+ 2.3358588218688965,
+ 0,
+ 0.5390753149986267,
+ 0,
+ 0,
+ 0.12931063771247864,
+ 0,
+ 0.10619288682937622,
+ 0,
+ 0,
+ 0,
+ 0.41271400451660156,
+ 0,
+ 0,
+ 0.23865878582000732,
+ 0.7501264810562134,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.2947666645050049,
+ 0,
+ 0,
+ 0,
+ 0.05958199501037598,
+ 0.20450782775878906,
+ 0,
+ 0,
+ 0,
+ 0.13838836550712585,
+ 0.13835513591766357,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.45820748805999756,
+ 0,
+ 0,
+ 0,
+ 0.19962045550346375,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.20416772365570068,
+ 0.46223968267440796,
+ 0,
+ 0.22815394401550293,
+ 0,
+ 0.1125795841217041,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.3023688793182373
+ ],
+ [
+ 0.28365251421928406,
+ 0,
+ 0,
+ 0.41595208644866943,
+ 0,
+ 0.15376341342926025,
+ 0,
+ 0.22517156600952148,
+ 0,
+ 0.7871096134185791,
+ 1.3084614276885986,
+ 0.2012956142425537,
+ 0,
+ 0,
+ 0,
+ 0.2532406449317932,
+ 0.009012699127197266,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.7235959768295288,
+ 0.021468758583068848,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.8338297009468079,
+ 0.3022422790527344,
+ 0.6702529191970825,
+ 0.5416026711463928,
+ 0,
+ 0,
+ 0,
+ 0.2034381628036499,
+ 1.9052581787109375,
+ 0,
+ 0.23752644658088684,
+ 0,
+ 0,
+ 0,
+ 0.8470145463943481,
+ 0,
+ 2.820002555847168,
+ 0,
+ 0.16275432705879211,
+ 0.06714236736297607,
+ 0.12017238140106201,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.486280620098114,
+ 0,
+ 0,
+ 0.3096342086791992,
+ 0.3064201772212982,
+ 0,
+ 0.09773910045623779,
+ 0,
+ 0.4613642394542694,
+ 0,
+ 0.021892428398132324,
+ 0,
+ 0.18887782096862793,
+ 0.18538141250610352,
+ 0,
+ 0.42975664138793945,
+ 0.9873132705688477,
+ 0,
+ 2.163774013519287,
+ 1.2928048372268677,
+ 0,
+ 0.2320784330368042,
+ 0.0062233805656433105,
+ 0,
+ 1.2478563785552979,
+ 0.5479208827018738,
+ 0,
+ 0.06501156091690063,
+ 0.3741762936115265,
+ 0,
+ 0,
+ 0.31712013483047485,
+ 0.5228050947189331,
+ 0.3981531858444214,
+ 0,
+ 0,
+ 0.4854400157928467,
+ 0.3341655731201172,
+ 0.39207732677459717,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.3316766023635864,
+ 0,
+ 0,
+ 0.33435362577438354,
+ 0.1380615234375,
+ 0.7183249592781067,
+ 0.041296958923339844,
+ 0.7634149193763733,
+ 0,
+ 0.4028007984161377,
+ 0,
+ 0.6915435791015625,
+ 0,
+ 0,
+ 0,
+ 0.3831353187561035,
+ 0.05798754096031189,
+ 0.15244710445404053,
+ 0,
+ 0.03230410814285278,
+ 0.2039397656917572,
+ 0.6142292022705078,
+ 0.15542924404144287,
+ 0.07628917694091797,
+ 0.0812273919582367,
+ 0.15177401900291443,
+ 0.10224854946136475,
+ 0,
+ 0,
+ 2.8106069564819336,
+ 0.3994237184524536,
+ 0.6397127509117126,
+ 0,
+ 0.8949670791625977,
+ 0,
+ 0,
+ 0.18832790851593018,
+ 0.1450880765914917,
+ 0,
+ 0,
+ 0.6900937557220459,
+ 0,
+ 0.14745783805847168
+ ],
+ [
+ 0.12055802345275879,
+ 0.023864269256591797,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.3327372670173645,
+ 0.1789897382259369,
+ 1.1445300579071045,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.4361664652824402,
+ 0.09996795654296875,
+ 0.10051405429840088,
+ 0,
+ 0.4030296802520752,
+ 0.06672021746635437,
+ 0.6339577436447144,
+ 3.3947582244873047,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.9711236357688904,
+ 0,
+ 0.38066884875297546,
+ 0.4158353805541992,
+ 1.5344438552856445,
+ 0,
+ 0.19816407561302185,
+ 0,
+ 0.6646860241889954,
+ 0,
+ 0.16733816266059875,
+ 0,
+ 0,
+ 0,
+ 0.322623074054718,
+ 0,
+ 0.7314171195030212,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.043955981731414795,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.9436180591583252,
+ 0,
+ 0.29259607195854187,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.1570979356765747,
+ 0,
+ 0,
+ 0,
+ 1.1782727241516113,
+ 1.2431498765945435,
+ 0.32878363132476807,
+ 0,
+ 0.419150173664093,
+ 2.3304405212402344,
+ 0.8566346764564514,
+ 0,
+ 0,
+ 0,
+ 0.3841046392917633,
+ 0.10476112365722656,
+ 0,
+ 0.18140661716461182,
+ 0,
+ 0,
+ 0.6665420532226562,
+ 0,
+ 0,
+ 0.22877633571624756,
+ 0.9225524663925171,
+ 0,
+ 0.15886402130126953,
+ 0,
+ 0,
+ 0.02094721794128418,
+ 0,
+ 0,
+ 0,
+ 0.3046541213989258,
+ 0.2845715284347534,
+ 0,
+ 0,
+ 0.4244043231010437,
+ 0.164473295211792,
+ 0.30073386430740356,
+ 0.7123112678527832,
+ 0.1730642318725586,
+ 0,
+ 0.4041661322116852,
+ 0.39166414737701416,
+ 0,
+ 0,
+ 0.2103893756866455,
+ 0.007811635732650757,
+ 0.010994672775268555,
+ 0.03914850950241089,
+ 0,
+ 0,
+ 0.8430832624435425,
+ 0,
+ 0,
+ 0,
+ 0.15830591320991516,
+ 0.29398930072784424,
+ 0,
+ 0,
+ 0,
+ 0.5994948148727417,
+ 0.1704254150390625,
+ 0,
+ 0.4673898220062256,
+ 0,
+ 0.3204514980316162,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.8447363376617432
+ ]
+ ]
+ }
+ ],
+ "layout": {
+ "coloraxis": {
+ "cmid": 0,
+ "colorscale": [
+ [
+ 0,
+ "rgb(103,0,31)"
+ ],
+ [
+ 0.1,
+ "rgb(178,24,43)"
+ ],
+ [
+ 0.2,
+ "rgb(214,96,77)"
+ ],
+ [
+ 0.3,
+ "rgb(244,165,130)"
+ ],
+ [
+ 0.4,
+ "rgb(253,219,199)"
+ ],
+ [
+ 0.5,
+ "rgb(247,247,247)"
+ ],
+ [
+ 0.6,
+ "rgb(209,229,240)"
+ ],
+ [
+ 0.7,
+ "rgb(146,197,222)"
+ ],
+ [
+ 0.8,
+ "rgb(67,147,195)"
+ ],
+ [
+ 0.9,
+ "rgb(33,102,172)"
+ ],
+ [
+ 1,
+ "rgb(5,48,97)"
+ ]
+ ]
+ },
+ "template": {
+ "data": {
+ "bar": [
+ {
+ "error_x": {
+ "color": "#2a3f5f"
+ },
+ "error_y": {
+ "color": "#2a3f5f"
+ },
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "bar"
+ }
+ ],
+ "barpolar": [
+ {
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "barpolar"
+ }
+ ],
+ "carpet": [
+ {
+ "aaxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "baxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "type": "carpet"
+ }
+ ],
+ "choropleth": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "choropleth"
+ }
+ ],
+ "contour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "contour"
+ }
+ ],
+ "contourcarpet": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "contourcarpet"
+ }
+ ],
+ "heatmap": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmap"
+ }
+ ],
+ "heatmapgl": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmapgl"
+ }
+ ],
+ "histogram": [
+ {
+ "marker": {
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "histogram"
+ }
+ ],
+ "histogram2d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2d"
+ }
+ ],
+ "histogram2dcontour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2dcontour"
+ }
+ ],
+ "mesh3d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "mesh3d"
+ }
+ ],
+ "parcoords": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "parcoords"
+ }
+ ],
+ "pie": [
+ {
+ "automargin": true,
+ "type": "pie"
+ }
+ ],
+ "scatter": [
+ {
+ "fillpattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ },
+ "type": "scatter"
+ }
+ ],
+ "scatter3d": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatter3d"
+ }
+ ],
+ "scattercarpet": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattercarpet"
+ }
+ ],
+ "scattergeo": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergeo"
+ }
+ ],
+ "scattergl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergl"
+ }
+ ],
+ "scattermapbox": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattermapbox"
+ }
+ ],
+ "scatterpolar": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolar"
+ }
+ ],
+ "scatterpolargl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolargl"
+ }
+ ],
+ "scatterternary": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterternary"
+ }
+ ],
+ "surface": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "surface"
+ }
+ ],
+ "table": [
+ {
+ "cells": {
+ "fill": {
+ "color": "#EBF0F8"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "header": {
+ "fill": {
+ "color": "#C8D4E3"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "type": "table"
+ }
+ ]
+ },
+ "layout": {
+ "annotationdefaults": {
+ "arrowcolor": "#2a3f5f",
+ "arrowhead": 0,
+ "arrowwidth": 1
+ },
+ "autotypenumbers": "strict",
+ "coloraxis": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "colorscale": {
+ "diverging": [
+ [
+ 0,
+ "#8e0152"
+ ],
+ [
+ 0.1,
+ "#c51b7d"
+ ],
+ [
+ 0.2,
+ "#de77ae"
+ ],
+ [
+ 0.3,
+ "#f1b6da"
+ ],
+ [
+ 0.4,
+ "#fde0ef"
+ ],
+ [
+ 0.5,
+ "#f7f7f7"
+ ],
+ [
+ 0.6,
+ "#e6f5d0"
+ ],
+ [
+ 0.7,
+ "#b8e186"
+ ],
+ [
+ 0.8,
+ "#7fbc41"
+ ],
+ [
+ 0.9,
+ "#4d9221"
+ ],
+ [
+ 1,
+ "#276419"
+ ]
+ ],
+ "sequential": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "sequentialminus": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ]
+ },
+ "colorway": [
+ "#636efa",
+ "#EF553B",
+ "#00cc96",
+ "#ab63fa",
+ "#FFA15A",
+ "#19d3f3",
+ "#FF6692",
+ "#B6E880",
+ "#FF97FF",
+ "#FECB52"
+ ],
+ "font": {
+ "color": "#2a3f5f"
+ },
+ "geo": {
+ "bgcolor": "white",
+ "lakecolor": "white",
+ "landcolor": "#E5ECF6",
+ "showlakes": true,
+ "showland": true,
+ "subunitcolor": "white"
+ },
+ "hoverlabel": {
+ "align": "left"
+ },
+ "hovermode": "closest",
+ "mapbox": {
+ "style": "light"
+ },
+ "paper_bgcolor": "white",
+ "plot_bgcolor": "#E5ECF6",
+ "polar": {
+ "angularaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "radialaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "scene": {
+ "xaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "yaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "zaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ }
+ },
+ "shapedefaults": {
+ "line": {
+ "color": "#2a3f5f"
+ }
+ },
+ "ternary": {
+ "aaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "baxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "caxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "title": {
+ "x": 0.05
+ },
+ "xaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ },
+ "yaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ }
+ }
+ },
+ "title": {
+ "text": "Activations of Live SAE features at L5 S2 position per prompt"
+ },
+ "xaxis": {
+ "anchor": "y",
+ "constrain": "domain",
+ "domain": [
+ 0,
+ 1
+ ],
+ "scaleanchor": "y",
+ "title": {
+ "text": "Feature Id"
+ }
+ },
+ "yaxis": {
+ "anchor": "x",
+ "autorange": "reversed",
+ "constrain": "domain",
+ "domain": [
+ 0,
+ 1
+ ],
+ "title": {
+ "text": "Prompt"
+ }
+ }
+ }
+ },
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "layer, s2_pos = 5, 10\n",
+ "saes = [hook_name_to_sae[utils.get_act_name('z', layer)]]\n",
+ "_, cache = model.run_with_cache_with_saes(tokens, saes=saes)\n",
+ "sae_acts = cache[utils.get_act_name('z', layer) + \".hook_sae_acts_post\"][:, s2_pos, :]\n",
+ "live_feature_mask = sae_acts > 0\n",
+ "live_feature_union = live_feature_mask.any(dim=0)\n",
+ "\n",
+ "imshow(\n",
+ " sae_acts[:, live_feature_union],\n",
+ " title = \"Activations of Live SAE features at L5 S2 position per prompt\",\n",
+ " xaxis=\"Feature Id\", yaxis=\"Prompt\",\n",
+ " x=list(map(str, live_feature_union.nonzero().flatten().tolist())),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We could then interpret some of the commonly activating features, like 7515, using [neuronpedia](https://www.neuronpedia.org/gpt2-small/5-att-kk/7515)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Run with Hooks (with SAEs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "HookedSAETransformer also allows you to intervene on SAE activations with `model.run_with_hooks_with_saes(tokens, saes=saes, fwd_hooks=fwd_hooks)`. This works exactly like the standard TransformerLens `run_with_hooks`, with the added benefit that we can now intervene on SAE activations from the HookedSAEs that we splice in. Along the same lines as `run_with_saes` and `run_with_cache_with_saes`, this will only temporarily add SAEs before returning the model to it's original state. \n",
+ "\n",
+ "I expect this to be useful when doing circuit analysis with SAEs. To demonstrate, let's zero ablate individual layer 5 attention SAE features to localize causally important features."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|āāāāāāāāāā| 141/141 [00:04<00:00, 28.85it/s]\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.plotly.v1+json": {
+ "config": {
+ "plotlyServerURL": "https://plot.ly"
+ },
+ "data": [
+ {
+ "coloraxis": "coloraxis",
+ "hovertemplate": "Feature Idx: %{x}
Prompt Idx: %{y}
color: %{z}",
+ "name": "0",
+ "type": "heatmap",
+ "x": [
+ "46",
+ "345",
+ "702",
+ "1372",
+ "1755",
+ "1965",
+ "2457",
+ "2496",
+ "2646",
+ "2999",
+ "3047",
+ "4569",
+ "5132",
+ "5203",
+ "5508",
+ "5940",
+ "6144",
+ "6371",
+ "6515",
+ "6558",
+ "6812",
+ "7092",
+ "7515",
+ "7907",
+ "8063",
+ "8623",
+ "8737",
+ "8768",
+ "9096",
+ "9102",
+ "9186",
+ "9463",
+ "9746",
+ "9913",
+ "10581",
+ "10894",
+ "12109",
+ "12485",
+ "12764",
+ "12866",
+ "13063",
+ "13624",
+ "13707",
+ "13777",
+ "14844",
+ "15050",
+ "15170",
+ "15696",
+ "16178",
+ "16892",
+ "17156",
+ "17259",
+ "17497",
+ "17854",
+ "18043",
+ "18210",
+ "18318",
+ "18385",
+ "18440",
+ "18920",
+ "19183",
+ "19263",
+ "19442",
+ "19524",
+ "19573",
+ "20838",
+ "21151",
+ "21657",
+ "22108",
+ "23578",
+ "24091",
+ "24217",
+ "25792",
+ "26373",
+ "26410",
+ "27535",
+ "27787",
+ "27811",
+ "27960",
+ "28061",
+ "28241",
+ "28242",
+ "28254",
+ "28349",
+ "28977",
+ "29027",
+ "29482",
+ "29603",
+ "29700",
+ "29822",
+ "32177",
+ "32920",
+ "33320",
+ "33730",
+ "33966",
+ "34177",
+ "34334",
+ "34947",
+ "35403",
+ "35425",
+ "35579",
+ "35665",
+ "35815",
+ "36109",
+ "36172",
+ "36451",
+ "36767",
+ "36917",
+ "38570",
+ "39962",
+ "40409",
+ "40418",
+ "40661",
+ "41162",
+ "41185",
+ "41552",
+ "42024",
+ "42161",
+ "42437",
+ "42577",
+ "42882",
+ "42931",
+ "43035",
+ "43414",
+ "43643",
+ "43662",
+ "44203",
+ "44256",
+ "44452",
+ "44652",
+ "45179",
+ "45814",
+ "45984",
+ "46880",
+ "47117",
+ "47170",
+ "47231",
+ "47313",
+ "47680",
+ "48063",
+ "48703"
+ ],
+ "xaxis": "x",
+ "yaxis": "y",
+ "z": [
+ [
+ 0.006268501281738281,
+ 0,
+ 0,
+ 0.0016260147094726562,
+ 0.0011568069458007812,
+ 0,
+ 0,
+ -0.000400543212890625,
+ 0,
+ -0.024961471557617188,
+ -0.062079429626464844,
+ 0,
+ 0.00041866302490234375,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.017510414123535156,
+ -0.0021581649780273438,
+ -0.0012054443359375,
+ -0.006356239318847656,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.025524139404296875,
+ 0,
+ 0,
+ -0.0037746429443359375,
+ 0.0004291534423828125,
+ -0.000194549560546875,
+ 0.002796173095703125,
+ 0.0001850128173828125,
+ -0.056549072265625,
+ -0.0029163360595703125,
+ -0.004790306091308594,
+ 0,
+ 0.0005321502685546875,
+ 0,
+ 0.00049591064453125,
+ -0.0008335113525390625,
+ 0,
+ -0.00299072265625,
+ -0.00185394287109375,
+ 0,
+ 0,
+ 0.011702537536621094,
+ 0,
+ 0,
+ 0,
+ -0.003353118896484375,
+ 0,
+ 0,
+ 0,
+ 0.00048828125,
+ -0.000213623046875,
+ 0,
+ -0.0062160491943359375,
+ -0.007611274719238281,
+ 0,
+ 0.06644821166992188,
+ -0.025884628295898438,
+ 0,
+ -0.0001964569091796875,
+ 0,
+ 0,
+ 0.03233909606933594,
+ -0.05103874206542969,
+ 0.0003414154052734375,
+ -0.0000057220458984375,
+ -0.0027713775634765625,
+ 0,
+ 0,
+ 0,
+ -0.02438068389892578,
+ 0.027306556701660156,
+ 0,
+ -0.0036411285400390625,
+ 0.018335342407226562,
+ 0.010270118713378906,
+ 0.0120849609375,
+ 0.0013589859008789062,
+ 0,
+ 0,
+ -0.0033817291259765625,
+ 0,
+ 0,
+ 0,
+ -0.014057159423828125,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.008485794067382812,
+ 0,
+ 0.021463394165039062,
+ 0,
+ -0.002582550048828125,
+ 0,
+ 0.012966156005859375,
+ 0,
+ 0,
+ 0,
+ -0.0077991485595703125,
+ 0.002948760986328125,
+ 0.0069675445556640625,
+ 0,
+ 0,
+ 0.0058879852294921875,
+ -0.050632476806640625,
+ 0.001888275146484375,
+ 0,
+ 0,
+ -0.0005016326904296875,
+ 0,
+ 0,
+ 0,
+ -0.5087032318115234,
+ -0.0006818771362304688,
+ 0.0017566680908203125,
+ 0,
+ -0.02089214324951172,
+ -0.0000286102294921875,
+ 0,
+ 0,
+ -0.000446319580078125,
+ 0.0008115768432617188,
+ 0,
+ 0.017795562744140625,
+ 0,
+ -0.008462905883789062
+ ],
+ [
+ 0,
+ 0,
+ 0.0042266845703125,
+ 0,
+ 0,
+ 0,
+ -0.00130462646484375,
+ -0.01946258544921875,
+ 0,
+ 0.03999900817871094,
+ 0.013164520263671875,
+ 0,
+ 0,
+ -0.000522613525390625,
+ -0.0028820037841796875,
+ -0.003643035888671875,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.24383163452148438,
+ 0,
+ 0,
+ -0.0009517669677734375,
+ 0,
+ 0.05923271179199219,
+ 0.00897979736328125,
+ 0,
+ 0,
+ -0.00617218017578125,
+ 0,
+ 0.011938095092773438,
+ 0.005764007568359375,
+ 0.08927345275878906,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.027820587158203125,
+ 0,
+ 0,
+ 0.021488189697265625,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.016414642333984375,
+ 0,
+ -0.012666702270507812,
+ 0.002353668212890625,
+ 0,
+ 0,
+ 0,
+ 0.10541152954101562,
+ 0,
+ 0.010334014892578125,
+ 0,
+ 0,
+ 0,
+ 0.0012111663818359375,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.047576904296875,
+ 0,
+ 0,
+ -0.006137847900390625,
+ 0.04940223693847656,
+ 0.014007568359375,
+ 0.030317306518554688,
+ 0,
+ -0.0012969970703125,
+ -0.12521743774414062,
+ 0.0023975372314453125,
+ 0.04903602600097656,
+ 0,
+ 0,
+ 0.019681930541992188,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.07957077026367188,
+ -0.00966644287109375,
+ 0,
+ 0.011016845703125,
+ 0.05775642395019531,
+ 0,
+ 0,
+ 0.00060272216796875,
+ 0,
+ 0,
+ 0.00067138671875,
+ 0,
+ 0,
+ 0,
+ -0.0041980743408203125,
+ 0,
+ 0,
+ 0.020341873168945312,
+ 0,
+ -0.02782440185546875,
+ 0,
+ 0,
+ 0.001705169677734375,
+ 0.0035266876220703125,
+ 0.0060558319091796875,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.0004119873046875,
+ 0,
+ 0,
+ 0.10181617736816406,
+ 0,
+ 0,
+ 0,
+ 0.0001964569091796875,
+ 0.009687423706054688,
+ 0,
+ 0,
+ 0.10214805603027344,
+ 0.03883934020996094,
+ 0.028743743896484375,
+ 0,
+ -0.009389877319335938,
+ -0.0005168914794921875,
+ -0.0241851806640625,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.0089263916015625
+ ],
+ [
+ 0.013156890869140625,
+ 0,
+ 0,
+ 0.00737762451171875,
+ 0,
+ 0,
+ 0,
+ -0.011926651000976562,
+ 0,
+ -0.1016092300415039,
+ -0.2541160583496094,
+ 0,
+ 0.0026063919067382812,
+ 0,
+ 0,
+ 0,
+ 0.011356353759765625,
+ 0,
+ -0.0003261566162109375,
+ 0,
+ 0,
+ 0.000354766845703125,
+ 0.018985748291015625,
+ -0.0010251998901367188,
+ 0,
+ -0.0016918182373046875,
+ 0.00087738037109375,
+ -0.03418159484863281,
+ -0.022599220275878906,
+ -0.031129837036132812,
+ -0.0039033889770507812,
+ 0,
+ 0.002773284912109375,
+ 0,
+ -0.0497589111328125,
+ 0.0000972747802734375,
+ 0.00002002716064453125,
+ 0,
+ -0.000766754150390625,
+ 0.000133514404296875,
+ 0,
+ 0.00109100341796875,
+ 0.00045013427734375,
+ -0.15281009674072266,
+ -0.0027723312377929688,
+ -0.008421897888183594,
+ 0,
+ 0.024028778076171875,
+ 0,
+ 0.0008792877197265625,
+ -0.0008392333984375,
+ 0,
+ -0.014632225036621094,
+ 0,
+ -0.0009860992431640625,
+ -0.0236358642578125,
+ 0.021772384643554688,
+ 0,
+ 0,
+ 0,
+ -0.016798019409179688,
+ 0,
+ 0,
+ -0.0022678375244140625,
+ 0,
+ -0.0038995742797851562,
+ 0.006114959716796875,
+ -0.05572509765625,
+ -0.008089065551757812,
+ 0,
+ 0.21244430541992188,
+ -0.06043434143066406,
+ 0,
+ 0.0001010894775390625,
+ 0.00023651123046875,
+ 0,
+ 0.062018394470214844,
+ -0.08936023712158203,
+ 0,
+ -0.005387306213378906,
+ -0.001903533935546875,
+ 0,
+ 0,
+ 0,
+ -0.08661651611328125,
+ 0.020143508911132812,
+ 0,
+ -0.01000213623046875,
+ 0.008556365966796875,
+ -0.0023040771484375,
+ 0.0063114166259765625,
+ 0,
+ 0,
+ 0,
+ -0.01030731201171875,
+ 0,
+ 0,
+ 0,
+ -0.037540435791015625,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.018768310546875,
+ 0,
+ 0.06715202331542969,
+ 0,
+ -0.01861572265625,
+ 0,
+ 0.02222919464111328,
+ -0.0029458999633789062,
+ -0.0005445480346679688,
+ -0.001338958740234375,
+ -0.0246734619140625,
+ 0,
+ 0.0014019012451171875,
+ 0,
+ 0,
+ 0,
+ -0.34259986877441406,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.002704620361328125,
+ -0.0001850128173828125,
+ -0.9704685211181641,
+ 0,
+ -0.01996612548828125,
+ 0,
+ -0.0199432373046875,
+ 0,
+ 0,
+ 0.025028228759765625,
+ 0,
+ 0,
+ 0,
+ 0.05844879150390625,
+ -0.00006961822509765625,
+ -0.002410888671875
+ ],
+ [
+ 0,
+ 0,
+ -0.001018524169921875,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0172882080078125,
+ 0.05738639831542969,
+ 0.12810707092285156,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0056362152099609375,
+ 0,
+ 0,
+ 0,
+ 0.009425163269042969,
+ 0,
+ 0,
+ -0.2314128875732422,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.057198524475097656,
+ 0,
+ 0,
+ 0,
+ 0.13471412658691406,
+ 0,
+ 0.08182525634765625,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.006465911865234375,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.0039052963256835938,
+ 0,
+ -0.0010318756103515625,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.062198638916015625,
+ 0.0000057220458984375,
+ -0.001708984375,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.03947257995605469,
+ 0.1576099395751953,
+ 0,
+ 0,
+ 0.00009822845458984375,
+ -0.25530242919921875,
+ 0,
+ 0.061611175537109375,
+ 0,
+ 0,
+ 0.0061016082763671875,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.079315185546875,
+ 0,
+ 0,
+ 0.04389762878417969,
+ 0.06207084655761719,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.0064945220947265625,
+ -0.009065628051757812,
+ 0,
+ 0,
+ 0.0025882720947265625,
+ 0,
+ 0.0033740997314453125,
+ 0,
+ 0,
+ 0,
+ 0.014276504516601562,
+ -0.011219978332519531,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.023397445678710938,
+ 0,
+ 0,
+ 0,
+ 0.0096435546875,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.007327079772949219,
+ 0,
+ 0.00238037109375,
+ 0,
+ -0.04556846618652344,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0
+ ],
+ [
+ -0.0007219314575195312,
+ 0,
+ 0,
+ -0.001102447509765625,
+ 0,
+ 0,
+ 0,
+ -0.00047397613525390625,
+ 0,
+ -0.02031421661376953,
+ -0.18840694427490234,
+ 0,
+ 0.0009374618530273438,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0014810562133789062,
+ 0,
+ 0,
+ 0,
+ -0.01897907257080078,
+ -0.012393951416015625,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.007961273193359375,
+ 0,
+ 0.006266593933105469,
+ 0.022070884704589844,
+ 0,
+ 0,
+ -0.00022220611572265625,
+ 0,
+ -0.08554744720458984,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.00211334228515625,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0006618499755859375,
+ 0,
+ 0,
+ 0,
+ 0.00042629241943359375,
+ 0,
+ 0,
+ -0.0023794174194335938,
+ 0,
+ 0,
+ 0,
+ -0.08295249938964844,
+ 0,
+ 0,
+ 0.02340221405029297,
+ 0.05393028259277344,
+ 0,
+ 0.0030164718627929688,
+ 0,
+ 0,
+ 0.02137470245361328,
+ -0.0648040771484375,
+ 0,
+ 0,
+ -0.0007104873657226562,
+ 0,
+ 0,
+ 0,
+ -0.02891063690185547,
+ 0,
+ 0,
+ -0.0024862289428710938,
+ -0.007077217102050781,
+ -0.004982948303222656,
+ 0.004157066345214844,
+ 0,
+ 0,
+ 0,
+ -0.0009584426879882812,
+ 0,
+ 0,
+ -0.0016260147094726562,
+ -0.03653144836425781,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.004261970520019531,
+ 0,
+ 0.1517467498779297,
+ 0,
+ -0.0017957687377929688,
+ 0,
+ 0.01949596405029297,
+ 0,
+ 0,
+ 0,
+ -0.024643898010253906,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.12193775177001953,
+ 0,
+ 0.01824474334716797,
+ 0.006918907165527344,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.5964584350585938,
+ 0,
+ -0.004886627197265625,
+ -0.0028219223022460938,
+ -0.013730049133300781,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.000370025634765625,
+ 0.11502552032470703,
+ 0,
+ 0
+ ],
+ [
+ 0,
+ 0,
+ 0.0020799636840820312,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.02874469757080078,
+ 0.0672769546508789,
+ 0.31006431579589844,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.014065742492675781,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.42875194549560547,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.037166595458984375,
+ 0,
+ 0,
+ 0.00395965576171875,
+ -0.09044742584228516,
+ 0,
+ 0,
+ 0,
+ 0.16284751892089844,
+ 0,
+ 0.2745513916015625,
+ 0,
+ 0,
+ 0.0013599395751953125,
+ 0,
+ 0,
+ -0.016633033752441406,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.002765655517578125,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.06857013702392578,
+ 0,
+ 0.0030755996704101562,
+ 0,
+ 0.005713462829589844,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.010555267333984375,
+ 0.35628509521484375,
+ 0,
+ 0,
+ 0,
+ -0.3705453872680664,
+ 0,
+ 0.1321268081665039,
+ 0,
+ 0,
+ 0.01171875,
+ 0,
+ 0.006653785705566406,
+ 0,
+ 0,
+ 0,
+ -0.04768085479736328,
+ 0,
+ 0,
+ 0.05365467071533203,
+ 0.10848140716552734,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0019636154174804688,
+ 0,
+ 0,
+ 0,
+ -0.0038604736328125,
+ -0.00696563720703125,
+ 0,
+ 0,
+ 0,
+ 0.004207611083984375,
+ -0.009866714477539062,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.041828155517578125,
+ 0,
+ 0,
+ 0,
+ 0.03432941436767578,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.02262592315673828,
+ 0.1012563705444336,
+ 0,
+ 0.0032415390014648438,
+ 0,
+ -0.028539657592773438,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.019530296325683594
+ ],
+ [
+ 0.0072574615478515625,
+ 0,
+ 0,
+ 0.0045604705810546875,
+ 0,
+ -0.002410888671875,
+ 0,
+ 0.000942230224609375,
+ 0,
+ -0.028242111206054688,
+ -0.06697559356689453,
+ -0.002197265625,
+ 0,
+ 0,
+ 0,
+ 0.01448822021484375,
+ 0.00038909912109375,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0072345733642578125,
+ 0.0015048980712890625,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.026609420776367188,
+ -0.007898330688476562,
+ 0.006641387939453125,
+ -0.012470245361328125,
+ 0,
+ 0,
+ 0,
+ -0.0054531097412109375,
+ 0.06533622741699219,
+ 0,
+ 0.00041484832763671875,
+ 0,
+ 0,
+ 0,
+ -0.002368927001953125,
+ 0,
+ 0.04226112365722656,
+ 0,
+ -0.0031299591064453125,
+ -0.0000457763671875,
+ 0.000308990478515625,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0275726318359375,
+ 0,
+ 0,
+ -0.004794120788574219,
+ 0.01718902587890625,
+ 0,
+ -0.001049041748046875,
+ 0,
+ -0.007875442504882812,
+ 0,
+ -0.00032806396484375,
+ 0,
+ 0.002880096435546875,
+ -0.0073566436767578125,
+ 0,
+ -0.012141227722167969,
+ -0.002796173095703125,
+ 0,
+ 0.0904073715209961,
+ -0.020002365112304688,
+ 0,
+ 0.0006046295166015625,
+ 0.0000095367431640625,
+ 0,
+ 0.09020233154296875,
+ -0.024329185485839844,
+ 0,
+ -0.0007257461547851562,
+ 0.0022792816162109375,
+ 0,
+ 0,
+ 0.0024671554565429688,
+ -0.031095504760742188,
+ 0.029073715209960938,
+ 0,
+ 0,
+ 0.017263412475585938,
+ 0.009774208068847656,
+ 0.01905059814453125,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.007511138916015625,
+ 0,
+ 0,
+ -0.01740264892578125,
+ -0.012363433837890625,
+ -0.007237434387207031,
+ 0.00046825408935546875,
+ 0.015039443969726562,
+ 0,
+ -0.001247406005859375,
+ 0,
+ 0.04442596435546875,
+ 0,
+ 0,
+ 0,
+ 0.0020885467529296875,
+ 0.0009975433349609375,
+ 0.0068645477294921875,
+ 0,
+ 0.0009918212890625,
+ 0.007763862609863281,
+ -0.10830020904541016,
+ 0.002170562744140625,
+ 0.0041522979736328125,
+ 0.0009832382202148438,
+ -0.0055789947509765625,
+ -0.0020475387573242188,
+ 0,
+ 0,
+ -0.46219825744628906,
+ -0.0004138946533203125,
+ 0.022248268127441406,
+ 0,
+ -0.023275375366210938,
+ 0,
+ 0,
+ -0.00007152557373046875,
+ -0.0017099380493164062,
+ 0,
+ 0,
+ 0.028047561645507812,
+ 0,
+ -0.006505012512207031
+ ],
+ [
+ 0.0026121139526367188,
+ 0.0023622512817382812,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.04861927032470703,
+ 0.04393959045410156,
+ 0.24942588806152344,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0894918441772461,
+ 0.011738777160644531,
+ 0.0023365020751953125,
+ 0,
+ 0.03142070770263672,
+ 0.007035255432128906,
+ 0.013895988464355469,
+ -0.38878440856933594,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.3524456024169922,
+ 0,
+ 0.04943275451660156,
+ 0.07975196838378906,
+ -0.13926124572753906,
+ 0,
+ 0.007584571838378906,
+ 0,
+ 0.10158729553222656,
+ 0,
+ 0.048768043518066406,
+ 0,
+ 0,
+ 0,
+ -0.010777473449707031,
+ 0,
+ -0.02371692657470703,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0021333694458007812,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.14519309997558594,
+ 0,
+ -0.023756027221679688,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.038219451904296875,
+ 0,
+ 0,
+ 0,
+ -0.07305049896240234,
+ 0.1724720001220703,
+ 0.035521507263183594,
+ 0,
+ 0.026566505432128906,
+ -0.2165508270263672,
+ -0.010828971862792969,
+ 0,
+ 0,
+ 0,
+ 0.06682586669921875,
+ 0.0020055770874023438,
+ 0,
+ 0.05693340301513672,
+ 0,
+ 0,
+ -0.1571969985961914,
+ 0,
+ 0,
+ 0.0275726318359375,
+ 0.09813213348388672,
+ 0,
+ -0.0074253082275390625,
+ 0,
+ 0,
+ -0.00006008148193359375,
+ 0,
+ 0,
+ 0,
+ 0.007464408874511719,
+ -0.011278152465820312,
+ 0,
+ 0,
+ 0.008585929870605469,
+ -0.02161121368408203,
+ -0.05259227752685547,
+ 0.15187358856201172,
+ 0.009034156799316406,
+ 0,
+ 0.01724529266357422,
+ 0.02186107635498047,
+ 0,
+ 0,
+ 0.023595809936523438,
+ 0.0018739700317382812,
+ 0.0014142990112304688,
+ 0.0001888275146484375,
+ 0,
+ 0,
+ 0.14745807647705078,
+ 0,
+ 0,
+ 0,
+ 0.022150039672851562,
+ 0.04754352569580078,
+ 0,
+ 0,
+ 0,
+ 0.12122058868408203,
+ 0.037743568420410156,
+ 0,
+ -0.022559165954589844,
+ 0,
+ -0.07815361022949219,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.1304798126220703
+ ]
+ ]
+ }
+ ],
+ "layout": {
+ "coloraxis": {
+ "cmid": 0,
+ "colorscale": [
+ [
+ 0,
+ "rgb(103,0,31)"
+ ],
+ [
+ 0.1,
+ "rgb(178,24,43)"
+ ],
+ [
+ 0.2,
+ "rgb(214,96,77)"
+ ],
+ [
+ 0.3,
+ "rgb(244,165,130)"
+ ],
+ [
+ 0.4,
+ "rgb(253,219,199)"
+ ],
+ [
+ 0.5,
+ "rgb(247,247,247)"
+ ],
+ [
+ 0.6,
+ "rgb(209,229,240)"
+ ],
+ [
+ 0.7,
+ "rgb(146,197,222)"
+ ],
+ [
+ 0.8,
+ "rgb(67,147,195)"
+ ],
+ [
+ 0.9,
+ "rgb(33,102,172)"
+ ],
+ [
+ 1,
+ "rgb(5,48,97)"
+ ]
+ ]
+ },
+ "template": {
+ "data": {
+ "bar": [
+ {
+ "error_x": {
+ "color": "#2a3f5f"
+ },
+ "error_y": {
+ "color": "#2a3f5f"
+ },
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "bar"
+ }
+ ],
+ "barpolar": [
+ {
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "barpolar"
+ }
+ ],
+ "carpet": [
+ {
+ "aaxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "baxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "type": "carpet"
+ }
+ ],
+ "choropleth": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "choropleth"
+ }
+ ],
+ "contour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "contour"
+ }
+ ],
+ "contourcarpet": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "contourcarpet"
+ }
+ ],
+ "heatmap": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmap"
+ }
+ ],
+ "heatmapgl": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmapgl"
+ }
+ ],
+ "histogram": [
+ {
+ "marker": {
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "histogram"
+ }
+ ],
+ "histogram2d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2d"
+ }
+ ],
+ "histogram2dcontour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2dcontour"
+ }
+ ],
+ "mesh3d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "mesh3d"
+ }
+ ],
+ "parcoords": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "parcoords"
+ }
+ ],
+ "pie": [
+ {
+ "automargin": true,
+ "type": "pie"
+ }
+ ],
+ "scatter": [
+ {
+ "fillpattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ },
+ "type": "scatter"
+ }
+ ],
+ "scatter3d": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatter3d"
+ }
+ ],
+ "scattercarpet": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattercarpet"
+ }
+ ],
+ "scattergeo": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergeo"
+ }
+ ],
+ "scattergl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergl"
+ }
+ ],
+ "scattermapbox": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattermapbox"
+ }
+ ],
+ "scatterpolar": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolar"
+ }
+ ],
+ "scatterpolargl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolargl"
+ }
+ ],
+ "scatterternary": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterternary"
+ }
+ ],
+ "surface": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "surface"
+ }
+ ],
+ "table": [
+ {
+ "cells": {
+ "fill": {
+ "color": "#EBF0F8"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "header": {
+ "fill": {
+ "color": "#C8D4E3"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "type": "table"
+ }
+ ]
+ },
+ "layout": {
+ "annotationdefaults": {
+ "arrowcolor": "#2a3f5f",
+ "arrowhead": 0,
+ "arrowwidth": 1
+ },
+ "autotypenumbers": "strict",
+ "coloraxis": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "colorscale": {
+ "diverging": [
+ [
+ 0,
+ "#8e0152"
+ ],
+ [
+ 0.1,
+ "#c51b7d"
+ ],
+ [
+ 0.2,
+ "#de77ae"
+ ],
+ [
+ 0.3,
+ "#f1b6da"
+ ],
+ [
+ 0.4,
+ "#fde0ef"
+ ],
+ [
+ 0.5,
+ "#f7f7f7"
+ ],
+ [
+ 0.6,
+ "#e6f5d0"
+ ],
+ [
+ 0.7,
+ "#b8e186"
+ ],
+ [
+ 0.8,
+ "#7fbc41"
+ ],
+ [
+ 0.9,
+ "#4d9221"
+ ],
+ [
+ 1,
+ "#276419"
+ ]
+ ],
+ "sequential": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "sequentialminus": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ]
+ },
+ "colorway": [
+ "#636efa",
+ "#EF553B",
+ "#00cc96",
+ "#ab63fa",
+ "#FFA15A",
+ "#19d3f3",
+ "#FF6692",
+ "#B6E880",
+ "#FF97FF",
+ "#FECB52"
+ ],
+ "font": {
+ "color": "#2a3f5f"
+ },
+ "geo": {
+ "bgcolor": "white",
+ "lakecolor": "white",
+ "landcolor": "#E5ECF6",
+ "showlakes": true,
+ "showland": true,
+ "subunitcolor": "white"
+ },
+ "hoverlabel": {
+ "align": "left"
+ },
+ "hovermode": "closest",
+ "mapbox": {
+ "style": "light"
+ },
+ "paper_bgcolor": "white",
+ "plot_bgcolor": "#E5ECF6",
+ "polar": {
+ "angularaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "radialaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "scene": {
+ "xaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "yaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "zaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ }
+ },
+ "shapedefaults": {
+ "line": {
+ "color": "#2a3f5f"
+ }
+ },
+ "ternary": {
+ "aaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "baxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "caxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "title": {
+ "x": 0.05
+ },
+ "xaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ },
+ "yaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ }
+ }
+ },
+ "title": {
+ "text": "Change in logit diff when ablating L5 SAE features for all prompts at pos 10"
+ },
+ "xaxis": {
+ "anchor": "y",
+ "constrain": "domain",
+ "domain": [
+ 0,
+ 1
+ ],
+ "scaleanchor": "y",
+ "title": {
+ "text": "Feature Idx"
+ }
+ },
+ "yaxis": {
+ "anchor": "x",
+ "autorange": "reversed",
+ "constrain": "domain",
+ "domain": [
+ 0,
+ 1
+ ],
+ "title": {
+ "text": "Prompt Idx"
+ }
+ }
+ }
+ },
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "def ablate_sae_feature(sae_acts, hook, pos, feature_id):\n",
+ " if pos is None:\n",
+ " sae_acts[:, :, feature_id] = 0.\n",
+ " else:\n",
+ " sae_acts[:, pos, feature_id] = 0.\n",
+ " return sae_acts\n",
+ "\n",
+ "layer = 5\n",
+ "sae = hook_name_to_sae[utils.get_act_name('z', layer)]\n",
+ "\n",
+ "logits_with_saes = model.run_with_saes(tokens, saes=sae)\n",
+ "clean_sae_baseline_per_prompt = logits_to_ave_logit_diff(logits_with_saes, answer_tokens, per_prompt=True)\n",
+ "\n",
+ "all_live_features = torch.arange(sae.cfg.d_sae)[live_feature_union.cpu()]\n",
+ "\n",
+ "causal_effects = torch.zeros((len(prompts), all_live_features.shape[0]))\n",
+ "fid_to_idx = {fid.item(): idx for idx, fid in enumerate(all_live_features)}\n",
+ "\n",
+ "\n",
+ "abl_layer, abl_pos = 5, 10\n",
+ "for feature_id in tqdm.tqdm(all_live_features):\n",
+ " feature_id = feature_id.item()\n",
+ " abl_feature_logits = model.run_with_hooks_with_saes(\n",
+ " tokens,\n",
+ " saes=sae,\n",
+ " fwd_hooks=[(utils.get_act_name('z', abl_layer) + \".hook_sae_acts_post\", partial(ablate_sae_feature, pos=abl_pos, feature_id=feature_id))]\n",
+ " ) # [batch, seq, vocab]\n",
+ " \n",
+ " abl_feature_logit_diff = logits_to_ave_logit_diff(abl_feature_logits, answer_tokens, per_prompt=True) # [batch]\n",
+ " causal_effects[:, fid_to_idx[feature_id]] = abl_feature_logit_diff - clean_sae_baseline_per_prompt\n",
+ "\n",
+ "\n",
+ "imshow(causal_effects, title=f\"Change in logit diff when ablating L{abl_layer} SAE features for all prompts at pos {abl_pos}\", xaxis=\"Feature Idx\", yaxis=\"Prompt Idx\", x=list(map(str, all_live_features.tolist())))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Although it's not super clean, we see a few features stand out, where ablating them causes a nontrivial drop in logit diff on multiple prompts: 7515 and 27535 for BABA prompts, with 44256 for ABBA prompts."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Add SAEs"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "While the `run_with_saes` family of methods are great for evaluating SAEs and exploratory analysis, you may want to permanently attach SAEs to your model. You can attach SAEs to any activation with `model.add_sae(sae)`, where sae is a HookedSAE. \n",
+ "\n",
+ "When you add an SAE, it gets stored in `model.acts_to_saes`, a dictionary that maps the activation name to the HookedSAE that is attached. The main benefit of permanently adding SAEs is that we can now just run the model like a normal HookedTransformer (with `forward`, `run_with_cache`, `run_with_hooks`), but some activations will be replaced with the reconstructed activations from the corresponding SAEs.\n",
+ "\n",
+ "I expect this to be most useful when you've already identified a good set of SAEs that you want to use for interpretability, and don't feel like passing in a massive list of saes for every forward pass."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Attached SAEs before add_sae {}\n",
+ "Attached SAEs after add_sae {'blocks.5.attn.hook_z': HookedSAE(\n",
+ " (hook_sae_input): HookPoint()\n",
+ " (hook_sae_acts_pre): HookPoint()\n",
+ " (hook_sae_acts_post): HookPoint()\n",
+ " (hook_sae_recons): HookPoint()\n",
+ " (hook_sae_error): HookPoint()\n",
+ " (hook_sae_output): HookPoint()\n",
+ ")}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"Attached SAEs before add_sae\", model.acts_to_saes)\n",
+ "layer = 5\n",
+ "sae = hook_name_to_sae[utils.get_act_name('z', layer)]\n",
+ "model.add_sae(sae)\n",
+ "print(\"Attached SAEs after add_sae\", model.acts_to_saes)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now we can just call the standard HookedTransformer forward, and the sae that we added will automatically be spliced in."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Average logit diff with SAEs: 3.6155965328216553\n"
+ ]
+ }
+ ],
+ "source": [
+ "logits_with_saes = model(tokens)\n",
+ "assert not torch.allclose(original_logits, logits_with_saes, atol=1e-4)\n",
+ "\n",
+ "average_logit_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens)\n",
+ "print(f\"Average logit diff with SAEs: {average_logit_diff_with_saes}\")\n",
+ "per_prompt_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens, per_prompt=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Run with cache"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Similarly, we can also use `logits, cache = model.run_with_cache(tokens)` directly to cache SAE activations:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.plotly.v1+json": {
+ "config": {
+ "plotlyServerURL": "https://plot.ly"
+ },
+ "data": [
+ {
+ "coloraxis": "coloraxis",
+ "hovertemplate": "Feature Id: %{x}
Prompt: %{y}
color: %{z}",
+ "name": "0",
+ "type": "heatmap",
+ "x": [
+ "46",
+ "345",
+ "702",
+ "1372",
+ "1755",
+ "1965",
+ "2457",
+ "2496",
+ "2646",
+ "2999",
+ "3047",
+ "4569",
+ "5132",
+ "5203",
+ "5508",
+ "5940",
+ "6144",
+ "6371",
+ "6515",
+ "6558",
+ "6812",
+ "7092",
+ "7515",
+ "7907",
+ "8063",
+ "8623",
+ "8737",
+ "8768",
+ "9096",
+ "9102",
+ "9186",
+ "9463",
+ "9746",
+ "9913",
+ "10581",
+ "10894",
+ "12109",
+ "12485",
+ "12764",
+ "12866",
+ "13063",
+ "13624",
+ "13707",
+ "13777",
+ "14844",
+ "15050",
+ "15170",
+ "15696",
+ "16178",
+ "16892",
+ "17156",
+ "17259",
+ "17497",
+ "17854",
+ "18043",
+ "18210",
+ "18318",
+ "18385",
+ "18440",
+ "18920",
+ "19183",
+ "19263",
+ "19442",
+ "19524",
+ "19573",
+ "20838",
+ "21151",
+ "21657",
+ "22108",
+ "23578",
+ "24091",
+ "24217",
+ "25792",
+ "26373",
+ "26410",
+ "27535",
+ "27787",
+ "27811",
+ "27960",
+ "28061",
+ "28241",
+ "28242",
+ "28254",
+ "28349",
+ "28977",
+ "29027",
+ "29482",
+ "29603",
+ "29700",
+ "29822",
+ "32177",
+ "32920",
+ "33320",
+ "33730",
+ "33966",
+ "34177",
+ "34334",
+ "34947",
+ "35403",
+ "35425",
+ "35579",
+ "35665",
+ "35815",
+ "36109",
+ "36172",
+ "36451",
+ "36767",
+ "36917",
+ "38570",
+ "39962",
+ "40409",
+ "40418",
+ "40661",
+ "41162",
+ "41185",
+ "41552",
+ "42024",
+ "42161",
+ "42437",
+ "42577",
+ "42882",
+ "42931",
+ "43035",
+ "43414",
+ "43643",
+ "43662",
+ "44203",
+ "44256",
+ "44452",
+ "44652",
+ "45179",
+ "45814",
+ "45984",
+ "46880",
+ "47117",
+ "47170",
+ "47231",
+ "47313",
+ "47680",
+ "48063",
+ "48703"
+ ],
+ "xaxis": "x",
+ "yaxis": "y",
+ "z": [
+ [
+ 0.23392018675804138,
+ 0,
+ 0,
+ 0.04335343837738037,
+ 0.44275617599487305,
+ 0,
+ 0,
+ 0.07259953022003174,
+ 0,
+ 0.6985604763031006,
+ 1.262436866760254,
+ 0,
+ 0.04656928777694702,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.45666736364364624,
+ 0.10434150695800781,
+ 0.30980953574180603,
+ 0.3319076895713806,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.7836596965789795,
+ 0,
+ 0,
+ 0.142583429813385,
+ 0.046830952167510986,
+ 0.3180348575115204,
+ 0.2927079200744629,
+ 0.12267106771469116,
+ 2.5688514709472656,
+ 0.2917236089706421,
+ 0.12333670258522034,
+ 0,
+ 0.1778419017791748,
+ 0,
+ 0.023626387119293213,
+ 0.02943563461303711,
+ 0,
+ 0.048882365226745605,
+ 0.13625454902648926,
+ 0,
+ 0,
+ 0.2634885013103485,
+ 0,
+ 0,
+ 0,
+ 0.21662655472755432,
+ 0,
+ 0,
+ 0,
+ 0.06997489929199219,
+ 0.006345987319946289,
+ 0,
+ 0.16112494468688965,
+ 0.4190089702606201,
+ 0,
+ 2.3819468021392822,
+ 1.0431660413742065,
+ 0,
+ 0.08364987373352051,
+ 0,
+ 0,
+ 0.3451769948005676,
+ 0.7391350865364075,
+ 0.4456520080566406,
+ 0.0019606351852416992,
+ 0.39914217591285706,
+ 0,
+ 0,
+ 0,
+ 0.29958274960517883,
+ 0.44243645668029785,
+ 0,
+ 0.1259920299053192,
+ 0.8349504470825195,
+ 0.37993764877319336,
+ 0.2633737325668335,
+ 0.08324140310287476,
+ 0,
+ 0,
+ 0.10421907901763916,
+ 0,
+ 0,
+ 0,
+ 0.36972635984420776,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.5578295588493347,
+ 0,
+ 0.9233021140098572,
+ 0,
+ 0.10010790824890137,
+ 0,
+ 0.45082613825798035,
+ 0,
+ 0,
+ 0,
+ 0.21043556928634644,
+ 0.12981292605400085,
+ 0.11557984352111816,
+ 0,
+ 0,
+ 0.17571094632148743,
+ 0.2823787331581116,
+ 0.1122598648071289,
+ 0,
+ 0,
+ 0.012049257755279541,
+ 0,
+ 0,
+ 0,
+ 2.417463541030884,
+ 0.0547795295715332,
+ 0.05216425657272339,
+ 0,
+ 0.6592545509338379,
+ 0.003663182258605957,
+ 0,
+ 0,
+ 0.04937589168548584,
+ 0.025814831256866455,
+ 0,
+ 0.8019273281097412,
+ 0,
+ 0.10218703746795654
+ ],
+ [
+ 0,
+ 0,
+ 0.3230956792831421,
+ 0,
+ 0,
+ 0,
+ 0.026041746139526367,
+ 0.31818556785583496,
+ 0,
+ 0.4900796413421631,
+ 0.04911249876022339,
+ 0,
+ 0,
+ 0.07309412956237793,
+ 0.08089971542358398,
+ 0.17180073261260986,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 2.3956947326660156,
+ 0,
+ 0,
+ 0.15781426429748535,
+ 0,
+ 0.5073252320289612,
+ 0.21765804290771484,
+ 0,
+ 0,
+ 1.618570327758789,
+ 0,
+ 0.22485831379890442,
+ 0.0830467939376831,
+ 0.7055595517158508,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.23371747136116028,
+ 0,
+ 0,
+ 0.6983060240745544,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.30831730365753174,
+ 0,
+ 0.417669415473938,
+ 0.05292201042175293,
+ 0,
+ 0,
+ 0,
+ 1.3391070365905762,
+ 0,
+ 0.41352108120918274,
+ 0,
+ 0,
+ 0,
+ 0.037178993225097656,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.2702980041503906,
+ 0,
+ 0,
+ 0.18745100498199463,
+ 1.3330132961273193,
+ 0.5793700814247131,
+ 0.33893001079559326,
+ 0,
+ 0.11196631193161011,
+ 1.720167636871338,
+ 0.17581266164779663,
+ 0.42567259073257446,
+ 0,
+ 0,
+ 0.23682871460914612,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.8280882835388184,
+ 0.1617840826511383,
+ 0,
+ 0.13557660579681396,
+ 0.5832244157791138,
+ 0,
+ 0,
+ 0.03256487846374512,
+ 0,
+ 0,
+ 0.03892314434051514,
+ 0,
+ 0,
+ 0,
+ 0.30978846549987793,
+ 0,
+ 0,
+ 0.36915141344070435,
+ 0,
+ 0.5477294325828552,
+ 0,
+ 0,
+ 0.06339260935783386,
+ 0.1851767599582672,
+ 0.5839155912399292,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.12337607145309448,
+ 0,
+ 0,
+ 1.0378936529159546,
+ 0,
+ 0,
+ 0,
+ 0.01616498827934265,
+ 0.20259439945220947,
+ 0,
+ 0,
+ 0.3087460398674011,
+ 0.618510365486145,
+ 0.24435847997665405,
+ 0,
+ 0.4668591022491455,
+ 0.1788468360900879,
+ 0.200361967086792,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.7064645290374756
+ ],
+ [
+ 0.2921750843524933,
+ 0,
+ 0,
+ 0.2805737257003784,
+ 0,
+ 0,
+ 0,
+ 0.3694216012954712,
+ 0,
+ 1.1156601905822754,
+ 1.2807728052139282,
+ 0,
+ 0.09175515174865723,
+ 0,
+ 0,
+ 0,
+ 0.10458803176879883,
+ 0,
+ 0.021218180656433105,
+ 0,
+ 0,
+ 0.01699376106262207,
+ 0.09601330757141113,
+ 0.054788172245025635,
+ 0,
+ 0.030488133430480957,
+ 0.021512210369110107,
+ 0.2717320919036865,
+ 0.29357004165649414,
+ 0.6420693397521973,
+ 0.05249035358428955,
+ 0,
+ 0.06201601028442383,
+ 0,
+ 0.4122554659843445,
+ 1.821354866027832,
+ 0.01981794834136963,
+ 0,
+ 0.14063221216201782,
+ 0.05093127489089966,
+ 0,
+ 0.32148706912994385,
+ 0.15257668495178223,
+ 2.418062686920166,
+ 0.17348229885101318,
+ 0.08421656489372253,
+ 0,
+ 0.4551248550415039,
+ 0,
+ 0.015430927276611328,
+ 0.24434363842010498,
+ 0,
+ 0.06232607364654541,
+ 0,
+ 0.04422914981842041,
+ 0.8720088005065918,
+ 0.3721686899662018,
+ 0,
+ 0,
+ 0,
+ 0.340120404958725,
+ 0,
+ 0,
+ 0.07813769578933716,
+ 0,
+ 0.0882720947265625,
+ 0.19706517457962036,
+ 0.4056885242462158,
+ 0.19529414176940918,
+ 0,
+ 2.928431510925293,
+ 1.1402223110198975,
+ 0,
+ 0.026796698570251465,
+ 0.0033188462257385254,
+ 0,
+ 0.3370524048805237,
+ 0.47657889127731323,
+ 0,
+ 0.10358679294586182,
+ 0.27619925141334534,
+ 0,
+ 0,
+ 0,
+ 0.40909066796302795,
+ 0.2599871754646301,
+ 0,
+ 0.275011271238327,
+ 0.5349323749542236,
+ 0.07697033882141113,
+ 0.17431437969207764,
+ 0,
+ 0,
+ 0,
+ 0.09000074863433838,
+ 0,
+ 0,
+ 0,
+ 0.276567280292511,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.5655339360237122,
+ 0,
+ 0.8971189856529236,
+ 0,
+ 0.5199201107025146,
+ 0,
+ 0.6301102638244629,
+ 0.013657361268997192,
+ 0.04469645023345947,
+ 0.038062095642089844,
+ 0.4305816888809204,
+ 0,
+ 0.04173767566680908,
+ 0,
+ 0,
+ 0,
+ 0.8985729217529297,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.08318889141082764,
+ 0.006362795829772949,
+ 2.069222927093506,
+ 0,
+ 0.7068352103233337,
+ 0,
+ 0.8527798652648926,
+ 0,
+ 0,
+ 0.4707651138305664,
+ 0,
+ 0,
+ 0,
+ 0.7790955305099487,
+ 0.021227538585662842,
+ 0.01846003532409668
+ ],
+ [
+ 0,
+ 0,
+ 0.2200499176979065,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.2433047890663147,
+ 0.2504638135433197,
+ 0.712148904800415,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.1410943865776062,
+ 0,
+ 0,
+ 0,
+ 0.11292147636413574,
+ 0,
+ 0,
+ 2.360842704772949,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.2830760478973389,
+ 0,
+ 0,
+ 0,
+ 0.6308119893074036,
+ 0,
+ 0.4040885865688324,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.5223236680030823,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.23784160614013672,
+ 0,
+ 0.04762387275695801,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.5758676528930664,
+ 0.01025208830833435,
+ 0.24556085467338562,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.1104614734649658,
+ 1.079118251800537,
+ 0,
+ 0,
+ 0.14462929964065552,
+ 1.9186956882476807,
+ 0,
+ 0.30735498666763306,
+ 0,
+ 0,
+ 0.07669633626937866,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.3975048065185547,
+ 0,
+ 0,
+ 0.3461639881134033,
+ 0.5062156915664673,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.19610454142093658,
+ 0.218009352684021,
+ 0,
+ 0,
+ 0.07953745126724243,
+ 0,
+ 0.1416093111038208,
+ 0,
+ 0,
+ 0,
+ 0.18305465579032898,
+ 0.10310900211334229,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.45315277576446533,
+ 0,
+ 0,
+ 0,
+ 0.09076884388923645,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.04246491193771362,
+ 0,
+ 0.1807355284690857,
+ 0,
+ 0.3002055883407593,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0
+ ],
+ [
+ 0.02005404233932495,
+ 0,
+ 0,
+ 0.07601284980773926,
+ 0,
+ 0,
+ 0,
+ 0.012166053056716919,
+ 0,
+ 1.0662918090820312,
+ 1.4810535907745361,
+ 0,
+ 0.014786958694458008,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.1491186022758484,
+ 0,
+ 0,
+ 0,
+ 0.38226866722106934,
+ 0.43110355734825134,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.6819074153900146,
+ 0,
+ 0.7939910888671875,
+ 0.28643298149108887,
+ 0,
+ 0,
+ 0.011532962322235107,
+ 0,
+ 1.2869157791137695,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.16446048021316528,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.03375712037086487,
+ 0,
+ 0,
+ 0,
+ 0.1915181577205658,
+ 0,
+ 0,
+ 0.10225892066955566,
+ 0,
+ 0,
+ 0,
+ 0.7338485717773438,
+ 0,
+ 0,
+ 1.3715617656707764,
+ 1.6115869283676147,
+ 0,
+ 0.7128411531448364,
+ 0,
+ 0,
+ 0.2161598801612854,
+ 0.5098914504051208,
+ 0,
+ 0,
+ 0.04084053635597229,
+ 0,
+ 0,
+ 0,
+ 0.17978456616401672,
+ 0,
+ 0,
+ 0.1365671455860138,
+ 0.27122950553894043,
+ 0.2945059537887573,
+ 0.2824629545211792,
+ 0,
+ 0,
+ 0,
+ 0.0464092493057251,
+ 0,
+ 0,
+ 0.04672741889953613,
+ 0.6179839968681335,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.045598745346069336,
+ 0,
+ 1.0172381401062012,
+ 0,
+ 0.07242608070373535,
+ 0,
+ 0.5165215730667114,
+ 0,
+ 0,
+ 0,
+ 0.5004003047943115,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.3409433960914612,
+ 0,
+ 0.1579979658126831,
+ 0.09901612997055054,
+ 0,
+ 0,
+ 0,
+ 0,
+ 2.413944721221924,
+ 0,
+ 0.20971286296844482,
+ 0.07062971591949463,
+ 0.26070594787597656,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.020640969276428223,
+ 1.0534553527832031,
+ 0,
+ 0
+ ],
+ [
+ 0,
+ 0,
+ 0.046907246112823486,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.20885008573532104,
+ 0.25957152247428894,
+ 1.0767037868499756,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.23976856470108032,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 2.762990951538086,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.29466086626052856,
+ 0,
+ 0,
+ 0.09433537721633911,
+ 1.2446393966674805,
+ 0,
+ 0,
+ 0,
+ 0.6668079495429993,
+ 0,
+ 0.7482341527938843,
+ 0,
+ 0,
+ 0.005075186491012573,
+ 0,
+ 0,
+ 0.4049275517463684,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.09314888715744019,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.4028928279876709,
+ 0,
+ 0.3687801659107208,
+ 0,
+ 0.10555410385131836,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.066054105758667,
+ 1.4596349000930786,
+ 0,
+ 0,
+ 0,
+ 2.3358588218688965,
+ 0,
+ 0.5390753149986267,
+ 0,
+ 0,
+ 0.12931063771247864,
+ 0,
+ 0.10619288682937622,
+ 0,
+ 0,
+ 0,
+ 0.41271400451660156,
+ 0,
+ 0,
+ 0.23865878582000732,
+ 0.7501264810562134,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.2947666645050049,
+ 0,
+ 0,
+ 0,
+ 0.05958199501037598,
+ 0.20450782775878906,
+ 0,
+ 0,
+ 0,
+ 0.13838836550712585,
+ 0.13835513591766357,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.45820748805999756,
+ 0,
+ 0,
+ 0,
+ 0.19962045550346375,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.20416772365570068,
+ 0.46223968267440796,
+ 0,
+ 0.22815394401550293,
+ 0,
+ 0.1125795841217041,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.3023688793182373
+ ],
+ [
+ 0.28365251421928406,
+ 0,
+ 0,
+ 0.41595208644866943,
+ 0,
+ 0.15376341342926025,
+ 0,
+ 0.22517156600952148,
+ 0,
+ 0.7871096134185791,
+ 1.3084614276885986,
+ 0.2012956142425537,
+ 0,
+ 0,
+ 0,
+ 0.2532406449317932,
+ 0.009012699127197266,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.7235959768295288,
+ 0.021468758583068848,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.8338297009468079,
+ 0.3022422790527344,
+ 0.6702529191970825,
+ 0.5416026711463928,
+ 0,
+ 0,
+ 0,
+ 0.2034381628036499,
+ 1.9052581787109375,
+ 0,
+ 0.23752644658088684,
+ 0,
+ 0,
+ 0,
+ 0.8470145463943481,
+ 0,
+ 2.820002555847168,
+ 0,
+ 0.16275432705879211,
+ 0.06714236736297607,
+ 0.12017238140106201,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.486280620098114,
+ 0,
+ 0,
+ 0.3096342086791992,
+ 0.3064201772212982,
+ 0,
+ 0.09773910045623779,
+ 0,
+ 0.4613642394542694,
+ 0,
+ 0.021892428398132324,
+ 0,
+ 0.18887782096862793,
+ 0.18538141250610352,
+ 0,
+ 0.42975664138793945,
+ 0.9873132705688477,
+ 0,
+ 2.163774013519287,
+ 1.2928048372268677,
+ 0,
+ 0.2320784330368042,
+ 0.0062233805656433105,
+ 0,
+ 1.2478563785552979,
+ 0.5479208827018738,
+ 0,
+ 0.06501156091690063,
+ 0.3741762936115265,
+ 0,
+ 0,
+ 0.31712013483047485,
+ 0.5228050947189331,
+ 0.3981531858444214,
+ 0,
+ 0,
+ 0.4854400157928467,
+ 0.3341655731201172,
+ 0.39207732677459717,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.3316766023635864,
+ 0,
+ 0,
+ 0.33435362577438354,
+ 0.1380615234375,
+ 0.7183249592781067,
+ 0.041296958923339844,
+ 0.7634149193763733,
+ 0,
+ 0.4028007984161377,
+ 0,
+ 0.6915435791015625,
+ 0,
+ 0,
+ 0,
+ 0.3831353187561035,
+ 0.05798754096031189,
+ 0.15244710445404053,
+ 0,
+ 0.03230410814285278,
+ 0.2039397656917572,
+ 0.6142292022705078,
+ 0.15542924404144287,
+ 0.07628917694091797,
+ 0.0812273919582367,
+ 0.15177401900291443,
+ 0.10224854946136475,
+ 0,
+ 0,
+ 2.8106069564819336,
+ 0.3994237184524536,
+ 0.6397127509117126,
+ 0,
+ 0.8949670791625977,
+ 0,
+ 0,
+ 0.18832790851593018,
+ 0.1450880765914917,
+ 0,
+ 0,
+ 0.6900937557220459,
+ 0,
+ 0.14745783805847168
+ ],
+ [
+ 0.12055802345275879,
+ 0.023864269256591797,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.3327372670173645,
+ 0.1789897382259369,
+ 1.1445300579071045,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.4361664652824402,
+ 0.09996795654296875,
+ 0.10051405429840088,
+ 0,
+ 0.4030296802520752,
+ 0.06672021746635437,
+ 0.6339577436447144,
+ 3.3947582244873047,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.9711236357688904,
+ 0,
+ 0.38066884875297546,
+ 0.4158353805541992,
+ 1.5344438552856445,
+ 0,
+ 0.19816407561302185,
+ 0,
+ 0.6646860241889954,
+ 0,
+ 0.16733816266059875,
+ 0,
+ 0,
+ 0,
+ 0.322623074054718,
+ 0,
+ 0.7314171195030212,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.043955981731414795,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.9436180591583252,
+ 0,
+ 0.29259607195854187,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.1570979356765747,
+ 0,
+ 0,
+ 0,
+ 1.1782727241516113,
+ 1.2431498765945435,
+ 0.32878363132476807,
+ 0,
+ 0.419150173664093,
+ 2.3304405212402344,
+ 0.8566346764564514,
+ 0,
+ 0,
+ 0,
+ 0.3841046392917633,
+ 0.10476112365722656,
+ 0,
+ 0.18140661716461182,
+ 0,
+ 0,
+ 0.6665420532226562,
+ 0,
+ 0,
+ 0.22877633571624756,
+ 0.9225524663925171,
+ 0,
+ 0.15886402130126953,
+ 0,
+ 0,
+ 0.02094721794128418,
+ 0,
+ 0,
+ 0,
+ 0.3046541213989258,
+ 0.2845715284347534,
+ 0,
+ 0,
+ 0.4244043231010437,
+ 0.164473295211792,
+ 0.30073386430740356,
+ 0.7123112678527832,
+ 0.1730642318725586,
+ 0,
+ 0.4041661322116852,
+ 0.39166414737701416,
+ 0,
+ 0,
+ 0.2103893756866455,
+ 0.007811635732650757,
+ 0.010994672775268555,
+ 0.03914850950241089,
+ 0,
+ 0,
+ 0.8430832624435425,
+ 0,
+ 0,
+ 0,
+ 0.15830591320991516,
+ 0.29398930072784424,
+ 0,
+ 0,
+ 0,
+ 0.5994948148727417,
+ 0.1704254150390625,
+ 0,
+ 0.4673898220062256,
+ 0,
+ 0.3204514980316162,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.8447363376617432
+ ]
+ ]
+ }
+ ],
+ "layout": {
+ "coloraxis": {
+ "cmid": 0,
+ "colorscale": [
+ [
+ 0,
+ "rgb(103,0,31)"
+ ],
+ [
+ 0.1,
+ "rgb(178,24,43)"
+ ],
+ [
+ 0.2,
+ "rgb(214,96,77)"
+ ],
+ [
+ 0.3,
+ "rgb(244,165,130)"
+ ],
+ [
+ 0.4,
+ "rgb(253,219,199)"
+ ],
+ [
+ 0.5,
+ "rgb(247,247,247)"
+ ],
+ [
+ 0.6,
+ "rgb(209,229,240)"
+ ],
+ [
+ 0.7,
+ "rgb(146,197,222)"
+ ],
+ [
+ 0.8,
+ "rgb(67,147,195)"
+ ],
+ [
+ 0.9,
+ "rgb(33,102,172)"
+ ],
+ [
+ 1,
+ "rgb(5,48,97)"
+ ]
+ ]
+ },
+ "template": {
+ "data": {
+ "bar": [
+ {
+ "error_x": {
+ "color": "#2a3f5f"
+ },
+ "error_y": {
+ "color": "#2a3f5f"
+ },
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "bar"
+ }
+ ],
+ "barpolar": [
+ {
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "barpolar"
+ }
+ ],
+ "carpet": [
+ {
+ "aaxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "baxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "type": "carpet"
+ }
+ ],
+ "choropleth": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "choropleth"
+ }
+ ],
+ "contour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "contour"
+ }
+ ],
+ "contourcarpet": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "contourcarpet"
+ }
+ ],
+ "heatmap": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmap"
+ }
+ ],
+ "heatmapgl": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmapgl"
+ }
+ ],
+ "histogram": [
+ {
+ "marker": {
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "histogram"
+ }
+ ],
+ "histogram2d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2d"
+ }
+ ],
+ "histogram2dcontour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2dcontour"
+ }
+ ],
+ "mesh3d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "mesh3d"
+ }
+ ],
+ "parcoords": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "parcoords"
+ }
+ ],
+ "pie": [
+ {
+ "automargin": true,
+ "type": "pie"
+ }
+ ],
+ "scatter": [
+ {
+ "fillpattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ },
+ "type": "scatter"
+ }
+ ],
+ "scatter3d": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatter3d"
+ }
+ ],
+ "scattercarpet": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattercarpet"
+ }
+ ],
+ "scattergeo": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergeo"
+ }
+ ],
+ "scattergl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergl"
+ }
+ ],
+ "scattermapbox": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattermapbox"
+ }
+ ],
+ "scatterpolar": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolar"
+ }
+ ],
+ "scatterpolargl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolargl"
+ }
+ ],
+ "scatterternary": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterternary"
+ }
+ ],
+ "surface": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "surface"
+ }
+ ],
+ "table": [
+ {
+ "cells": {
+ "fill": {
+ "color": "#EBF0F8"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "header": {
+ "fill": {
+ "color": "#C8D4E3"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "type": "table"
+ }
+ ]
+ },
+ "layout": {
+ "annotationdefaults": {
+ "arrowcolor": "#2a3f5f",
+ "arrowhead": 0,
+ "arrowwidth": 1
+ },
+ "autotypenumbers": "strict",
+ "coloraxis": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "colorscale": {
+ "diverging": [
+ [
+ 0,
+ "#8e0152"
+ ],
+ [
+ 0.1,
+ "#c51b7d"
+ ],
+ [
+ 0.2,
+ "#de77ae"
+ ],
+ [
+ 0.3,
+ "#f1b6da"
+ ],
+ [
+ 0.4,
+ "#fde0ef"
+ ],
+ [
+ 0.5,
+ "#f7f7f7"
+ ],
+ [
+ 0.6,
+ "#e6f5d0"
+ ],
+ [
+ 0.7,
+ "#b8e186"
+ ],
+ [
+ 0.8,
+ "#7fbc41"
+ ],
+ [
+ 0.9,
+ "#4d9221"
+ ],
+ [
+ 1,
+ "#276419"
+ ]
+ ],
+ "sequential": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "sequentialminus": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ]
+ },
+ "colorway": [
+ "#636efa",
+ "#EF553B",
+ "#00cc96",
+ "#ab63fa",
+ "#FFA15A",
+ "#19d3f3",
+ "#FF6692",
+ "#B6E880",
+ "#FF97FF",
+ "#FECB52"
+ ],
+ "font": {
+ "color": "#2a3f5f"
+ },
+ "geo": {
+ "bgcolor": "white",
+ "lakecolor": "white",
+ "landcolor": "#E5ECF6",
+ "showlakes": true,
+ "showland": true,
+ "subunitcolor": "white"
+ },
+ "hoverlabel": {
+ "align": "left"
+ },
+ "hovermode": "closest",
+ "mapbox": {
+ "style": "light"
+ },
+ "paper_bgcolor": "white",
+ "plot_bgcolor": "#E5ECF6",
+ "polar": {
+ "angularaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "radialaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "scene": {
+ "xaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "yaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "zaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ }
+ },
+ "shapedefaults": {
+ "line": {
+ "color": "#2a3f5f"
+ }
+ },
+ "ternary": {
+ "aaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "baxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "caxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "title": {
+ "x": 0.05
+ },
+ "xaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ },
+ "yaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ }
+ }
+ },
+ "title": {
+ "text": "Activations of Live SAE features at L5 S2 position per prompt"
+ },
+ "xaxis": {
+ "anchor": "y",
+ "constrain": "domain",
+ "domain": [
+ 0,
+ 1
+ ],
+ "scaleanchor": "y",
+ "title": {
+ "text": "Feature Id"
+ }
+ },
+ "yaxis": {
+ "anchor": "x",
+ "autorange": "reversed",
+ "constrain": "domain",
+ "domain": [
+ 0,
+ 1
+ ],
+ "title": {
+ "text": "Prompt"
+ }
+ }
+ }
+ },
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "layer = 5\n",
+ "_, cache = model.run_with_cache(tokens)\n",
+ "s2_pos = 10\n",
+ "sae_acts = cache[utils.get_act_name('z', layer) + \".hook_sae_acts_post\"][:, s2_pos, :]\n",
+ "\n",
+ "live_feature_mask = sae_acts > 0\n",
+ "live_feature_union = live_feature_mask.any(dim=0)\n",
+ "\n",
+ "imshow(\n",
+ " sae_acts[:, live_feature_union],\n",
+ " title = \"Activations of Live SAE features at L5 S2 position per prompt\",\n",
+ " xaxis=\"Feature Id\", yaxis=\"Prompt\",\n",
+ " x=list(map(str, live_feature_union.nonzero().flatten().tolist())),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Run with hooks"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Finally we can also use `run_with_hooks` and intervene on the added SAE's activations. To show a more complicated intervention, we'll try path patching the feature from the S-inhibition head's value vectors."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model.set_use_split_qkv_input(True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|āāāāāāāāāā| 141/141 [00:05<00:00, 26.94it/s]\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.plotly.v1+json": {
+ "config": {
+ "plotlyServerURL": "https://plot.ly"
+ },
+ "data": [
+ {
+ "coloraxis": "coloraxis",
+ "hovertemplate": "Feature Id: %{x}
Prompt Idx: %{y}
color: %{z}",
+ "name": "0",
+ "type": "heatmap",
+ "x": [
+ "46",
+ "345",
+ "702",
+ "1372",
+ "1755",
+ "1965",
+ "2457",
+ "2496",
+ "2646",
+ "2999",
+ "3047",
+ "4569",
+ "5132",
+ "5203",
+ "5508",
+ "5940",
+ "6144",
+ "6371",
+ "6515",
+ "6558",
+ "6812",
+ "7092",
+ "7515",
+ "7907",
+ "8063",
+ "8623",
+ "8737",
+ "8768",
+ "9096",
+ "9102",
+ "9186",
+ "9463",
+ "9746",
+ "9913",
+ "10581",
+ "10894",
+ "12109",
+ "12485",
+ "12764",
+ "12866",
+ "13063",
+ "13624",
+ "13707",
+ "13777",
+ "14844",
+ "15050",
+ "15170",
+ "15696",
+ "16178",
+ "16892",
+ "17156",
+ "17259",
+ "17497",
+ "17854",
+ "18043",
+ "18210",
+ "18318",
+ "18385",
+ "18440",
+ "18920",
+ "19183",
+ "19263",
+ "19442",
+ "19524",
+ "19573",
+ "20838",
+ "21151",
+ "21657",
+ "22108",
+ "23578",
+ "24091",
+ "24217",
+ "25792",
+ "26373",
+ "26410",
+ "27535",
+ "27787",
+ "27811",
+ "27960",
+ "28061",
+ "28241",
+ "28242",
+ "28254",
+ "28349",
+ "28977",
+ "29027",
+ "29482",
+ "29603",
+ "29700",
+ "29822",
+ "32177",
+ "32920",
+ "33320",
+ "33730",
+ "33966",
+ "34177",
+ "34334",
+ "34947",
+ "35403",
+ "35425",
+ "35579",
+ "35665",
+ "35815",
+ "36109",
+ "36172",
+ "36451",
+ "36767",
+ "36917",
+ "38570",
+ "39962",
+ "40409",
+ "40418",
+ "40661",
+ "41162",
+ "41185",
+ "41552",
+ "42024",
+ "42161",
+ "42437",
+ "42577",
+ "42882",
+ "42931",
+ "43035",
+ "43414",
+ "43643",
+ "43662",
+ "44203",
+ "44256",
+ "44452",
+ "44652",
+ "45179",
+ "45814",
+ "45984",
+ "46880",
+ "47117",
+ "47170",
+ "47231",
+ "47313",
+ "47680",
+ "48063",
+ "48703"
+ ],
+ "xaxis": "x",
+ "yaxis": "y",
+ "z": [
+ [
+ 0.0005645751953125,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.000339508056640625,
+ -0.003261566162109375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.00069427490234375,
+ 0.0000057220458984375,
+ 0.0016155242919921875,
+ -0.09088897705078125,
+ 0.0000057220458984375,
+ 0.00011444091796875,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ -0.009515762329101562,
+ -0.0022525787353515625,
+ 0.0031604766845703125,
+ -0.0020704269409179688,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ -0.013577461242675781,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ -0.0017032623291015625,
+ 0.0002880096435546875,
+ -0.00020503997802734375,
+ -0.0016231536865234375,
+ 0.00037860870361328125,
+ -0.0098114013671875,
+ -0.002185821533203125,
+ -0.0008878707885742188,
+ 0.0000057220458984375,
+ 0.0002346038818359375,
+ 0.0000057220458984375,
+ -0.000354766845703125,
+ 0.00036334991455078125,
+ 0.0000057220458984375,
+ -0.000988006591796875,
+ -0.00044918060302734375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.005593299865722656,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ -0.005214691162109375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ -0.000789642333984375,
+ 0.00010585784912109375,
+ 0.0000057220458984375,
+ -0.0059051513671875,
+ 0.0011091232299804688,
+ 0.0000057220458984375,
+ 0.026823997497558594,
+ 0.019052505493164062,
+ 0.0000057220458984375,
+ 0.0000152587890625,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0033597946166992188,
+ -0.020666122436523438,
+ -0.0041141510009765625,
+ -0.000011444091796875,
+ 0.00130462646484375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ -0.01567840576171875,
+ 0.006500244140625,
+ 0.0000057220458984375,
+ 0.002086639404296875,
+ 0.00576019287109375,
+ 0.004245758056640625,
+ 0.006832122802734375,
+ 0.0006284713745117188,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ -0.0009737014770507812,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ -0.0040988922119140625,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ -0.003326416015625,
+ 0.0000057220458984375,
+ 0.020755767822265625,
+ 0.0000057220458984375,
+ -0.0008373260498046875,
+ 0.0000057220458984375,
+ 0.007825851440429688,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ -0.002574920654296875,
+ 0.00151824951171875,
+ -0.00008678436279296875,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.001171112060546875,
+ -0.02040863037109375,
+ -0.0014247894287109375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.00003814697265625,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ -0.3322334289550781,
+ 0.000579833984375,
+ 0.001293182373046875,
+ 0.0000057220458984375,
+ 0.0066661834716796875,
+ 0.0000171661376953125,
+ 0.0000057220458984375,
+ 0.0000057220458984375,
+ 0.0005435943603515625,
+ 0.00032806396484375,
+ 0.0000057220458984375,
+ 0.023120880126953125,
+ 0.0000057220458984375,
+ -0.0017566680908203125
+ ],
+ [
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.0040073394775390625,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.0022735595703125,
+ -0.0012683868408203125,
+ 0.00000762939453125,
+ 0.017993927001953125,
+ 0.011075973510742188,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ -0.001407623291015625,
+ -0.000270843505859375,
+ -0.010431289672851562,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ -0.6347770690917969,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.0005435943603515625,
+ 0.00000762939453125,
+ 0.09274864196777344,
+ 0.008495330810546875,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ -0.08464431762695312,
+ 0.00000762939453125,
+ 0.028835296630859375,
+ 0.01250457763671875,
+ 0.029806137084960938,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.012714385986328125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ -0.0004444122314453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.003757476806640625,
+ 0.00000762939453125,
+ 0.0025272369384765625,
+ 0.0013427734375,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.07260704040527344,
+ 0.00000762939453125,
+ 0.01149749755859375,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ -0.000213623046875,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ -0.016370773315429688,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ -0.00792694091796875,
+ 0.03365135192871094,
+ -0.004932403564453125,
+ 0.005069732666015625,
+ 0.00000762939453125,
+ 0.0031223297119140625,
+ -0.5932121276855469,
+ -0.0007534027099609375,
+ 0.05148506164550781,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.014024734497070312,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ -0.11317634582519531,
+ -0.0026416778564453125,
+ 0.00000762939453125,
+ -0.006038665771484375,
+ 0.00672149658203125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.000064849853515625,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.0005397796630859375,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ -0.0024967193603515625,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.016933441162109375,
+ 0.00000762939453125,
+ -0.0049343109130859375,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ -0.00244140625,
+ -0.00624847412109375,
+ 0.018770217895507812,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ -0.001132965087890625,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.1962738037109375,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ -0.0005283355712890625,
+ 0.0070934295654296875,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.10946464538574219,
+ 0.05410957336425781,
+ -0.0026397705078125,
+ 0.00000762939453125,
+ 0.005107879638671875,
+ 0.006359100341796875,
+ -0.04090118408203125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.00000762939453125,
+ 0.06792449951171875
+ ],
+ [
+ 0.0032672882080078125,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ 0.0026044845581054688,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ -0.0013751983642578125,
+ 0.00000667572021484375,
+ 0.018096923828125,
+ -0.29747962951660156,
+ 0.00000667572021484375,
+ 0.00159454345703125,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ -0.00185394287109375,
+ 0.00000667572021484375,
+ 0.000064849853515625,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ 0.0004253387451171875,
+ 0.02138805389404297,
+ 0.000370025634765625,
+ 0.00000667572021484375,
+ -0.0002880096435546875,
+ 0.000560760498046875,
+ -0.03230476379394531,
+ -0.02060699462890625,
+ 0.020964622497558594,
+ -0.0022487640380859375,
+ 0.00000667572021484375,
+ 0.001964569091796875,
+ 0.00000667572021484375,
+ -0.07773113250732422,
+ -0.042862892150878906,
+ 0.00027751922607421875,
+ 0.00000667572021484375,
+ -0.0020580291748046875,
+ 0.001407623291015625,
+ 0.00000667572021484375,
+ -0.0008306503295898438,
+ 0.00371551513671875,
+ -0.08299636840820312,
+ -0.0030012130737304688,
+ -0.0021905899047851562,
+ 0.00000667572021484375,
+ 0.011617660522460938,
+ 0.00000667572021484375,
+ -0.0000152587890625,
+ 0.005359649658203125,
+ 0.00000667572021484375,
+ -0.0042018890380859375,
+ 0.00000667572021484375,
+ 0.0008802413940429688,
+ -0.049579620361328125,
+ 0.010822296142578125,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ -0.014369964599609375,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ -0.0016632080078125,
+ 0.00000667572021484375,
+ 0.0035800933837890625,
+ 0.024021148681640625,
+ -0.04512596130371094,
+ -0.0006885528564453125,
+ 0.00000667572021484375,
+ 0.013338088989257812,
+ 0.06371307373046875,
+ 0.00000667572021484375,
+ 0.000629425048828125,
+ 0.00002002716064453125,
+ 0.00000667572021484375,
+ 0.015112876892089844,
+ -0.05301094055175781,
+ 0.00000667572021484375,
+ -0.0011320114135742188,
+ 0.0012521743774414062,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ -0.038700103759765625,
+ -0.0035238265991210938,
+ 0.00000667572021484375,
+ 0.00608062744140625,
+ -0.011157035827636719,
+ 0.004566192626953125,
+ 0.0062274932861328125,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ -0.0015010833740234375,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ -0.010572433471679688,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ -0.016614913940429688,
+ 0.00000667572021484375,
+ 0.030905723571777344,
+ 0.00000667572021484375,
+ -0.015107154846191406,
+ 0.00000667572021484375,
+ 0.012714385986328125,
+ -0.0009021759033203125,
+ -0.00067138671875,
+ 0.0006847381591796875,
+ -0.005970954895019531,
+ 0.00000667572021484375,
+ 0.000392913818359375,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ -0.20943737030029297,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ 0.0024538040161132812,
+ -0.00016117095947265625,
+ -0.6926145553588867,
+ 0.00000667572021484375,
+ -0.006705284118652344,
+ 0.00000667572021484375,
+ 0.013433456420898438,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ 0.0039653778076171875,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ 0.00000667572021484375,
+ 0.05192756652832031,
+ -0.00046539306640625,
+ -0.0010156631469726562
+ ],
+ [
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ 0.0073337554931640625,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ 0.0017852783203125,
+ 0.021762847900390625,
+ 0.023838043212890625,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.0093231201171875,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ 0.00185394287109375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.7318296432495117,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.06693649291992188,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ 0.04135417938232422,
+ -0.00000667572021484375,
+ 0.0012073516845703125,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.0023708343505859375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ 0.00597381591796875,
+ -0.00000667572021484375,
+ 0.0001049041748046875,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ 0.04203224182128906,
+ -0.000133514404296875,
+ 0.0032367706298828125,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ 0.053966522216796875,
+ -0.017469406127929688,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ 0.0032787322998046875,
+ -0.8294486999511719,
+ -0.00000667572021484375,
+ 0.042545318603515625,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ 0.006573677062988281,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.1314229965209961,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.022655487060546875,
+ 0.0008211135864257812,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.019756317138671875,
+ -0.0028676986694335938,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ 0.0034084320068359375,
+ -0.00000667572021484375,
+ 0.0000171661376953125,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.0022497177124023438,
+ 0.00191497802734375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ 0.09851455688476562,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.003956794738769531,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ 0.0011348724365234375,
+ -0.00000667572021484375,
+ 0.0007839202880859375,
+ -0.00000667572021484375,
+ -0.0783843994140625,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375,
+ -0.00000667572021484375
+ ],
+ [
+ -0.00021839141845703125,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ -0.00017833709716796875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ -0.00004863739013671875,
+ 0.00000286102294921875,
+ 0.0024118423461914062,
+ -0.1688375473022461,
+ 0.00000286102294921875,
+ 0.0005617141723632812,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ -0.0027265548706054688,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ -0.009179115295410156,
+ 0.011872291564941406,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.006281852722167969,
+ 0.00000286102294921875,
+ 0.011416435241699219,
+ 0.014454841613769531,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ -0.00018596649169921875,
+ 0.00000286102294921875,
+ 0.012002944946289062,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ -0.0023813247680664062,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.000225067138671875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ -0.0033779144287109375,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ -0.0017099380493164062,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ -0.05732154846191406,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.016089439392089844,
+ 0.07070255279541016,
+ 0.00000286102294921875,
+ 0.014483451843261719,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.0017747879028320312,
+ -0.024786949157714844,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00012302398681640625,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ -0.0092620849609375,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00185394287109375,
+ -0.00025177001953125,
+ 0.008860588073730469,
+ 0.006030082702636719,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00017833709716796875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ -0.001644134521484375,
+ 0.0026140213012695312,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ -0.0013418197631835938,
+ 0.00000286102294921875,
+ 0.037514686584472656,
+ 0.00000286102294921875,
+ -0.00038433074951171875,
+ 0.00000286102294921875,
+ 0.01964282989501953,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.005845069885253906,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ -0.04890918731689453,
+ 0.00000286102294921875,
+ 0.008494377136230469,
+ -0.00026988983154296875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ -0.37475109100341797,
+ 0.00000286102294921875,
+ 0.004479408264160156,
+ -0.0015649795532226562,
+ 0.00385284423828125,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00000286102294921875,
+ 0.00030803680419921875,
+ 0.06992149353027344,
+ 0.00000286102294921875,
+ 0.00000286102294921875
+ ],
+ [
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0018415451049804688,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0016222000122070312,
+ 0.023705482482910156,
+ 0.07090950012207031,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ -0.021169662475585938,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ -1.3031587600708008,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.08781909942626953,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.016541481018066406,
+ -0.10686969757080078,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.04713726043701172,
+ 0.0000019073486328125,
+ 0.002704620361328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.00046062469482421875,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ -0.01665210723876953,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0031337738037109375,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0426177978515625,
+ 0.0000019073486328125,
+ 0.018036842346191406,
+ 0.0000019073486328125,
+ -0.001964569091796875,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0788869857788086,
+ -0.03188610076904297,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ -1.4886322021484375,
+ 0.0000019073486328125,
+ 0.0885171890258789,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.01448822021484375,
+ 0.0000019073486328125,
+ -0.0066547393798828125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ -0.045001983642578125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ -0.017113685607910156,
+ 0.010157585144042969,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ -0.0030698776245117188,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ -0.0001583099365234375,
+ -0.004227638244628906,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ -0.008556365966796875,
+ 0.007357597351074219,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.13220977783203125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ -0.013454437255859375,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.033707618713378906,
+ 0.006083488464355469,
+ 0.0000019073486328125,
+ 0.0014142990112304688,
+ 0.0000019073486328125,
+ -0.04172039031982422,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.0000019073486328125,
+ 0.03891944885253906
+ ],
+ [
+ 0.0013017654418945312,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.0019664764404296875,
+ 0.00000858306884765625,
+ 0.0035953521728515625,
+ 0.00000858306884765625,
+ 0.0006504058837890625,
+ 0.00000858306884765625,
+ 0.0031061172485351562,
+ -0.07722282409667969,
+ -0.0011444091796875,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.0056209564208984375,
+ 0.00003147125244140625,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.015537261962890625,
+ 0.001983642578125,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ -0.025580406188964844,
+ -0.005356788635253906,
+ 0.016262054443359375,
+ -0.005573272705078125,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ -0.0113525390625,
+ 0.013624191284179688,
+ 0.00000858306884765625,
+ 0.000110626220703125,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.000293731689453125,
+ 0.00000858306884765625,
+ 0.026404380798339844,
+ 0.00000858306884765625,
+ 0.0005817413330078125,
+ 0.00007343292236328125,
+ 0.0010223388671875,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ -0.009862899780273438,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ -0.006870269775390625,
+ 0.00435638427734375,
+ 0.00000858306884765625,
+ 0.000232696533203125,
+ 0.00000858306884765625,
+ -0.00616455078125,
+ 0.00000858306884765625,
+ 0.00033283233642578125,
+ 0.00000858306884765625,
+ -0.0016880035400390625,
+ 0.00286102294921875,
+ 0.00000858306884765625,
+ -0.01665496826171875,
+ 0.008039474487304688,
+ 0.00000858306884765625,
+ 0.03484916687011719,
+ 0.018899917602539062,
+ 0.00000858306884765625,
+ 0.00034809112548828125,
+ -0.0000095367431640625,
+ 0.00000858306884765625,
+ 0.022369384765625,
+ -0.00615692138671875,
+ 0.00000858306884765625,
+ -0.00008392333984375,
+ -0.0018634796142578125,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ -0.0065765380859375,
+ -0.00798797607421875,
+ 0.007740974426269531,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.0047855377197265625,
+ 0.00484466552734375,
+ 0.006256103515625,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ -0.005949974060058594,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ -0.0056934356689453125,
+ -0.0057353973388671875,
+ -0.005535125732421875,
+ 0.00028228759765625,
+ 0.0137786865234375,
+ 0.00000858306884765625,
+ 0.0026874542236328125,
+ 0.00000858306884765625,
+ 0.01714324951171875,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.00165557861328125,
+ 0.0006313323974609375,
+ -0.00090789794921875,
+ 0.00000858306884765625,
+ 0.00016021728515625,
+ 0.00311279296875,
+ -0.04284191131591797,
+ -0.00058746337890625,
+ 0.0028972625732421875,
+ -0.001148223876953125,
+ 0.0013751983642578125,
+ -0.0005426406860351562,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ -0.29439735412597656,
+ 0.0019617080688476562,
+ 0.018915176391601562,
+ 0.00000858306884765625,
+ 0.009466171264648438,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.0011997222900390625,
+ 0.001117706298828125,
+ 0.00000858306884765625,
+ 0.00000858306884765625,
+ 0.02146148681640625,
+ 0.00000858306884765625,
+ -0.0012531280517578125
+ ],
+ [
+ 0.005249977111816406,
+ 0.0015926361083984375,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ 0.0021581649780273438,
+ 0.01883697509765625,
+ 0.0733022689819336,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.04300212860107422,
+ 0.0030794143676757812,
+ -0.0017910003662109375,
+ -0.00000286102294921875,
+ 0.016645431518554688,
+ -0.021103858947753906,
+ 0.013091087341308594,
+ -1.6041021347045898,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ 0.3691072463989258,
+ -0.00000286102294921875,
+ -0.01113128662109375,
+ 0.09581279754638672,
+ -0.11300373077392578,
+ -0.00000286102294921875,
+ 0.047149658203125,
+ -0.00000286102294921875,
+ 0.053336143493652344,
+ -0.00000286102294921875,
+ 0.004380226135253906,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.008252143859863281,
+ -0.00000286102294921875,
+ -0.018776893615722656,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ 0.0016984939575195312,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ 0.10451030731201172,
+ -0.00000286102294921875,
+ 0.010519981384277344,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.014172554016113281,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ 0.07860374450683594,
+ -0.047211647033691406,
+ 0.010329246520996094,
+ -0.00000286102294921875,
+ 0.02579212188720703,
+ -1.5303049087524414,
+ -0.020979881286621094,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ 0.05430316925048828,
+ 0.006442070007324219,
+ -0.00000286102294921875,
+ 0.035637855529785156,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.0784912109375,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.020351409912109375,
+ 0.02591228485107422,
+ -0.00000286102294921875,
+ 0.0030069351196289062,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ 0.0012063980102539062,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.04748249053955078,
+ -0.00510406494140625,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ 0.03345203399658203,
+ -0.0017213821411132812,
+ -0.008072853088378906,
+ 0.014155387878417969,
+ -0.003909111022949219,
+ -0.00000286102294921875,
+ -0.02114105224609375,
+ 0.021615028381347656,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ 0.011925697326660156,
+ 0.0005092620849609375,
+ 0.000263214111328125,
+ -0.00007343292236328125,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ 0.2987813949584961,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.011395454406738281,
+ 0.01917552947998047,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ 0.12213993072509766,
+ 0.0026998519897460938,
+ -0.00000286102294921875,
+ 0.009751319885253906,
+ -0.00000286102294921875,
+ -0.12412357330322266,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ -0.00000286102294921875,
+ 0.15506553649902344
+ ]
+ ]
+ }
+ ],
+ "layout": {
+ "coloraxis": {
+ "cmid": 0,
+ "colorscale": [
+ [
+ 0,
+ "rgb(103,0,31)"
+ ],
+ [
+ 0.1,
+ "rgb(178,24,43)"
+ ],
+ [
+ 0.2,
+ "rgb(214,96,77)"
+ ],
+ [
+ 0.3,
+ "rgb(244,165,130)"
+ ],
+ [
+ 0.4,
+ "rgb(253,219,199)"
+ ],
+ [
+ 0.5,
+ "rgb(247,247,247)"
+ ],
+ [
+ 0.6,
+ "rgb(209,229,240)"
+ ],
+ [
+ 0.7,
+ "rgb(146,197,222)"
+ ],
+ [
+ 0.8,
+ "rgb(67,147,195)"
+ ],
+ [
+ 0.9,
+ "rgb(33,102,172)"
+ ],
+ [
+ 1,
+ "rgb(5,48,97)"
+ ]
+ ]
+ },
+ "template": {
+ "data": {
+ "bar": [
+ {
+ "error_x": {
+ "color": "#2a3f5f"
+ },
+ "error_y": {
+ "color": "#2a3f5f"
+ },
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "bar"
+ }
+ ],
+ "barpolar": [
+ {
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "barpolar"
+ }
+ ],
+ "carpet": [
+ {
+ "aaxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "baxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "type": "carpet"
+ }
+ ],
+ "choropleth": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "choropleth"
+ }
+ ],
+ "contour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "contour"
+ }
+ ],
+ "contourcarpet": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "contourcarpet"
+ }
+ ],
+ "heatmap": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmap"
+ }
+ ],
+ "heatmapgl": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmapgl"
+ }
+ ],
+ "histogram": [
+ {
+ "marker": {
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "histogram"
+ }
+ ],
+ "histogram2d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2d"
+ }
+ ],
+ "histogram2dcontour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2dcontour"
+ }
+ ],
+ "mesh3d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "mesh3d"
+ }
+ ],
+ "parcoords": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "parcoords"
+ }
+ ],
+ "pie": [
+ {
+ "automargin": true,
+ "type": "pie"
+ }
+ ],
+ "scatter": [
+ {
+ "fillpattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ },
+ "type": "scatter"
+ }
+ ],
+ "scatter3d": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatter3d"
+ }
+ ],
+ "scattercarpet": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattercarpet"
+ }
+ ],
+ "scattergeo": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergeo"
+ }
+ ],
+ "scattergl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergl"
+ }
+ ],
+ "scattermapbox": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattermapbox"
+ }
+ ],
+ "scatterpolar": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolar"
+ }
+ ],
+ "scatterpolargl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolargl"
+ }
+ ],
+ "scatterternary": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterternary"
+ }
+ ],
+ "surface": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "surface"
+ }
+ ],
+ "table": [
+ {
+ "cells": {
+ "fill": {
+ "color": "#EBF0F8"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "header": {
+ "fill": {
+ "color": "#C8D4E3"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "type": "table"
+ }
+ ]
+ },
+ "layout": {
+ "annotationdefaults": {
+ "arrowcolor": "#2a3f5f",
+ "arrowhead": 0,
+ "arrowwidth": 1
+ },
+ "autotypenumbers": "strict",
+ "coloraxis": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "colorscale": {
+ "diverging": [
+ [
+ 0,
+ "#8e0152"
+ ],
+ [
+ 0.1,
+ "#c51b7d"
+ ],
+ [
+ 0.2,
+ "#de77ae"
+ ],
+ [
+ 0.3,
+ "#f1b6da"
+ ],
+ [
+ 0.4,
+ "#fde0ef"
+ ],
+ [
+ 0.5,
+ "#f7f7f7"
+ ],
+ [
+ 0.6,
+ "#e6f5d0"
+ ],
+ [
+ 0.7,
+ "#b8e186"
+ ],
+ [
+ 0.8,
+ "#7fbc41"
+ ],
+ [
+ 0.9,
+ "#4d9221"
+ ],
+ [
+ 1,
+ "#276419"
+ ]
+ ],
+ "sequential": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "sequentialminus": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ]
+ },
+ "colorway": [
+ "#636efa",
+ "#EF553B",
+ "#00cc96",
+ "#ab63fa",
+ "#FFA15A",
+ "#19d3f3",
+ "#FF6692",
+ "#B6E880",
+ "#FF97FF",
+ "#FECB52"
+ ],
+ "font": {
+ "color": "#2a3f5f"
+ },
+ "geo": {
+ "bgcolor": "white",
+ "lakecolor": "white",
+ "landcolor": "#E5ECF6",
+ "showlakes": true,
+ "showland": true,
+ "subunitcolor": "white"
+ },
+ "hoverlabel": {
+ "align": "left"
+ },
+ "hovermode": "closest",
+ "mapbox": {
+ "style": "light"
+ },
+ "paper_bgcolor": "white",
+ "plot_bgcolor": "#E5ECF6",
+ "polar": {
+ "angularaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "radialaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "scene": {
+ "xaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "yaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "zaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ }
+ },
+ "shapedefaults": {
+ "line": {
+ "color": "#2a3f5f"
+ }
+ },
+ "ternary": {
+ "aaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "baxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "caxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "title": {
+ "x": 0.05
+ },
+ "xaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ },
+ "yaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ }
+ }
+ },
+ "title": {
+ "text": "Change in logit diff when path patching features from S_inhibition heads values per prompts"
+ },
+ "xaxis": {
+ "anchor": "y",
+ "constrain": "domain",
+ "domain": [
+ 0,
+ 1
+ ],
+ "scaleanchor": "y",
+ "title": {
+ "text": "Feature Id"
+ }
+ },
+ "yaxis": {
+ "anchor": "x",
+ "autorange": "reversed",
+ "constrain": "domain",
+ "domain": [
+ 0,
+ 1
+ ],
+ "title": {
+ "text": "Prompt Idx"
+ }
+ }
+ }
+ },
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "def path_patch_v_input(v_input, hook, feature_dirs, pos, head_index):\n",
+ " v_input[:, pos, head_index, :] = v_input[:, pos, head_index, :] - feature_dirs\n",
+ " return v_input\n",
+ "\n",
+ "\n",
+ "s_inhib_heads = [(7, 3), (7, 9), (8,6), (8,10)]\n",
+ "\n",
+ "results = torch.zeros(tokens.shape[0], all_live_features.shape[0])\n",
+ "\n",
+ "W_O_cat = einops.rearrange(\n",
+ " model.W_O,\n",
+ " \"n_layers n_heads d_head d_model -> n_layers (n_heads d_head) d_model\"\n",
+ ")\n",
+ "\n",
+ "for feature_id in tqdm.tqdm(all_live_features):\n",
+ " feature_id = feature_id.item()\n",
+ " feature_acts = cache[utils.get_act_name('z', abl_layer) + \".hook_sae_acts_post\"][:, abl_pos, feature_id] # [batch]\n",
+ " feature_dirs = (feature_acts.unsqueeze(-1) * sae.W_dec[feature_id]) @ W_O_cat[abl_layer]\n",
+ " hook_fns = [\n",
+ " (utils.get_act_name('v_input', layer), partial(path_patch_v_input, feature_dirs=feature_dirs, pos=abl_pos, head_index=head)) for (layer, head) in s_inhib_heads\n",
+ " ]\n",
+ " path_patched_logits = model.run_with_hooks(\n",
+ " tokens,\n",
+ " return_type=\"logits\",\n",
+ " fwd_hooks=hook_fns\n",
+ " )\n",
+ "\n",
+ " path_patched_logit_diff = logits_to_ave_logit_diff(path_patched_logits, answer_tokens, per_prompt=True)\n",
+ " results[:, fid_to_idx[feature_id]] = path_patched_logit_diff - clean_sae_baseline_per_prompt\n",
+ "\n",
+ "imshow(\n",
+ " results, \n",
+ " title=f\"Change in logit diff when path patching features from S_inhibition heads values per prompts\",\n",
+ " xaxis=\"Feature Id\", yaxis=\"Prompt Idx\", x=list(map(str, all_live_features.tolist()))\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Reset SAEs"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "One major footgun is forgetting about an SAE that you previously attached with `add_sae`. Similar to TransformerLens `reset_hooks`, you can always reset SAEs you've added with `model.reset_saes()`. You can also pass in a list of activation names to only reset a subset of attached SAEs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Attached SAEs before reset_saes: {'blocks.5.attn.hook_z': HookedSAE(\n",
+ " (hook_sae_input): HookPoint()\n",
+ " (hook_sae_acts_pre): HookPoint()\n",
+ " (hook_sae_acts_post): HookPoint()\n",
+ " (hook_sae_recons): HookPoint()\n",
+ " (hook_sae_error): HookPoint()\n",
+ " (hook_sae_output): HookPoint()\n",
+ ")}\n",
+ "Attached SAEs after reset_saes: {}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"Attached SAEs before reset_saes:\", model.acts_to_saes)\n",
+ "model.reset_saes()\n",
+ "print(\"Attached SAEs after reset_saes:\", model.acts_to_saes)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Note that the HookedSAETransformer API is generally designed to closely match TransformerLens hooks API."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Error Nodes"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Recent exciting work from [Marks et al.](https://arxiv.org/abs/2403.19647v2) demonstrated the use of \"error nodes\" in SAE circuit analysis. The idea is that for some input activation x, SAE(x) = x_reconstruct is an approximation of x, but we can define an error_term such that x = x_reconstruct + error_term.\n",
+ "\n",
+ "This seems useful: instead of replacing x with x_reconstruct, which might break everything and make our circuit analysis janky, we can just re-write x as a function of the SAE features, bias, and error term, which gives us access to all of the SAE features but without breaking performance. \n",
+ "\n",
+ "Additionally, we can compare interventions on SAE features to the same intervention on the error term to get a better sense of how much the SAE features have actually captured.\n",
+ "\n",
+ "To use error terms with HookedSAEs, you can set `hooked_sae.cfg.use_error_term = True`, or initialize it to True in the config. Note HookedSAEConfig sets this to False by default."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Attached SAEs after adding l5_sae_with_error: {'blocks.5.attn.hook_z': HookedSAE(\n",
+ " (hook_sae_input): HookPoint()\n",
+ " (hook_sae_acts_pre): HookPoint()\n",
+ " (hook_sae_acts_post): HookPoint()\n",
+ " (hook_sae_recons): HookPoint()\n",
+ " (hook_sae_error): HookPoint()\n",
+ " (hook_sae_output): HookPoint()\n",
+ ")}\n"
+ ]
+ }
+ ],
+ "source": [
+ "import copy\n",
+ "l5_sae = hook_name_to_sae[utils.get_act_name('z', 5)]\n",
+ "l5_sae_with_error = copy.deepcopy(l5_sae)\n",
+ "l5_sae_with_error.cfg.use_error_term=True\n",
+ "model.add_sae(l5_sae_with_error)\n",
+ "print(\"Attached SAEs after adding l5_sae_with_error:\", model.acts_to_saes)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now the output of each attached SAE will be SAE(x) + error_term = x. We can sanity check this by confirming that running with SAEs produces the same logits without SAEs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "logits_with_saes = model(tokens)\n",
+ "logit_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens)\n",
+ "\n",
+ "assert torch.allclose(logits_with_saes, original_logits, atol=1e-4)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now we can compare ablations of each feature to ablating the error node. We'll start by ablating each feature on each prompt, and then the error nodes. We'll append the effects from ablating error nodes to the rightmost column on the heatmap:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|āāāāāāāāāā| 141/141 [00:04<00:00, 32.33it/s]\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.plotly.v1+json": {
+ "config": {
+ "plotlyServerURL": "https://plot.ly"
+ },
+ "data": [
+ {
+ "coloraxis": "coloraxis",
+ "hovertemplate": "Feature Idx: %{x}
Prompt Idx: %{y}
color: %{z}",
+ "name": "0",
+ "type": "heatmap",
+ "x": [
+ "46",
+ "345",
+ "702",
+ "1372",
+ "1755",
+ "1965",
+ "2457",
+ "2496",
+ "2646",
+ "2999",
+ "3047",
+ "4569",
+ "5132",
+ "5203",
+ "5508",
+ "5940",
+ "6144",
+ "6371",
+ "6515",
+ "6558",
+ "6812",
+ "7092",
+ "7515",
+ "7907",
+ "8063",
+ "8623",
+ "8737",
+ "8768",
+ "9096",
+ "9102",
+ "9186",
+ "9463",
+ "9746",
+ "9913",
+ "10581",
+ "10894",
+ "12109",
+ "12485",
+ "12764",
+ "12866",
+ "13063",
+ "13624",
+ "13707",
+ "13777",
+ "14844",
+ "15050",
+ "15170",
+ "15696",
+ "16178",
+ "16892",
+ "17156",
+ "17259",
+ "17497",
+ "17854",
+ "18043",
+ "18210",
+ "18318",
+ "18385",
+ "18440",
+ "18920",
+ "19183",
+ "19263",
+ "19442",
+ "19524",
+ "19573",
+ "20838",
+ "21151",
+ "21657",
+ "22108",
+ "23578",
+ "24091",
+ "24217",
+ "25792",
+ "26373",
+ "26410",
+ "27535",
+ "27787",
+ "27811",
+ "27960",
+ "28061",
+ "28241",
+ "28242",
+ "28254",
+ "28349",
+ "28977",
+ "29027",
+ "29482",
+ "29603",
+ "29700",
+ "29822",
+ "32177",
+ "32920",
+ "33320",
+ "33730",
+ "33966",
+ "34177",
+ "34334",
+ "34947",
+ "35403",
+ "35425",
+ "35579",
+ "35665",
+ "35815",
+ "36109",
+ "36172",
+ "36451",
+ "36767",
+ "36917",
+ "38570",
+ "39962",
+ "40409",
+ "40418",
+ "40661",
+ "41162",
+ "41185",
+ "41552",
+ "42024",
+ "42161",
+ "42437",
+ "42577",
+ "42882",
+ "42931",
+ "43035",
+ "43414",
+ "43643",
+ "43662",
+ "44203",
+ "44256",
+ "44452",
+ "44652",
+ "45179",
+ "45814",
+ "45984",
+ "46880",
+ "47117",
+ "47170",
+ "47231",
+ "47313",
+ "47680",
+ "48063",
+ "48703",
+ "error"
+ ],
+ "xaxis": "x",
+ "yaxis": "y",
+ "z": [
+ [
+ 0.0012617111206054688,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ 0.0016908645629882812,
+ -0.0002231597900390625,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.00029659271240234375,
+ -9.5367431640625e-7,
+ -0.03279590606689453,
+ -0.07254886627197266,
+ -9.5367431640625e-7,
+ 0.00013065338134765625,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.014922142028808594,
+ -0.0044403076171875,
+ 0.0007047653198242188,
+ -0.00428009033203125,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.039069175720214844,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.007334709167480469,
+ 0.00033092498779296875,
+ -0.0017004013061523438,
+ 0.0026845932006835938,
+ 0.00043010711669921875,
+ -0.11128997802734375,
+ -0.0038976669311523438,
+ -0.006033897399902344,
+ -9.5367431640625e-7,
+ -0.00027751922607421875,
+ -9.5367431640625e-7,
+ 0.0006570816040039062,
+ -0.0004291534423828125,
+ -9.5367431640625e-7,
+ -0.0035734176635742188,
+ -0.0033063888549804688,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ 0.0033960342407226562,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.0030546188354492188,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.0000972747802734375,
+ -0.0001811981201171875,
+ -9.5367431640625e-7,
+ -0.004569053649902344,
+ -0.013583183288574219,
+ -9.5367431640625e-7,
+ 0.02047252655029297,
+ -0.02572154998779297,
+ -9.5367431640625e-7,
+ -0.0006608963012695312,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ 0.02255725860595703,
+ -0.05519580841064453,
+ -0.0033473968505859375,
+ -0.0000057220458984375,
+ -0.0026073455810546875,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.02097320556640625,
+ 0.008440971374511719,
+ -9.5367431640625e-7,
+ -0.004597663879394531,
+ 0.00159454345703125,
+ 0.0001544952392578125,
+ 0.005199432373046875,
+ 0.0007762908935546875,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.0032625198364257812,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.015192985534667969,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.018138885498046875,
+ -9.5367431640625e-7,
+ 0.010298728942871094,
+ -9.5367431640625e-7,
+ -0.0031423568725585938,
+ -9.5367431640625e-7,
+ 0.004242897033691406,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.010041236877441406,
+ 0.0010347366333007812,
+ 0.006011962890625,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ 0.00301361083984375,
+ -0.04584026336669922,
+ 0.0002079010009765625,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.0002574920654296875,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.45942211151123047,
+ -0.0008325576782226562,
+ 0.00041484832763671875,
+ -9.5367431640625e-7,
+ -0.023777008056640625,
+ 0.0000514984130859375,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.00030422210693359375,
+ 0.0006666183471679688,
+ -9.5367431640625e-7,
+ 0.004633903503417969,
+ -9.5367431640625e-7,
+ -0.008234977722167969,
+ -0.07327461242675781
+ ],
+ [
+ 0.000003814697265625,
+ 0.000003814697265625,
+ -0.00208282470703125,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ -0.0012912750244140625,
+ -0.01760101318359375,
+ 0.000003814697265625,
+ 0.057277679443359375,
+ 0.013429641723632812,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ -0.0000457763671875,
+ -0.0027828216552734375,
+ -0.0055084228515625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ -0.2744255065917969,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.0021514892578125,
+ 0.000003814697265625,
+ 0.06994247436523438,
+ 0.0048542022705078125,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ -0.0567169189453125,
+ 0.000003814697265625,
+ 0.012315750122070312,
+ 0.0066585540771484375,
+ 0.07937240600585938,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.028867721557617188,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.0074901580810546875,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.009624481201171875,
+ 0.000003814697265625,
+ -0.009510040283203125,
+ 0.0032100677490234375,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.10918617248535156,
+ 0.000003814697265625,
+ 0.026102066040039062,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000946044921875,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ -0.041675567626953125,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ -0.0066776275634765625,
+ 0.03926849365234375,
+ 0.03615379333496094,
+ 0.027612686157226562,
+ 0.000003814697265625,
+ -0.0004673004150390625,
+ -0.1435985565185547,
+ -0.00030517578125,
+ 0.059326171875,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.020435333251953125,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ -0.11923980712890625,
+ -0.009393692016601562,
+ 0.000003814697265625,
+ 0.011783599853515625,
+ 0.06122589111328125,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.0002918243408203125,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.001491546630859375,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ -0.0050716400146484375,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.025064468383789062,
+ 0.000003814697265625,
+ -0.0467529296875,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.0014934539794921875,
+ 0.00043487548828125,
+ 0.028188705444335938,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.001995086669921875,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.13014602661132812,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.0005893707275390625,
+ 0.012182235717773438,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.11103057861328125,
+ 0.042850494384765625,
+ 0.030099868774414062,
+ 0.000003814697265625,
+ -0.0047321319580078125,
+ 0.0000133514404296875,
+ -0.0320587158203125,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.031030654907226562,
+ -0.002635955810546875
+ ],
+ [
+ 0.007018089294433594,
+ 0,
+ 0,
+ 0.0028057098388671875,
+ 0,
+ 0,
+ 0,
+ -0.010999679565429688,
+ 0,
+ -0.1419973373413086,
+ -0.24188613891601562,
+ 0,
+ 0.0003147125244140625,
+ 0,
+ 0,
+ 0,
+ 0.009432792663574219,
+ 0,
+ -0.000125885009765625,
+ 0,
+ 0,
+ 0.00017070770263671875,
+ 0.011651992797851562,
+ -0.00225830078125,
+ 0,
+ -0.0014581680297851562,
+ 0.00020122528076171875,
+ -0.030771255493164062,
+ -0.03744316101074219,
+ -0.034499168395996094,
+ -0.00374603271484375,
+ 0,
+ 0.0011348724365234375,
+ 0,
+ -0.0302276611328125,
+ -0.08229637145996094,
+ -0.00048160552978515625,
+ 0,
+ -0.00640869140625,
+ 0.0001277923583984375,
+ 0,
+ -0.0008974075317382812,
+ 0.00022983551025390625,
+ -0.2322559356689453,
+ -0.0050449371337890625,
+ -0.010677337646484375,
+ 0,
+ 0.014942169189453125,
+ 0,
+ 0.0008764266967773438,
+ 0.00417327880859375,
+ 0,
+ -0.015301704406738281,
+ 0,
+ -0.0008974075317382812,
+ -0.04426097869873047,
+ 0.005242347717285156,
+ 0,
+ 0,
+ 0,
+ -0.009447097778320312,
+ 0,
+ 0,
+ -0.0011806488037109375,
+ 0,
+ -0.0045909881591796875,
+ 0.015285491943359375,
+ -0.034976959228515625,
+ -0.013401985168457031,
+ 0,
+ 0.1357421875,
+ -0.09111690521240234,
+ 0,
+ 0.00013065338134765625,
+ 0.0002460479736328125,
+ 0,
+ 0.04656982421875,
+ -0.09346866607666016,
+ 0,
+ -0.005030632019042969,
+ 0.0001125335693359375,
+ 0,
+ 0,
+ 0,
+ -0.07491683959960938,
+ 0.006598472595214844,
+ 0,
+ -0.014060020446777344,
+ -0.008306503295898438,
+ -0.0054874420166015625,
+ -0.0004930496215820312,
+ 0,
+ 0,
+ 0,
+ -0.008953094482421875,
+ 0,
+ 0,
+ 0,
+ -0.03713417053222656,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.028200149536132812,
+ 0,
+ 0.036255836486816406,
+ 0,
+ -0.03178215026855469,
+ 0,
+ -0.012192726135253906,
+ -0.002147674560546875,
+ -0.0005474090576171875,
+ -0.0021409988403320312,
+ -0.030725479125976562,
+ 0,
+ 0.0008029937744140625,
+ 0,
+ 0,
+ 0,
+ -0.29135894775390625,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0027914047241210938,
+ -0.00022125244140625,
+ -0.8653240203857422,
+ 0,
+ -0.05593109130859375,
+ 0,
+ -0.04123210906982422,
+ 0,
+ 0,
+ 0.015351295471191406,
+ 0,
+ 0,
+ 0,
+ 0.018423080444335938,
+ -0.0000476837158203125,
+ -0.0023584365844726562,
+ -0.3282146453857422
+ ],
+ [
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.0001983642578125,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ -0.010341644287109375,
+ 0.07198715209960938,
+ 0.14725303649902344,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.0002918243408203125,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.011704444885253906,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ -0.3150959014892578,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ -0.039947509765625,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.15607547760009766,
+ 9.5367431640625e-7,
+ 0.09917640686035156,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.019521713256835938,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.012205123901367188,
+ 9.5367431640625e-7,
+ -0.0005893707275390625,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.07062149047851562,
+ 0.000492095947265625,
+ 0.014776229858398438,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.0557098388671875,
+ 0.15409469604492188,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ -0.0007076263427734375,
+ -0.24256324768066406,
+ 9.5367431640625e-7,
+ 0.0858917236328125,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.007343292236328125,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ -0.11646080017089844,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.05528736114501953,
+ 0.0847921371459961,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.00428009033203125,
+ -0.0056171417236328125,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.0066967010498046875,
+ 9.5367431640625e-7,
+ -0.006005287170410156,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.01735687255859375,
+ -0.0037336349487304688,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.09533309936523438,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.009324073791503906,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.007989883422851562,
+ 9.5367431640625e-7,
+ 0.0064525604248046875,
+ 9.5367431640625e-7,
+ -0.06574440002441406,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.5591859817504883
+ ],
+ [
+ -0.0009012222290039062,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.0006313323974609375,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.000461578369140625,
+ 0.00001239776611328125,
+ -0.055993080139160156,
+ -0.24974536895751953,
+ 0.00001239776611328125,
+ 0.0011262893676757812,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.0025796890258789062,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.030013084411621094,
+ -0.012925148010253906,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.0253448486328125,
+ 0.00001239776611328125,
+ 0.0012464523315429688,
+ 0.021536827087402344,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.00009822845458984375,
+ 0.00001239776611328125,
+ -0.09924793243408203,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.006188392639160156,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.0010576248168945312,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.0008172988891601562,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.0020704269409179688,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.09985160827636719,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.036945343017578125,
+ 0.025011062622070312,
+ 0.00001239776611328125,
+ 0.004599571228027344,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.027939796447753906,
+ -0.07974910736083984,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.00038242340087890625,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.035175323486328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.0047245025634765625,
+ -0.008166313171386719,
+ -0.008578300476074219,
+ 0.0018529891967773438,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.0016679763793945312,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.0028676986694335938,
+ -0.04880046844482422,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.0053462982177734375,
+ 0.00001239776611328125,
+ 0.1658468246459961,
+ 0.00001239776611328125,
+ -0.0024824142456054688,
+ 0.00001239776611328125,
+ 0.025139808654785156,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.027915000915527344,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.14544200897216797,
+ 0.00001239776611328125,
+ 0.020270347595214844,
+ 0.007473945617675781,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.8424196243286133,
+ 0.00001239776611328125,
+ -0.007409095764160156,
+ -0.00318145751953125,
+ -0.015982627868652344,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00034046173095703125,
+ 0.10727787017822266,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.5388059616088867
+ ],
+ [
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.0019397735595703125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.022940635681152344,
+ 0.07428932189941406,
+ 0.29994869232177734,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.016974449157714844,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.4772310256958008,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.05463600158691406,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.004734992980957031,
+ -0.12352275848388672,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.15236186981201172,
+ -0.00000762939453125,
+ 0.27855396270751953,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.001430511474609375,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.016387939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.0008668899536132812,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.06814861297607422,
+ -0.00000762939453125,
+ 0.00351715087890625,
+ -0.00000762939453125,
+ 0.0061588287353515625,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.02480602264404297,
+ 0.31668567657470703,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.4413900375366211,
+ -0.00000762939453125,
+ 0.1517791748046875,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.010898590087890625,
+ -0.00000762939453125,
+ 0.006583213806152344,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.04946422576904297,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.040429115295410156,
+ 0.1020956039428711,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.0008649826049804688,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.0054836273193359375,
+ -0.010519981384277344,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.004588127136230469,
+ -0.006558418273925781,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.07750797271728516,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.03235149383544922,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.02773571014404297,
+ 0.08978557586669922,
+ -0.00000762939453125,
+ 0.008780479431152344,
+ -0.00000762939453125,
+ -0.0327301025390625,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.035370826721191406,
+ 0.19881343841552734
+ ],
+ [
+ 0.0041904449462890625,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ 0.008458137512207031,
+ -0.0000057220458984375,
+ -0.0042858123779296875,
+ -0.0000057220458984375,
+ 0.002468109130859375,
+ -0.0000057220458984375,
+ -0.03716564178466797,
+ -0.10456657409667969,
+ -0.0047702789306640625,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ 0.017671585083007812,
+ 0.0004062652587890625,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.00655364990234375,
+ 0.001873016357421875,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.04653644561767578,
+ -0.01836395263671875,
+ 0.014448165893554688,
+ -0.0209197998046875,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.006618499755859375,
+ 0.02408599853515625,
+ -0.0000057220458984375,
+ 0.0012884140014648438,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0050144195556640625,
+ -0.0000057220458984375,
+ 0.036708831787109375,
+ -0.0000057220458984375,
+ -0.0056591033935546875,
+ -0.0004215240478515625,
+ 0.0014057159423828125,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.033232688903808594,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.008130073547363281,
+ 0.016930580139160156,
+ -0.0000057220458984375,
+ -0.0012025833129882812,
+ -0.0000057220458984375,
+ 0.000545501708984375,
+ -0.0000057220458984375,
+ -0.0004673004150390625,
+ -0.0000057220458984375,
+ 0.0038089752197265625,
+ -0.008646011352539062,
+ -0.0000057220458984375,
+ -0.008909225463867188,
+ -0.011255264282226562,
+ -0.0000057220458984375,
+ 0.0925750732421875,
+ -0.0064563751220703125,
+ -0.0000057220458984375,
+ 0.0011615753173828125,
+ 0.00002956390380859375,
+ -0.0000057220458984375,
+ 0.07063961029052734,
+ -0.030902862548828125,
+ -0.0000057220458984375,
+ -0.0010814666748046875,
+ 0.00038909912109375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0013380050659179688,
+ -0.022397994995117188,
+ 0.027740478515625,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ 0.01797771453857422,
+ 0.009552955627441406,
+ 0.01857471466064453,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0055103302001953125,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.01697063446044922,
+ -0.0159149169921875,
+ -0.011240959167480469,
+ 0.000301361083984375,
+ 0.020501136779785156,
+ -0.0000057220458984375,
+ -0.0006427764892578125,
+ -0.0000057220458984375,
+ 0.04800224304199219,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.00001811981201171875,
+ 0.0005950927734375,
+ 0.00732421875,
+ -0.0000057220458984375,
+ 0.001216888427734375,
+ 0.00897216796875,
+ -0.1255035400390625,
+ 0.001003265380859375,
+ 0.006274223327636719,
+ 0.0026502609252929688,
+ -0.00449371337890625,
+ -0.0023517608642578125,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.6521244049072266,
+ -0.009072303771972656,
+ 0.013387680053710938,
+ -0.0000057220458984375,
+ -0.022745132446289062,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ 0.000606536865234375,
+ -0.0011501312255859375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ 0.023046493530273438,
+ -0.0000057220458984375,
+ -0.008263587951660156,
+ -0.11597061157226562
+ ],
+ [
+ -0.0037221908569335938,
+ 0.00225830078125,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.05941295623779297,
+ 0.04140281677246094,
+ 0.24284648895263672,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.11462688446044922,
+ 0.012240409851074219,
+ 0.0012884140014648438,
+ -0.00001049041748046875,
+ 0.01781749725341797,
+ 0.005211830139160156,
+ -0.0016298294067382812,
+ -0.2994966506958008,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ 0.26962947845458984,
+ -0.00001049041748046875,
+ 0.050202369689941406,
+ 0.04053211212158203,
+ -0.30355358123779297,
+ -0.00001049041748046875,
+ -0.0013666152954101562,
+ -0.00001049041748046875,
+ 0.06442928314208984,
+ -0.00001049041748046875,
+ 0.04406547546386719,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00763702392578125,
+ -0.00001049041748046875,
+ -0.03402233123779297,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.005751609802246094,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ 0.13191986083984375,
+ -0.00001049041748046875,
+ -0.031653404235839844,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.02394390106201172,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.1073293685913086,
+ 0.20270729064941406,
+ 0.02746295928955078,
+ -0.00001049041748046875,
+ 0.020377159118652344,
+ -0.31055259704589844,
+ -0.043480873107910156,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ 0.04507160186767578,
+ -0.0014734268188476562,
+ -0.00001049041748046875,
+ 0.048813819885253906,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.1407604217529297,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ 0.013548851013183594,
+ 0.016210556030273438,
+ -0.00001049041748046875,
+ -0.011261940002441406,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00029277801513671875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ 0.008993148803710938,
+ -0.020813941955566406,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.008435249328613281,
+ -0.021961212158203125,
+ -0.04410362243652344,
+ 0.1307668685913086,
+ 0.005297660827636719,
+ -0.00001049041748046875,
+ 0.006031990051269531,
+ 0.016150474548339844,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ 0.01802349090576172,
+ 0.0018205642700195312,
+ 0.0016574859619140625,
+ 0.0005712509155273438,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ 0.02598094940185547,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ 0.02737903594970703,
+ 0.039580345153808594,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ 0.09876728057861328,
+ 0.035803794860839844,
+ -0.00001049041748046875,
+ -0.027251243591308594,
+ -0.00001049041748046875,
+ -0.07061004638671875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ 0.08719158172607422,
+ -0.37606334686279297
+ ]
+ ]
+ }
+ ],
+ "layout": {
+ "coloraxis": {
+ "cmid": 0,
+ "colorscale": [
+ [
+ 0,
+ "rgb(103,0,31)"
+ ],
+ [
+ 0.1,
+ "rgb(178,24,43)"
+ ],
+ [
+ 0.2,
+ "rgb(214,96,77)"
+ ],
+ [
+ 0.3,
+ "rgb(244,165,130)"
+ ],
+ [
+ 0.4,
+ "rgb(253,219,199)"
+ ],
+ [
+ 0.5,
+ "rgb(247,247,247)"
+ ],
+ [
+ 0.6,
+ "rgb(209,229,240)"
+ ],
+ [
+ 0.7,
+ "rgb(146,197,222)"
+ ],
+ [
+ 0.8,
+ "rgb(67,147,195)"
+ ],
+ [
+ 0.9,
+ "rgb(33,102,172)"
+ ],
+ [
+ 1,
+ "rgb(5,48,97)"
+ ]
+ ]
+ },
+ "template": {
+ "data": {
+ "bar": [
+ {
+ "error_x": {
+ "color": "#2a3f5f"
+ },
+ "error_y": {
+ "color": "#2a3f5f"
+ },
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "bar"
+ }
+ ],
+ "barpolar": [
+ {
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "barpolar"
+ }
+ ],
+ "carpet": [
+ {
+ "aaxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "baxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "type": "carpet"
+ }
+ ],
+ "choropleth": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "choropleth"
+ }
+ ],
+ "contour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "contour"
+ }
+ ],
+ "contourcarpet": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "contourcarpet"
+ }
+ ],
+ "heatmap": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmap"
+ }
+ ],
+ "heatmapgl": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmapgl"
+ }
+ ],
+ "histogram": [
+ {
+ "marker": {
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "histogram"
+ }
+ ],
+ "histogram2d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2d"
+ }
+ ],
+ "histogram2dcontour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2dcontour"
+ }
+ ],
+ "mesh3d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "mesh3d"
+ }
+ ],
+ "parcoords": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "parcoords"
+ }
+ ],
+ "pie": [
+ {
+ "automargin": true,
+ "type": "pie"
+ }
+ ],
+ "scatter": [
+ {
+ "fillpattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ },
+ "type": "scatter"
+ }
+ ],
+ "scatter3d": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatter3d"
+ }
+ ],
+ "scattercarpet": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattercarpet"
+ }
+ ],
+ "scattergeo": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergeo"
+ }
+ ],
+ "scattergl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergl"
+ }
+ ],
+ "scattermapbox": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattermapbox"
+ }
+ ],
+ "scatterpolar": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolar"
+ }
+ ],
+ "scatterpolargl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolargl"
+ }
+ ],
+ "scatterternary": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterternary"
+ }
+ ],
+ "surface": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "surface"
+ }
+ ],
+ "table": [
+ {
+ "cells": {
+ "fill": {
+ "color": "#EBF0F8"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "header": {
+ "fill": {
+ "color": "#C8D4E3"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "type": "table"
+ }
+ ]
+ },
+ "layout": {
+ "annotationdefaults": {
+ "arrowcolor": "#2a3f5f",
+ "arrowhead": 0,
+ "arrowwidth": 1
+ },
+ "autotypenumbers": "strict",
+ "coloraxis": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "colorscale": {
+ "diverging": [
+ [
+ 0,
+ "#8e0152"
+ ],
+ [
+ 0.1,
+ "#c51b7d"
+ ],
+ [
+ 0.2,
+ "#de77ae"
+ ],
+ [
+ 0.3,
+ "#f1b6da"
+ ],
+ [
+ 0.4,
+ "#fde0ef"
+ ],
+ [
+ 0.5,
+ "#f7f7f7"
+ ],
+ [
+ 0.6,
+ "#e6f5d0"
+ ],
+ [
+ 0.7,
+ "#b8e186"
+ ],
+ [
+ 0.8,
+ "#7fbc41"
+ ],
+ [
+ 0.9,
+ "#4d9221"
+ ],
+ [
+ 1,
+ "#276419"
+ ]
+ ],
+ "sequential": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "sequentialminus": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ]
+ },
+ "colorway": [
+ "#636efa",
+ "#EF553B",
+ "#00cc96",
+ "#ab63fa",
+ "#FFA15A",
+ "#19d3f3",
+ "#FF6692",
+ "#B6E880",
+ "#FF97FF",
+ "#FECB52"
+ ],
+ "font": {
+ "color": "#2a3f5f"
+ },
+ "geo": {
+ "bgcolor": "white",
+ "lakecolor": "white",
+ "landcolor": "#E5ECF6",
+ "showlakes": true,
+ "showland": true,
+ "subunitcolor": "white"
+ },
+ "hoverlabel": {
+ "align": "left"
+ },
+ "hovermode": "closest",
+ "mapbox": {
+ "style": "light"
+ },
+ "paper_bgcolor": "white",
+ "plot_bgcolor": "#E5ECF6",
+ "polar": {
+ "angularaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "radialaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "scene": {
+ "xaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "yaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "zaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ }
+ },
+ "shapedefaults": {
+ "line": {
+ "color": "#2a3f5f"
+ }
+ },
+ "ternary": {
+ "aaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "baxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "caxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "title": {
+ "x": 0.05
+ },
+ "xaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ },
+ "yaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ }
+ }
+ },
+ "title": {
+ "text": "Change in logit diff when ablating L5 SAE features for all prompts at pos 10"
+ },
+ "xaxis": {
+ "anchor": "y",
+ "constrain": "domain",
+ "domain": [
+ 0,
+ 1
+ ],
+ "scaleanchor": "y",
+ "title": {
+ "text": "Feature Idx"
+ }
+ },
+ "yaxis": {
+ "anchor": "x",
+ "autorange": "reversed",
+ "constrain": "domain",
+ "domain": [
+ 0,
+ 1
+ ],
+ "title": {
+ "text": "Prompt Idx"
+ }
+ }
+ }
+ },
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "def ablate_sae_feature(sae_acts, hook, pos, feature_id):\n",
+ " if pos is None:\n",
+ " sae_acts[:, :, feature_id] = 0.\n",
+ " else:\n",
+ " sae_acts[:, pos, feature_id] = 0.\n",
+ " return sae_acts\n",
+ "\n",
+ "layer = 5\n",
+ "hooked_encoder = model.acts_to_saes[utils.get_act_name('z', layer)]\n",
+ "all_live_features = torch.arange(hooked_encoder.cfg.d_sae)[live_feature_union.cpu()]\n",
+ "\n",
+ "causal_effects = torch.zeros((len(prompts), all_live_features.shape[0]))\n",
+ "fid_to_idx = {fid.item(): idx for idx, fid in enumerate(all_live_features)}\n",
+ "\n",
+ "\n",
+ "abl_layer, abl_pos = 5, 10\n",
+ "for feature_id in tqdm.tqdm(all_live_features):\n",
+ " feature_id = feature_id.item()\n",
+ " abl_feature_logits = model.run_with_hooks(\n",
+ " tokens,\n",
+ " return_type=\"logits\",\n",
+ " fwd_hooks=[(utils.get_act_name('z', abl_layer) + \".hook_sae_acts_post\", partial(ablate_sae_feature, pos=abl_pos, feature_id=feature_id))]\n",
+ " ) # [batch, seq, vocab]\n",
+ " \n",
+ " abl_feature_logit_diff = logits_to_ave_logit_diff(abl_feature_logits, answer_tokens, per_prompt=True) # [batch]\n",
+ " causal_effects[:, fid_to_idx[feature_id]] = abl_feature_logit_diff - original_per_prompt_logit_diff\n",
+ "\n",
+ "def able_sae_error(sae_error, hook, pos):\n",
+ " if pos is None:\n",
+ " sae_error = 0.\n",
+ " else:\n",
+ " sae_error[:, pos, ...] = 0.\n",
+ " return sae_error\n",
+ "\n",
+ "\n",
+ "abl_error_logits = model.run_with_hooks(\n",
+ " tokens,\n",
+ " return_type=\"logits\",\n",
+ " fwd_hooks=[(utils.get_act_name('z', abl_layer) + \".hook_sae_error\", partial(able_sae_error, pos=abl_pos))]\n",
+ ") # [batch, seq, vocab]\n",
+ "\n",
+ "abl_error_logit_diff = logits_to_ave_logit_diff(abl_error_logits, answer_tokens, per_prompt=True) # [batch]\n",
+ "error_abl_effect = abl_error_logit_diff - original_per_prompt_logit_diff\n",
+ "\n",
+ "\n",
+ "causal_effects_with_error = torch.cat([causal_effects, error_abl_effect.unsqueeze(-1).cpu()], dim=-1)\n",
+ "imshow(causal_effects_with_error, title=f\"Change in logit diff when ablating L{abl_layer} SAE features for all prompts at pos {abl_pos}\", xaxis=\"Feature Idx\", yaxis=\"Prompt Idx\", x=list(map(str, all_live_features.tolist()))+[\"error\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can see that on some prompts, ablating the error term (right most column) does have a non trivial effect on the logit diff, although I don't see a clear pattern. It seems useful to include this term when doing causal interventions to get a better sense of how much the SAE features are actually explaining. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Attribution patching "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "Both [Anthropic](https://transformer-circuits.pub/2024/march-update/index.html#feature-heads) and [Marks et al](https://arxiv.org/abs/2403.19647v2). also demonstrated the use of gradient based attribution techniques as a substitute for activation patching on SAE features. The key idea is that patching / ablations (as we did above) can be slow, as it requires a new forward pass for each patch. This seems especially problematic when dealing with SAEs with tens of thousands of features per activation. They find that gradient based attribution techniques like [attribution patching](https://www.neelnanda.io/mechanistic-interpretability/attribution-patching) are good approximations, allowing for more efficient and scalable circuit analysis with SAEs.\n",
+ "\n",
+ "With `HookedSAETransformer`, added SAEs are automatically spliced into the computational graph, allowing us to implement this easily. Let's implement attribution patching for every L5 SAE feature to find causally relevant SAE features with just one forward and one backward pass."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "torch.set_grad_enabled(True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "tensor(-7.6294e-06, device='cuda:0', grad_fn=)\n",
+ "Clean Value: -7.62939453125e-06\n",
+ "Clean Activations Cached: 1\n",
+ "Clean Gradients Cached: 1\n"
+ ]
+ }
+ ],
+ "source": [
+ "from transformer_lens import ActivationCache\n",
+ "filter_sae_acts = lambda name: (\"hook_sae_acts_post\" in name)\n",
+ "def get_cache_fwd_and_bwd(model, tokens, metric):\n",
+ " model.reset_hooks()\n",
+ " cache = {}\n",
+ " def forward_cache_hook(act, hook):\n",
+ " cache[hook.name] = act.detach()\n",
+ " model.add_hook(filter_sae_acts, forward_cache_hook, \"fwd\")\n",
+ "\n",
+ " grad_cache = {}\n",
+ " def backward_cache_hook(act, hook):\n",
+ " grad_cache[hook.name] = act.detach()\n",
+ " model.add_hook(filter_sae_acts, backward_cache_hook, \"bwd\")\n",
+ "\n",
+ " value = metric(model(tokens))\n",
+ " print(value)\n",
+ " value.backward()\n",
+ " model.reset_hooks()\n",
+ " return value.item(), ActivationCache(cache, model), ActivationCache(grad_cache, model)\n",
+ "\n",
+ "\n",
+ "BASELINE = original_per_prompt_logit_diff\n",
+ "def ioi_metric(logits, answer_tokens=answer_tokens):\n",
+ " return (logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=True) - BASELINE).sum()\n",
+ "\n",
+ "clean_tokens = tokens.clone()\n",
+ "clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(model, clean_tokens, ioi_metric)\n",
+ "print(\"Clean Value:\", clean_value)\n",
+ "print(\"Clean Activations Cached:\", len(clean_cache))\n",
+ "print(\"Clean Gradients Cached:\", len(clean_grad_cache))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.plotly.v1+json": {
+ "config": {
+ "plotlyServerURL": "https://plot.ly"
+ },
+ "data": [
+ {
+ "coloraxis": "coloraxis",
+ "hovertemplate": "Feature Idx: %{x}
Prompt Idx: %{y}
color: %{z}",
+ "name": "0",
+ "type": "heatmap",
+ "x": [
+ "46",
+ "345",
+ "702",
+ "1372",
+ "1755",
+ "1965",
+ "2457",
+ "2496",
+ "2646",
+ "2999",
+ "3047",
+ "4569",
+ "5132",
+ "5203",
+ "5508",
+ "5940",
+ "6144",
+ "6371",
+ "6515",
+ "6558",
+ "6812",
+ "7092",
+ "7515",
+ "7907",
+ "8063",
+ "8623",
+ "8737",
+ "8768",
+ "9096",
+ "9102",
+ "9186",
+ "9463",
+ "9746",
+ "9913",
+ "10581",
+ "10894",
+ "12109",
+ "12485",
+ "12764",
+ "12866",
+ "13063",
+ "13624",
+ "13707",
+ "13777",
+ "14844",
+ "15050",
+ "15170",
+ "15696",
+ "16178",
+ "16892",
+ "17156",
+ "17259",
+ "17497",
+ "17854",
+ "18043",
+ "18210",
+ "18318",
+ "18385",
+ "18440",
+ "18920",
+ "19183",
+ "19263",
+ "19442",
+ "19524",
+ "19573",
+ "20838",
+ "21151",
+ "21657",
+ "22108",
+ "23578",
+ "24091",
+ "24217",
+ "25792",
+ "26373",
+ "26410",
+ "27535",
+ "27787",
+ "27811",
+ "27960",
+ "28061",
+ "28241",
+ "28242",
+ "28254",
+ "28349",
+ "28977",
+ "29027",
+ "29482",
+ "29603",
+ "29700",
+ "29822",
+ "32177",
+ "32920",
+ "33320",
+ "33730",
+ "33966",
+ "34177",
+ "34334",
+ "34947",
+ "35403",
+ "35425",
+ "35579",
+ "35665",
+ "35815",
+ "36109",
+ "36172",
+ "36451",
+ "36767",
+ "36917",
+ "38570",
+ "39962",
+ "40409",
+ "40418",
+ "40661",
+ "41162",
+ "41185",
+ "41552",
+ "42024",
+ "42161",
+ "42437",
+ "42577",
+ "42882",
+ "42931",
+ "43035",
+ "43414",
+ "43643",
+ "43662",
+ "44203",
+ "44256",
+ "44452",
+ "44652",
+ "45179",
+ "45814",
+ "45984",
+ "46880",
+ "47117",
+ "47170",
+ "47231",
+ "47313",
+ "47680",
+ "48063",
+ "48703"
+ ],
+ "xaxis": "x",
+ "yaxis": "y",
+ "z": [
+ [
+ 0.001567811705172062,
+ 0,
+ 0,
+ 0.001697835512459278,
+ 0.00011560246639419347,
+ 0,
+ 0,
+ -0.0002851475146599114,
+ 0,
+ -0.030827227979898453,
+ -0.06409652531147003,
+ 0,
+ 0.00015167289529927075,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.013627146370708942,
+ -0.004393726587295532,
+ 0.0015328703448176384,
+ -0.0038613511715084314,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.02049136720597744,
+ 0,
+ 0,
+ -0.007114107254892588,
+ 0.0003477374848444015,
+ -0.001384311355650425,
+ 0.003183899214491248,
+ 0.0004558839718811214,
+ -0.059277813881635666,
+ -0.0035793157294392586,
+ -0.00589390005916357,
+ 0,
+ -0.0001910730206873268,
+ 0,
+ 0.0006608504336327314,
+ -0.0004212319909129292,
+ 0,
+ -0.003545185085386038,
+ -0.00327106611803174,
+ 0,
+ 0,
+ 0.0040074847638607025,
+ 0,
+ 0,
+ 0,
+ -0.0026069351006299257,
+ 0,
+ 0,
+ 0,
+ -0.00008433026232523844,
+ -0.00018646706303115934,
+ 0,
+ -0.00439279293641448,
+ -0.013254894874989986,
+ 0,
+ 0.050094299018383026,
+ -0.021308520808815956,
+ 0,
+ -0.0006410681526176631,
+ 0,
+ 0,
+ 0.02329532988369465,
+ -0.05166983604431152,
+ -0.002982117934152484,
+ -0.000014124364497547504,
+ -0.0020334068685770035,
+ 0,
+ 0,
+ 0,
+ -0.02020590752363205,
+ 0.00998645182698965,
+ 0,
+ -0.004585121292620897,
+ 0.005916096270084381,
+ 0.0018219061894342303,
+ 0.005700498353689909,
+ 0.0008085825829766691,
+ 0,
+ 0,
+ -0.0032405084930360317,
+ 0,
+ 0,
+ 0,
+ -0.014961971901357174,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.016915086656808853,
+ 0,
+ 0.016825370490550995,
+ 0,
+ -0.00311169121414423,
+ 0,
+ 0.005266942549496889,
+ 0,
+ 0,
+ 0,
+ -0.009660078212618828,
+ 0.0010975055629387498,
+ 0.006078756880015135,
+ 0,
+ 0,
+ 0.003166533075273037,
+ -0.044512320309877396,
+ 0.0002630578528624028,
+ 0,
+ 0,
+ -0.00025422731414437294,
+ 0,
+ 0,
+ 0,
+ -0.3718416392803192,
+ -0.0008081833366304636,
+ 0.00043700754758901894,
+ 0,
+ -0.023154418915510178,
+ 0.00004691413778346032,
+ 0,
+ 0,
+ -0.0002914638607762754,
+ 0.0006733346963301301,
+ 0,
+ 0.008972969837486744,
+ 0,
+ -0.008168808184564114
+ ],
+ [
+ 0,
+ 0,
+ -0.0006953283445909619,
+ 0,
+ 0,
+ 0,
+ -0.001286927843466401,
+ -0.017273705452680588,
+ 0,
+ 0.05898163467645645,
+ 0.013462062925100327,
+ 0,
+ 0,
+ -0.00003325308352941647,
+ -0.0027551515959203243,
+ -0.004652985371649265,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.21421866118907928,
+ 0,
+ 0,
+ 0.002191215055063367,
+ 0,
+ 0.07645706832408905,
+ 0.0052618952468037605,
+ 0,
+ 0,
+ -0.020269982516765594,
+ 0,
+ 0.013446477241814137,
+ 0.0068704248405992985,
+ 0.08710267394781113,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.028982989490032196,
+ 0,
+ 0,
+ 0.014961526729166508,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.011233230121433735,
+ 0,
+ -0.009112805128097534,
+ 0.003226917004212737,
+ 0,
+ 0,
+ 0,
+ 0.112985759973526,
+ 0,
+ 0.028253009542822838,
+ 0,
+ 0,
+ 0,
+ 0.0009787877788767219,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.03986968472599983,
+ 0,
+ 0,
+ -0.006135094445198774,
+ 0.04977395758032799,
+ 0.0397123359143734,
+ 0.027974072843790054,
+ 0,
+ -0.00044811973930336535,
+ -0.10083132237195969,
+ 0.000008234118467953522,
+ 0.06165996566414833,
+ 0,
+ 0,
+ 0.021058127284049988,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.08074336498975754,
+ -0.009298793971538544,
+ 0,
+ 0.012482613325119019,
+ 0.06513619422912598,
+ 0,
+ 0,
+ 0.00029019018984399736,
+ 0,
+ 0,
+ 0.0014882637187838554,
+ 0,
+ 0,
+ 0,
+ -0.004803473129868507,
+ 0,
+ 0,
+ 0.025678949430584908,
+ 0,
+ -0.04240157827734947,
+ 0,
+ 0,
+ 0.0015190609265118837,
+ 0.0006482255994342268,
+ 0.03654245659708977,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.0020186977926641703,
+ 0,
+ 0,
+ 0.17831696569919586,
+ 0,
+ 0,
+ 0,
+ 0.0005887048901058733,
+ 0.012331255711615086,
+ 0,
+ 0,
+ 0.11619613319635391,
+ 0.04687207192182541,
+ 0.03033648431301117,
+ 0,
+ -0.004195880610495806,
+ 0.00006391256465576589,
+ -0.03162289038300514,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.03672636300325394
+ ],
+ [
+ 0.00788492988795042,
+ 0,
+ 0,
+ 0.003685369621962309,
+ 0,
+ 0,
+ 0,
+ -0.010384900495409966,
+ 0,
+ -0.1327948272228241,
+ -0.22788244485855103,
+ 0,
+ 0.0003893508983310312,
+ 0,
+ 0,
+ 0,
+ 0.009530982933938503,
+ 0,
+ -0.0001286355109186843,
+ 0,
+ 0,
+ 0.0001596187794348225,
+ 0.011789986863732338,
+ -0.0022452236153185368,
+ 0,
+ -0.0014552043285220861,
+ 0.0002036036894423887,
+ -0.03003234602510929,
+ -0.036742936819791794,
+ -0.028862446546554565,
+ -0.003727517556399107,
+ 0,
+ 0.0011460097739472985,
+ 0,
+ -0.027142589911818504,
+ -0.054151974618434906,
+ -0.0004727205669041723,
+ 0,
+ -0.006094601005315781,
+ 0.00013960858632344753,
+ 0,
+ -0.0003665595140773803,
+ 0.00028091753483749926,
+ -0.17846877872943878,
+ -0.004990901332348585,
+ -0.010615025646984577,
+ 0,
+ 0.015916047617793083,
+ 0,
+ 0.0008773574372753501,
+ 0.004459311719983816,
+ 0,
+ -0.015235064551234245,
+ 0,
+ -0.0008741968194954097,
+ -0.04074608162045479,
+ 0.007227533031255007,
+ 0,
+ 0,
+ 0,
+ -0.007763775996863842,
+ 0,
+ 0,
+ -0.0011336231837049127,
+ 0,
+ -0.004542750306427479,
+ 0.016146792098879814,
+ -0.032868705689907074,
+ -0.013282506726682186,
+ 0,
+ 0.1884474903345108,
+ -0.07819699496030807,
+ 0,
+ 0.00013099861098453403,
+ 0.00024322106037288904,
+ 0,
+ 0.04764547944068909,
+ -0.09056885540485382,
+ 0,
+ -0.005007788073271513,
+ 0.000487087934743613,
+ 0,
+ 0,
+ 0,
+ -0.07196655869483948,
+ 0.007451012264937162,
+ 0,
+ -0.013892672955989838,
+ -0.005596193019300699,
+ -0.005349555052816868,
+ -0.00015437132969964296,
+ 0,
+ 0,
+ 0,
+ -0.00894666463136673,
+ 0,
+ 0,
+ 0,
+ -0.036862581968307495,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.026162482798099518,
+ 0,
+ 0.046491872519254684,
+ 0,
+ -0.030160455033183098,
+ 0,
+ -0.009029642678797245,
+ -0.0021479984279721975,
+ -0.0005375721957534552,
+ -0.002135993679985404,
+ -0.027962258085608482,
+ 0,
+ 0.0008057129452936351,
+ 0,
+ 0,
+ 0,
+ -0.26795026659965515,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0027670287527143955,
+ -0.0002252299600513652,
+ -0.7548060417175293,
+ 0,
+ -0.05009680241346359,
+ 0,
+ -0.03914204612374306,
+ 0,
+ 0,
+ 0.016279445961117744,
+ 0,
+ 0,
+ 0,
+ 0.025662390515208244,
+ -0.000049459828005637974,
+ -0.0023572721984237432
+ ],
+ [
+ 0,
+ 0,
+ 0.0009027881897054613,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.01007400918751955,
+ 0.07334298640489578,
+ 0.15174342691898346,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.0007311829249374568,
+ 0,
+ 0,
+ 0,
+ 0.011839455924928188,
+ 0,
+ 0,
+ -0.2282165139913559,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.017542533576488495,
+ 0,
+ 0,
+ 0,
+ 0.1636323779821396,
+ 0,
+ 0.10289037227630615,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.024433566257357597,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.013018166646361351,
+ 0,
+ -0.0005916667287237942,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.07111621648073196,
+ 0.0004984873230569065,
+ 0.015917964279651642,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.06262800097465515,
+ 0.17253385484218597,
+ 0,
+ 0,
+ -0.0007970984443090856,
+ -0.1451263427734375,
+ 0,
+ 0.08718064427375793,
+ 0,
+ 0,
+ 0.007446629460901022,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.09546831995248795,
+ 0,
+ 0,
+ 0.06110787391662598,
+ 0.08931172639131546,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.005256101489067078,
+ -0.00553735950961709,
+ 0,
+ 0,
+ 0.006732907146215439,
+ 0,
+ -0.005547903478145599,
+ 0,
+ 0,
+ 0,
+ 0.01766844280064106,
+ -0.0034187675919383764,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.1122211441397667,
+ 0,
+ 0,
+ 0,
+ 0.009442206472158432,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.00800288561731577,
+ 0,
+ 0.006613056641072035,
+ 0,
+ -0.06462590396404266,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0
+ ],
+ [
+ -0.0009047402418218553,
+ 0,
+ 0,
+ -0.0005877931835129857,
+ 0,
+ 0,
+ 0,
+ -0.0004729636711999774,
+ 0,
+ -0.05036322772502899,
+ -0.24687804281711578,
+ 0,
+ 0.001115482416935265,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0024291854351758957,
+ 0,
+ 0,
+ 0,
+ -0.029154174029827118,
+ -0.011211197823286057,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0075091151520609856,
+ 0,
+ 0.0037634933833032846,
+ 0.022711526602506638,
+ 0,
+ 0,
+ -0.00011145337339257821,
+ 0,
+ -0.08350298553705215,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.0063380529172718525,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0010615212377160788,
+ 0,
+ 0,
+ 0,
+ 0.001314864493906498,
+ 0,
+ 0,
+ -0.0020079570822417736,
+ 0,
+ 0,
+ 0,
+ -0.095857173204422,
+ 0,
+ 0,
+ 0.04977884888648987,
+ 0.04924672096967697,
+ 0,
+ 0.00675918348133564,
+ 0,
+ 0,
+ 0.02823697216808796,
+ -0.07869893312454224,
+ 0,
+ 0,
+ -0.00039145027403719723,
+ 0,
+ 0,
+ 0,
+ -0.03502006456255913,
+ 0,
+ 0,
+ -0.004709419794380665,
+ -0.007543480955064297,
+ -0.007213911972939968,
+ 0.0026987697929143906,
+ 0,
+ 0,
+ 0,
+ -0.0016787010245025158,
+ 0,
+ 0,
+ -0.002866228111088276,
+ -0.04759479686617851,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.005348640959709883,
+ 0,
+ 0.17661413550376892,
+ 0,
+ -0.0024743194226175547,
+ 0,
+ 0.0269751138985157,
+ 0,
+ 0,
+ 0,
+ -0.025461290031671524,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.14607883989810944,
+ 0,
+ 0.020490022376179695,
+ 0.007573024369776249,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.8939738869667053,
+ 0,
+ -0.006900197826325893,
+ -0.0031849159859120846,
+ -0.015817783772945404,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.00032859406201168895,
+ 0.11629504710435867,
+ 0,
+ 0
+ ],
+ [
+ 0,
+ 0,
+ 0.0020032059401273727,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.02256190776824951,
+ 0.07616151124238968,
+ 0.3106333911418915,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.014044971205294132,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.3483165502548218,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.05930393189191818,
+ 0,
+ 0,
+ 0.004992437083274126,
+ -0.08404884487390518,
+ 0,
+ 0,
+ 0,
+ 0.16281214356422424,
+ 0,
+ 0.28443410992622375,
+ 0,
+ 0,
+ 0.0014393558958545327,
+ 0,
+ 0,
+ -0.009063852950930595,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.001169737195596099,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.06898342072963715,
+ 0,
+ 0.007991905324161053,
+ 0,
+ 0.006260615773499012,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.037955716252326965,
+ 0.3505173921585083,
+ 0,
+ 0,
+ 0,
+ -0.338177889585495,
+ 0,
+ 0.158599853515625,
+ 0,
+ 0,
+ 0.01131439208984375,
+ 0,
+ 0.006751265376806259,
+ 0,
+ 0,
+ 0,
+ -0.04573351889848709,
+ 0,
+ 0,
+ 0.04386100172996521,
+ 0.11277603358030319,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0003205372195225209,
+ 0,
+ 0,
+ 0,
+ -0.005409737583249807,
+ -0.009204162284731865,
+ 0,
+ 0,
+ 0,
+ 0.004804544150829315,
+ -0.005810749251395464,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.09645535796880722,
+ 0,
+ 0,
+ 0,
+ 0.032931435853242874,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.028524864464998245,
+ 0.09402520954608917,
+ 0,
+ 0.008998546749353409,
+ 0,
+ -0.03251685947179794,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.037343256175518036
+ ],
+ [
+ 0.0045840502716600895,
+ 0,
+ 0,
+ 0.009021798148751259,
+ 0,
+ -0.004217533860355616,
+ 0,
+ 0.0025705555453896523,
+ 0,
+ -0.035309672355651855,
+ -0.09942735731601715,
+ -0.004700342193245888,
+ 0,
+ 0,
+ 0,
+ 0.018288278952240944,
+ 0.0004021169152110815,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.005593586713075638,
+ 0.0018821493722498417,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.04561242088675499,
+ -0.01815006509423256,
+ 0.016583485528826714,
+ -0.020843051373958588,
+ 0,
+ 0,
+ 0,
+ -0.006372869946062565,
+ 0.04272369295358658,
+ 0,
+ 0.0013309348141774535,
+ 0,
+ 0,
+ 0,
+ -0.0031638317741453648,
+ 0,
+ 0.08714215457439423,
+ 0,
+ -0.005442100111395121,
+ -0.00039313771412707865,
+ 0.0014464370906352997,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.03132649511098862,
+ 0,
+ 0,
+ -0.007972904480993748,
+ 0.01753396727144718,
+ 0,
+ -0.0011563192820176482,
+ 0,
+ 0.0017362519865855575,
+ 0,
+ -0.0004587600124068558,
+ 0,
+ 0.0038881096988916397,
+ -0.008516360074281693,
+ 0,
+ -0.008183307014405727,
+ -0.010095844976603985,
+ 0,
+ 0.10722006857395172,
+ -0.002898464212194085,
+ 0,
+ 0.0012827662285417318,
+ 0.00004252225699019618,
+ 0,
+ 0.07567721605300903,
+ -0.030121177434921265,
+ 0,
+ -0.0010666534071788192,
+ 0.0006539365276694298,
+ 0,
+ 0,
+ -0.0011567147448658943,
+ -0.021622339263558388,
+ 0.028687214478850365,
+ 0,
+ 0,
+ 0.018764594569802284,
+ 0.010613140650093555,
+ 0.019510075449943542,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.005288010463118553,
+ 0,
+ 0,
+ -0.016743114218115807,
+ -0.015873711556196213,
+ -0.009877816773951054,
+ 0.0003150522243231535,
+ 0.023689158260822296,
+ 0,
+ -0.00033418016391806304,
+ 0,
+ 0.04904749244451523,
+ 0,
+ 0,
+ 0,
+ 0.0006500506424345076,
+ 0.000622213410679251,
+ 0.00738720316439867,
+ 0,
+ 0.0012243357487022877,
+ 0.009066173806786537,
+ -0.12073952704668045,
+ 0.0010678119724616408,
+ 0.006296947598457336,
+ 0.002682592486962676,
+ -0.00444818427786231,
+ -0.0023324599023908377,
+ 0,
+ 0,
+ -0.5609893798828125,
+ -0.008780602365732193,
+ 0.015986066311597824,
+ 0,
+ -0.02213476411998272,
+ 0,
+ 0,
+ 0.0006705078994855285,
+ -0.0011221399763599038,
+ 0,
+ 0,
+ 0.025299811735749245,
+ 0,
+ -0.008218510076403618
+ ],
+ [
+ -0.0034782839938998222,
+ 0.0022423912305384874,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.05859537422657013,
+ 0.0421387143433094,
+ 0.26256099343299866,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.10330676287412643,
+ 0.012355834245681763,
+ 0.0013472040882334113,
+ 0,
+ 0.019914263859391212,
+ 0.005261276848614216,
+ 0.001149827498011291,
+ -0.03320133313536644,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.32198745012283325,
+ 0,
+ 0.05401667580008507,
+ 0.04610951617360115,
+ -0.2326284795999527,
+ 0,
+ 0.0000856258993735537,
+ 0,
+ 0.074106365442276,
+ 0,
+ 0.044469863176345825,
+ 0,
+ 0,
+ 0,
+ -0.006453251000493765,
+ 0,
+ -0.018431225791573524,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.005704954732209444,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.13457728922367096,
+ 0,
+ -0.029186677187681198,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.022995056584477425,
+ 0,
+ 0,
+ 0,
+ -0.09004921466112137,
+ 0.24257110059261322,
+ 0.02852930873632431,
+ 0,
+ 0.021270141005516052,
+ -0.13564155995845795,
+ -0.03098711557686329,
+ 0,
+ 0,
+ 0,
+ 0.0486220121383667,
+ -0.001395023544318974,
+ 0,
+ 0.04929636791348457,
+ 0,
+ 0,
+ -0.13068373501300812,
+ 0,
+ 0,
+ 0.016955919563770294,
+ 0.03848254308104515,
+ 0,
+ -0.011160435155034065,
+ 0,
+ 0,
+ -0.0002991429646499455,
+ 0,
+ 0,
+ 0,
+ 0.01138608530163765,
+ -0.020150866359472275,
+ 0,
+ 0,
+ -0.007353566121309996,
+ -0.021389631554484367,
+ -0.042083244770765305,
+ 0.13586723804473877,
+ 0.005315479822456837,
+ 0,
+ 0.008157049305737019,
+ 0.022239860147237778,
+ 0,
+ 0,
+ 0.01896926946938038,
+ 0.0018052944215014577,
+ 0.0016496418975293636,
+ 0.0005593635141849518,
+ 0,
+ 0,
+ 0.07655386626720428,
+ 0,
+ 0,
+ 0,
+ 0.02781328558921814,
+ 0.04012482985854149,
+ 0,
+ 0,
+ 0,
+ 0.10631410032510757,
+ 0.03608629107475281,
+ 0,
+ -0.02651066705584526,
+ 0,
+ -0.0690990686416626,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.1022648885846138
+ ]
+ ]
+ }
+ ],
+ "layout": {
+ "coloraxis": {
+ "cmid": 0,
+ "colorscale": [
+ [
+ 0,
+ "rgb(103,0,31)"
+ ],
+ [
+ 0.1,
+ "rgb(178,24,43)"
+ ],
+ [
+ 0.2,
+ "rgb(214,96,77)"
+ ],
+ [
+ 0.3,
+ "rgb(244,165,130)"
+ ],
+ [
+ 0.4,
+ "rgb(253,219,199)"
+ ],
+ [
+ 0.5,
+ "rgb(247,247,247)"
+ ],
+ [
+ 0.6,
+ "rgb(209,229,240)"
+ ],
+ [
+ 0.7,
+ "rgb(146,197,222)"
+ ],
+ [
+ 0.8,
+ "rgb(67,147,195)"
+ ],
+ [
+ 0.9,
+ "rgb(33,102,172)"
+ ],
+ [
+ 1,
+ "rgb(5,48,97)"
+ ]
+ ]
+ },
+ "template": {
+ "data": {
+ "bar": [
+ {
+ "error_x": {
+ "color": "#2a3f5f"
+ },
+ "error_y": {
+ "color": "#2a3f5f"
+ },
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "bar"
+ }
+ ],
+ "barpolar": [
+ {
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "barpolar"
+ }
+ ],
+ "carpet": [
+ {
+ "aaxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "baxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "type": "carpet"
+ }
+ ],
+ "choropleth": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "choropleth"
+ }
+ ],
+ "contour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "contour"
+ }
+ ],
+ "contourcarpet": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "contourcarpet"
+ }
+ ],
+ "heatmap": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmap"
+ }
+ ],
+ "heatmapgl": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmapgl"
+ }
+ ],
+ "histogram": [
+ {
+ "marker": {
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "histogram"
+ }
+ ],
+ "histogram2d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2d"
+ }
+ ],
+ "histogram2dcontour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2dcontour"
+ }
+ ],
+ "mesh3d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "mesh3d"
+ }
+ ],
+ "parcoords": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "parcoords"
+ }
+ ],
+ "pie": [
+ {
+ "automargin": true,
+ "type": "pie"
+ }
+ ],
+ "scatter": [
+ {
+ "fillpattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ },
+ "type": "scatter"
+ }
+ ],
+ "scatter3d": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatter3d"
+ }
+ ],
+ "scattercarpet": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattercarpet"
+ }
+ ],
+ "scattergeo": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergeo"
+ }
+ ],
+ "scattergl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergl"
+ }
+ ],
+ "scattermapbox": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattermapbox"
+ }
+ ],
+ "scatterpolar": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolar"
+ }
+ ],
+ "scatterpolargl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolargl"
+ }
+ ],
+ "scatterternary": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterternary"
+ }
+ ],
+ "surface": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "surface"
+ }
+ ],
+ "table": [
+ {
+ "cells": {
+ "fill": {
+ "color": "#EBF0F8"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "header": {
+ "fill": {
+ "color": "#C8D4E3"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "type": "table"
+ }
+ ]
+ },
+ "layout": {
+ "annotationdefaults": {
+ "arrowcolor": "#2a3f5f",
+ "arrowhead": 0,
+ "arrowwidth": 1
+ },
+ "autotypenumbers": "strict",
+ "coloraxis": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "colorscale": {
+ "diverging": [
+ [
+ 0,
+ "#8e0152"
+ ],
+ [
+ 0.1,
+ "#c51b7d"
+ ],
+ [
+ 0.2,
+ "#de77ae"
+ ],
+ [
+ 0.3,
+ "#f1b6da"
+ ],
+ [
+ 0.4,
+ "#fde0ef"
+ ],
+ [
+ 0.5,
+ "#f7f7f7"
+ ],
+ [
+ 0.6,
+ "#e6f5d0"
+ ],
+ [
+ 0.7,
+ "#b8e186"
+ ],
+ [
+ 0.8,
+ "#7fbc41"
+ ],
+ [
+ 0.9,
+ "#4d9221"
+ ],
+ [
+ 1,
+ "#276419"
+ ]
+ ],
+ "sequential": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "sequentialminus": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ]
+ },
+ "colorway": [
+ "#636efa",
+ "#EF553B",
+ "#00cc96",
+ "#ab63fa",
+ "#FFA15A",
+ "#19d3f3",
+ "#FF6692",
+ "#B6E880",
+ "#FF97FF",
+ "#FECB52"
+ ],
+ "font": {
+ "color": "#2a3f5f"
+ },
+ "geo": {
+ "bgcolor": "white",
+ "lakecolor": "white",
+ "landcolor": "#E5ECF6",
+ "showlakes": true,
+ "showland": true,
+ "subunitcolor": "white"
+ },
+ "hoverlabel": {
+ "align": "left"
+ },
+ "hovermode": "closest",
+ "mapbox": {
+ "style": "light"
+ },
+ "paper_bgcolor": "white",
+ "plot_bgcolor": "#E5ECF6",
+ "polar": {
+ "angularaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "radialaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "scene": {
+ "xaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "yaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "zaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ }
+ },
+ "shapedefaults": {
+ "line": {
+ "color": "#2a3f5f"
+ }
+ },
+ "ternary": {
+ "aaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "baxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "caxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "title": {
+ "x": 0.05
+ },
+ "xaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ },
+ "yaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ }
+ }
+ },
+ "title": {
+ "text": "attribution patching"
+ },
+ "xaxis": {
+ "anchor": "y",
+ "constrain": "domain",
+ "domain": [
+ 0,
+ 1
+ ],
+ "scaleanchor": "y",
+ "title": {
+ "text": "Feature Idx"
+ }
+ },
+ "yaxis": {
+ "anchor": "x",
+ "autorange": "reversed",
+ "constrain": "domain",
+ "domain": [
+ 0,
+ 1
+ ],
+ "title": {
+ "text": "Prompt Idx"
+ }
+ }
+ }
+ },
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "def attr_patch_sae_acts(\n",
+ " clean_cache: ActivationCache, \n",
+ " clean_grad_cache: ActivationCache,\n",
+ " site: str, layer: int\n",
+ " ):\n",
+ " clean_sae_acts_post = clean_cache[utils.get_act_name(site, layer) + \".hook_sae_acts_post\"] \n",
+ " clean_grad_sae_acts_post = clean_grad_cache[utils.get_act_name(site, layer) + \".hook_sae_acts_post\"] \n",
+ " sae_act_attr = clean_grad_sae_acts_post * (0 - clean_sae_acts_post)\n",
+ " return sae_act_attr\n",
+ "\n",
+ "site = \"z\"\n",
+ "layer = 5\n",
+ "sae_act_attr = attr_patch_sae_acts(clean_cache, clean_grad_cache, site, layer)\n",
+ "\n",
+ "imshow(\n",
+ " sae_act_attr[:, s2_pos, all_live_features],\n",
+ " title=\"attribution patching\",\n",
+ " xaxis=\"Feature Idx\", yaxis=\"Prompt Idx\", x=list(map(str, all_live_features.tolist())))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.plotly.v1+json": {
+ "config": {
+ "plotlyServerURL": "https://plot.ly"
+ },
+ "data": [
+ {
+ "hovertemplate": "Activation Patch=%{x}
Attribution Patch=%{y}",
+ "legendgroup": "",
+ "marker": {
+ "color": "#636efa",
+ "symbol": "circle"
+ },
+ "mode": "markers",
+ "name": "",
+ "showlegend": false,
+ "type": "scattergl",
+ "x": [
+ 0.0012617111206054688,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ 0.0016908645629882812,
+ -0.0002231597900390625,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.00029659271240234375,
+ -9.5367431640625e-7,
+ -0.03279590606689453,
+ -0.07254886627197266,
+ -9.5367431640625e-7,
+ 0.00013065338134765625,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.014922142028808594,
+ -0.0044403076171875,
+ 0.0007047653198242188,
+ -0.00428009033203125,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.039069175720214844,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.007334709167480469,
+ 0.00033092498779296875,
+ -0.0017004013061523438,
+ 0.0026845932006835938,
+ 0.00043010711669921875,
+ -0.11128997802734375,
+ -0.0038976669311523438,
+ -0.006033897399902344,
+ -9.5367431640625e-7,
+ -0.00027751922607421875,
+ -9.5367431640625e-7,
+ 0.0006570816040039062,
+ -0.0004291534423828125,
+ -9.5367431640625e-7,
+ -0.0035734176635742188,
+ -0.0033063888549804688,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ 0.0033960342407226562,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.0030546188354492188,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.0000972747802734375,
+ -0.0001811981201171875,
+ -9.5367431640625e-7,
+ -0.004569053649902344,
+ -0.013583183288574219,
+ -9.5367431640625e-7,
+ 0.02047252655029297,
+ -0.02572154998779297,
+ -9.5367431640625e-7,
+ -0.0006608963012695312,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ 0.02255725860595703,
+ -0.05519580841064453,
+ -0.0033473968505859375,
+ -0.0000057220458984375,
+ -0.0026073455810546875,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.02097320556640625,
+ 0.008440971374511719,
+ -9.5367431640625e-7,
+ -0.004597663879394531,
+ 0.00159454345703125,
+ 0.0001544952392578125,
+ 0.005199432373046875,
+ 0.0007762908935546875,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.0032625198364257812,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.015192985534667969,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.018138885498046875,
+ -9.5367431640625e-7,
+ 0.010298728942871094,
+ -9.5367431640625e-7,
+ -0.0031423568725585938,
+ -9.5367431640625e-7,
+ 0.004242897033691406,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.010041236877441406,
+ 0.0010347366333007812,
+ 0.006011962890625,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ 0.00301361083984375,
+ -0.04584026336669922,
+ 0.0002079010009765625,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.0002574920654296875,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.45942211151123047,
+ -0.0008325576782226562,
+ 0.00041484832763671875,
+ -9.5367431640625e-7,
+ -0.023777008056640625,
+ 0.0000514984130859375,
+ -9.5367431640625e-7,
+ -9.5367431640625e-7,
+ -0.00030422210693359375,
+ 0.0006666183471679688,
+ -9.5367431640625e-7,
+ 0.004633903503417969,
+ -9.5367431640625e-7,
+ -0.008234977722167969,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ -0.00208282470703125,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ -0.0012912750244140625,
+ -0.01760101318359375,
+ 0.000003814697265625,
+ 0.057277679443359375,
+ 0.013429641723632812,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ -0.0000457763671875,
+ -0.0027828216552734375,
+ -0.0055084228515625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ -0.2744255065917969,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.0021514892578125,
+ 0.000003814697265625,
+ 0.06994247436523438,
+ 0.0048542022705078125,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ -0.0567169189453125,
+ 0.000003814697265625,
+ 0.012315750122070312,
+ 0.0066585540771484375,
+ 0.07937240600585938,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.028867721557617188,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.0074901580810546875,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.009624481201171875,
+ 0.000003814697265625,
+ -0.009510040283203125,
+ 0.0032100677490234375,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.10918617248535156,
+ 0.000003814697265625,
+ 0.026102066040039062,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000946044921875,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ -0.041675567626953125,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ -0.0066776275634765625,
+ 0.03926849365234375,
+ 0.03615379333496094,
+ 0.027612686157226562,
+ 0.000003814697265625,
+ -0.0004673004150390625,
+ -0.1435985565185547,
+ -0.00030517578125,
+ 0.059326171875,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.020435333251953125,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ -0.11923980712890625,
+ -0.009393692016601562,
+ 0.000003814697265625,
+ 0.011783599853515625,
+ 0.06122589111328125,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.0002918243408203125,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.001491546630859375,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ -0.0050716400146484375,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.025064468383789062,
+ 0.000003814697265625,
+ -0.0467529296875,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.0014934539794921875,
+ 0.00043487548828125,
+ 0.028188705444335938,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.001995086669921875,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.13014602661132812,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.0005893707275390625,
+ 0.012182235717773438,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.11103057861328125,
+ 0.042850494384765625,
+ 0.030099868774414062,
+ 0.000003814697265625,
+ -0.0047321319580078125,
+ 0.0000133514404296875,
+ -0.0320587158203125,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.000003814697265625,
+ 0.031030654907226562,
+ 0.007018089294433594,
+ 0,
+ 0,
+ 0.0028057098388671875,
+ 0,
+ 0,
+ 0,
+ -0.010999679565429688,
+ 0,
+ -0.1419973373413086,
+ -0.24188613891601562,
+ 0,
+ 0.0003147125244140625,
+ 0,
+ 0,
+ 0,
+ 0.009432792663574219,
+ 0,
+ -0.000125885009765625,
+ 0,
+ 0,
+ 0.00017070770263671875,
+ 0.011651992797851562,
+ -0.00225830078125,
+ 0,
+ -0.0014581680297851562,
+ 0.00020122528076171875,
+ -0.030771255493164062,
+ -0.03744316101074219,
+ -0.034499168395996094,
+ -0.00374603271484375,
+ 0,
+ 0.0011348724365234375,
+ 0,
+ -0.0302276611328125,
+ -0.08229637145996094,
+ -0.00048160552978515625,
+ 0,
+ -0.00640869140625,
+ 0.0001277923583984375,
+ 0,
+ -0.0008974075317382812,
+ 0.00022983551025390625,
+ -0.2322559356689453,
+ -0.0050449371337890625,
+ -0.010677337646484375,
+ 0,
+ 0.014942169189453125,
+ 0,
+ 0.0008764266967773438,
+ 0.00417327880859375,
+ 0,
+ -0.015301704406738281,
+ 0,
+ -0.0008974075317382812,
+ -0.04426097869873047,
+ 0.005242347717285156,
+ 0,
+ 0,
+ 0,
+ -0.009447097778320312,
+ 0,
+ 0,
+ -0.0011806488037109375,
+ 0,
+ -0.0045909881591796875,
+ 0.015285491943359375,
+ -0.034976959228515625,
+ -0.013401985168457031,
+ 0,
+ 0.1357421875,
+ -0.09111690521240234,
+ 0,
+ 0.00013065338134765625,
+ 0.0002460479736328125,
+ 0,
+ 0.04656982421875,
+ -0.09346866607666016,
+ 0,
+ -0.005030632019042969,
+ 0.0001125335693359375,
+ 0,
+ 0,
+ 0,
+ -0.07491683959960938,
+ 0.006598472595214844,
+ 0,
+ -0.014060020446777344,
+ -0.008306503295898438,
+ -0.0054874420166015625,
+ -0.0004930496215820312,
+ 0,
+ 0,
+ 0,
+ -0.008953094482421875,
+ 0,
+ 0,
+ 0,
+ -0.03713417053222656,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.028200149536132812,
+ 0,
+ 0.036255836486816406,
+ 0,
+ -0.03178215026855469,
+ 0,
+ -0.012192726135253906,
+ -0.002147674560546875,
+ -0.0005474090576171875,
+ -0.0021409988403320312,
+ -0.030725479125976562,
+ 0,
+ 0.0008029937744140625,
+ 0,
+ 0,
+ 0,
+ -0.29135894775390625,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0027914047241210938,
+ -0.00022125244140625,
+ -0.8653240203857422,
+ 0,
+ -0.05593109130859375,
+ 0,
+ -0.04123210906982422,
+ 0,
+ 0,
+ 0.015351295471191406,
+ 0,
+ 0,
+ 0,
+ 0.018423080444335938,
+ -0.0000476837158203125,
+ -0.0023584365844726562,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.0001983642578125,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ -0.010341644287109375,
+ 0.07198715209960938,
+ 0.14725303649902344,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.0002918243408203125,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.011704444885253906,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ -0.3150959014892578,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ -0.039947509765625,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.15607547760009766,
+ 9.5367431640625e-7,
+ 0.09917640686035156,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.019521713256835938,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.012205123901367188,
+ 9.5367431640625e-7,
+ -0.0005893707275390625,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.07062149047851562,
+ 0.000492095947265625,
+ 0.014776229858398438,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.0557098388671875,
+ 0.15409469604492188,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ -0.0007076263427734375,
+ -0.24256324768066406,
+ 9.5367431640625e-7,
+ 0.0858917236328125,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.007343292236328125,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ -0.11646080017089844,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.05528736114501953,
+ 0.0847921371459961,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.00428009033203125,
+ -0.0056171417236328125,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.0066967010498046875,
+ 9.5367431640625e-7,
+ -0.006005287170410156,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.01735687255859375,
+ -0.0037336349487304688,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.09533309936523438,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.009324073791503906,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 0.007989883422851562,
+ 9.5367431640625e-7,
+ 0.0064525604248046875,
+ 9.5367431640625e-7,
+ -0.06574440002441406,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ 9.5367431640625e-7,
+ -0.0009012222290039062,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.0006313323974609375,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.000461578369140625,
+ 0.00001239776611328125,
+ -0.055993080139160156,
+ -0.24974536895751953,
+ 0.00001239776611328125,
+ 0.0011262893676757812,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.0025796890258789062,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.030013084411621094,
+ -0.012925148010253906,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.0253448486328125,
+ 0.00001239776611328125,
+ 0.0012464523315429688,
+ 0.021536827087402344,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.00009822845458984375,
+ 0.00001239776611328125,
+ -0.09924793243408203,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.006188392639160156,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.0010576248168945312,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.0008172988891601562,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.0020704269409179688,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.09985160827636719,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.036945343017578125,
+ 0.025011062622070312,
+ 0.00001239776611328125,
+ 0.004599571228027344,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.027939796447753906,
+ -0.07974910736083984,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.00038242340087890625,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.035175323486328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.0047245025634765625,
+ -0.008166313171386719,
+ -0.008578300476074219,
+ 0.0018529891967773438,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.0016679763793945312,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.0028676986694335938,
+ -0.04880046844482422,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.0053462982177734375,
+ 0.00001239776611328125,
+ 0.1658468246459961,
+ 0.00001239776611328125,
+ -0.0024824142456054688,
+ 0.00001239776611328125,
+ 0.025139808654785156,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.027915000915527344,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.14544200897216797,
+ 0.00001239776611328125,
+ 0.020270347595214844,
+ 0.007473945617675781,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.8424196243286133,
+ 0.00001239776611328125,
+ -0.007409095764160156,
+ -0.00318145751953125,
+ -0.015982627868652344,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ 0.00034046173095703125,
+ 0.10727787017822266,
+ 0.00001239776611328125,
+ 0.00001239776611328125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.0019397735595703125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.022940635681152344,
+ 0.07428932189941406,
+ 0.29994869232177734,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.016974449157714844,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.4772310256958008,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.05463600158691406,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.004734992980957031,
+ -0.12352275848388672,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.15236186981201172,
+ -0.00000762939453125,
+ 0.27855396270751953,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.001430511474609375,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.016387939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.0008668899536132812,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.06814861297607422,
+ -0.00000762939453125,
+ 0.00351715087890625,
+ -0.00000762939453125,
+ 0.0061588287353515625,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.02480602264404297,
+ 0.31668567657470703,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.4413900375366211,
+ -0.00000762939453125,
+ 0.1517791748046875,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.010898590087890625,
+ -0.00000762939453125,
+ 0.006583213806152344,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.04946422576904297,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.040429115295410156,
+ 0.1020956039428711,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.0008649826049804688,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.0054836273193359375,
+ -0.010519981384277344,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.004588127136230469,
+ -0.006558418273925781,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.07750797271728516,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.03235149383544922,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.02773571014404297,
+ 0.08978557586669922,
+ -0.00000762939453125,
+ 0.008780479431152344,
+ -0.00000762939453125,
+ -0.0327301025390625,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ -0.00000762939453125,
+ 0.035370826721191406,
+ 0.0041904449462890625,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ 0.008458137512207031,
+ -0.0000057220458984375,
+ -0.0042858123779296875,
+ -0.0000057220458984375,
+ 0.002468109130859375,
+ -0.0000057220458984375,
+ -0.03716564178466797,
+ -0.10456657409667969,
+ -0.0047702789306640625,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ 0.017671585083007812,
+ 0.0004062652587890625,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.00655364990234375,
+ 0.001873016357421875,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.04653644561767578,
+ -0.01836395263671875,
+ 0.014448165893554688,
+ -0.0209197998046875,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.006618499755859375,
+ 0.02408599853515625,
+ -0.0000057220458984375,
+ 0.0012884140014648438,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0050144195556640625,
+ -0.0000057220458984375,
+ 0.036708831787109375,
+ -0.0000057220458984375,
+ -0.0056591033935546875,
+ -0.0004215240478515625,
+ 0.0014057159423828125,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.033232688903808594,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.008130073547363281,
+ 0.016930580139160156,
+ -0.0000057220458984375,
+ -0.0012025833129882812,
+ -0.0000057220458984375,
+ 0.000545501708984375,
+ -0.0000057220458984375,
+ -0.0004673004150390625,
+ -0.0000057220458984375,
+ 0.0038089752197265625,
+ -0.008646011352539062,
+ -0.0000057220458984375,
+ -0.008909225463867188,
+ -0.011255264282226562,
+ -0.0000057220458984375,
+ 0.0925750732421875,
+ -0.0064563751220703125,
+ -0.0000057220458984375,
+ 0.0011615753173828125,
+ 0.00002956390380859375,
+ -0.0000057220458984375,
+ 0.07063961029052734,
+ -0.030902862548828125,
+ -0.0000057220458984375,
+ -0.0010814666748046875,
+ 0.00038909912109375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0013380050659179688,
+ -0.022397994995117188,
+ 0.027740478515625,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ 0.01797771453857422,
+ 0.009552955627441406,
+ 0.01857471466064453,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0055103302001953125,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.01697063446044922,
+ -0.0159149169921875,
+ -0.011240959167480469,
+ 0.000301361083984375,
+ 0.020501136779785156,
+ -0.0000057220458984375,
+ -0.0006427764892578125,
+ -0.0000057220458984375,
+ 0.04800224304199219,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.00001811981201171875,
+ 0.0005950927734375,
+ 0.00732421875,
+ -0.0000057220458984375,
+ 0.001216888427734375,
+ 0.00897216796875,
+ -0.1255035400390625,
+ 0.001003265380859375,
+ 0.006274223327636719,
+ 0.0026502609252929688,
+ -0.00449371337890625,
+ -0.0023517608642578125,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ -0.6521244049072266,
+ -0.009072303771972656,
+ 0.013387680053710938,
+ -0.0000057220458984375,
+ -0.022745132446289062,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ 0.000606536865234375,
+ -0.0011501312255859375,
+ -0.0000057220458984375,
+ -0.0000057220458984375,
+ 0.023046493530273438,
+ -0.0000057220458984375,
+ -0.008263587951660156,
+ -0.0037221908569335938,
+ 0.00225830078125,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.05941295623779297,
+ 0.04140281677246094,
+ 0.24284648895263672,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.11462688446044922,
+ 0.012240409851074219,
+ 0.0012884140014648438,
+ -0.00001049041748046875,
+ 0.01781749725341797,
+ 0.005211830139160156,
+ -0.0016298294067382812,
+ -0.2994966506958008,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ 0.26962947845458984,
+ -0.00001049041748046875,
+ 0.050202369689941406,
+ 0.04053211212158203,
+ -0.30355358123779297,
+ -0.00001049041748046875,
+ -0.0013666152954101562,
+ -0.00001049041748046875,
+ 0.06442928314208984,
+ -0.00001049041748046875,
+ 0.04406547546386719,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00763702392578125,
+ -0.00001049041748046875,
+ -0.03402233123779297,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.005751609802246094,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ 0.13191986083984375,
+ -0.00001049041748046875,
+ -0.031653404235839844,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.02394390106201172,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.1073293685913086,
+ 0.20270729064941406,
+ 0.02746295928955078,
+ -0.00001049041748046875,
+ 0.020377159118652344,
+ -0.31055259704589844,
+ -0.043480873107910156,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ 0.04507160186767578,
+ -0.0014734268188476562,
+ -0.00001049041748046875,
+ 0.048813819885253906,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.1407604217529297,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ 0.013548851013183594,
+ 0.016210556030273438,
+ -0.00001049041748046875,
+ -0.011261940002441406,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00029277801513671875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ 0.008993148803710938,
+ -0.020813941955566406,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.008435249328613281,
+ -0.021961212158203125,
+ -0.04410362243652344,
+ 0.1307668685913086,
+ 0.005297660827636719,
+ -0.00001049041748046875,
+ 0.006031990051269531,
+ 0.016150474548339844,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ 0.01802349090576172,
+ 0.0018205642700195312,
+ 0.0016574859619140625,
+ 0.0005712509155273438,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ 0.02598094940185547,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ 0.02737903594970703,
+ 0.039580345153808594,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ 0.09876728057861328,
+ 0.035803794860839844,
+ -0.00001049041748046875,
+ -0.027251243591308594,
+ -0.00001049041748046875,
+ -0.07061004638671875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ -0.00001049041748046875,
+ 0.08719158172607422
+ ],
+ "xaxis": "x",
+ "y": [
+ 0.001567811705172062,
+ 0,
+ 0,
+ 0.001697835512459278,
+ 0.00011560246639419347,
+ 0,
+ 0,
+ -0.0002851475146599114,
+ 0,
+ -0.030827227979898453,
+ -0.06409652531147003,
+ 0,
+ 0.00015167289529927075,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.013627146370708942,
+ -0.004393726587295532,
+ 0.0015328703448176384,
+ -0.0038613511715084314,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.02049136720597744,
+ 0,
+ 0,
+ -0.007114107254892588,
+ 0.0003477374848444015,
+ -0.001384311355650425,
+ 0.003183899214491248,
+ 0.0004558839718811214,
+ -0.059277813881635666,
+ -0.0035793157294392586,
+ -0.00589390005916357,
+ 0,
+ -0.0001910730206873268,
+ 0,
+ 0.0006608504336327314,
+ -0.0004212319909129292,
+ 0,
+ -0.003545185085386038,
+ -0.00327106611803174,
+ 0,
+ 0,
+ 0.0040074847638607025,
+ 0,
+ 0,
+ 0,
+ -0.0026069351006299257,
+ 0,
+ 0,
+ 0,
+ -0.00008433026232523844,
+ -0.00018646706303115934,
+ 0,
+ -0.00439279293641448,
+ -0.013254894874989986,
+ 0,
+ 0.050094299018383026,
+ -0.021308520808815956,
+ 0,
+ -0.0006410681526176631,
+ 0,
+ 0,
+ 0.02329532988369465,
+ -0.05166983604431152,
+ -0.002982117934152484,
+ -0.000014124364497547504,
+ -0.0020334068685770035,
+ 0,
+ 0,
+ 0,
+ -0.02020590752363205,
+ 0.00998645182698965,
+ 0,
+ -0.004585121292620897,
+ 0.005916096270084381,
+ 0.0018219061894342303,
+ 0.005700498353689909,
+ 0.0008085825829766691,
+ 0,
+ 0,
+ -0.0032405084930360317,
+ 0,
+ 0,
+ 0,
+ -0.014961971901357174,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.016915086656808853,
+ 0,
+ 0.016825370490550995,
+ 0,
+ -0.00311169121414423,
+ 0,
+ 0.005266942549496889,
+ 0,
+ 0,
+ 0,
+ -0.009660078212618828,
+ 0.0010975055629387498,
+ 0.006078756880015135,
+ 0,
+ 0,
+ 0.003166533075273037,
+ -0.044512320309877396,
+ 0.0002630578528624028,
+ 0,
+ 0,
+ -0.00025422731414437294,
+ 0,
+ 0,
+ 0,
+ -0.3718416392803192,
+ -0.0008081833366304636,
+ 0.00043700754758901894,
+ 0,
+ -0.023154418915510178,
+ 0.00004691413778346032,
+ 0,
+ 0,
+ -0.0002914638607762754,
+ 0.0006733346963301301,
+ 0,
+ 0.008972969837486744,
+ 0,
+ -0.008168808184564114,
+ 0,
+ 0,
+ -0.0006953283445909619,
+ 0,
+ 0,
+ 0,
+ -0.001286927843466401,
+ -0.017273705452680588,
+ 0,
+ 0.05898163467645645,
+ 0.013462062925100327,
+ 0,
+ 0,
+ -0.00003325308352941647,
+ -0.0027551515959203243,
+ -0.004652985371649265,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.21421866118907928,
+ 0,
+ 0,
+ 0.002191215055063367,
+ 0,
+ 0.07645706832408905,
+ 0.0052618952468037605,
+ 0,
+ 0,
+ -0.020269982516765594,
+ 0,
+ 0.013446477241814137,
+ 0.0068704248405992985,
+ 0.08710267394781113,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.028982989490032196,
+ 0,
+ 0,
+ 0.014961526729166508,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.011233230121433735,
+ 0,
+ -0.009112805128097534,
+ 0.003226917004212737,
+ 0,
+ 0,
+ 0,
+ 0.112985759973526,
+ 0,
+ 0.028253009542822838,
+ 0,
+ 0,
+ 0,
+ 0.0009787877788767219,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.03986968472599983,
+ 0,
+ 0,
+ -0.006135094445198774,
+ 0.04977395758032799,
+ 0.0397123359143734,
+ 0.027974072843790054,
+ 0,
+ -0.00044811973930336535,
+ -0.10083132237195969,
+ 0.000008234118467953522,
+ 0.06165996566414833,
+ 0,
+ 0,
+ 0.021058127284049988,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.08074336498975754,
+ -0.009298793971538544,
+ 0,
+ 0.012482613325119019,
+ 0.06513619422912598,
+ 0,
+ 0,
+ 0.00029019018984399736,
+ 0,
+ 0,
+ 0.0014882637187838554,
+ 0,
+ 0,
+ 0,
+ -0.004803473129868507,
+ 0,
+ 0,
+ 0.025678949430584908,
+ 0,
+ -0.04240157827734947,
+ 0,
+ 0,
+ 0.0015190609265118837,
+ 0.0006482255994342268,
+ 0.03654245659708977,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.0020186977926641703,
+ 0,
+ 0,
+ 0.17831696569919586,
+ 0,
+ 0,
+ 0,
+ 0.0005887048901058733,
+ 0.012331255711615086,
+ 0,
+ 0,
+ 0.11619613319635391,
+ 0.04687207192182541,
+ 0.03033648431301117,
+ 0,
+ -0.004195880610495806,
+ 0.00006391256465576589,
+ -0.03162289038300514,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.03672636300325394,
+ 0.00788492988795042,
+ 0,
+ 0,
+ 0.003685369621962309,
+ 0,
+ 0,
+ 0,
+ -0.010384900495409966,
+ 0,
+ -0.1327948272228241,
+ -0.22788244485855103,
+ 0,
+ 0.0003893508983310312,
+ 0,
+ 0,
+ 0,
+ 0.009530982933938503,
+ 0,
+ -0.0001286355109186843,
+ 0,
+ 0,
+ 0.0001596187794348225,
+ 0.011789986863732338,
+ -0.0022452236153185368,
+ 0,
+ -0.0014552043285220861,
+ 0.0002036036894423887,
+ -0.03003234602510929,
+ -0.036742936819791794,
+ -0.028862446546554565,
+ -0.003727517556399107,
+ 0,
+ 0.0011460097739472985,
+ 0,
+ -0.027142589911818504,
+ -0.054151974618434906,
+ -0.0004727205669041723,
+ 0,
+ -0.006094601005315781,
+ 0.00013960858632344753,
+ 0,
+ -0.0003665595140773803,
+ 0.00028091753483749926,
+ -0.17846877872943878,
+ -0.004990901332348585,
+ -0.010615025646984577,
+ 0,
+ 0.015916047617793083,
+ 0,
+ 0.0008773574372753501,
+ 0.004459311719983816,
+ 0,
+ -0.015235064551234245,
+ 0,
+ -0.0008741968194954097,
+ -0.04074608162045479,
+ 0.007227533031255007,
+ 0,
+ 0,
+ 0,
+ -0.007763775996863842,
+ 0,
+ 0,
+ -0.0011336231837049127,
+ 0,
+ -0.004542750306427479,
+ 0.016146792098879814,
+ -0.032868705689907074,
+ -0.013282506726682186,
+ 0,
+ 0.1884474903345108,
+ -0.07819699496030807,
+ 0,
+ 0.00013099861098453403,
+ 0.00024322106037288904,
+ 0,
+ 0.04764547944068909,
+ -0.09056885540485382,
+ 0,
+ -0.005007788073271513,
+ 0.000487087934743613,
+ 0,
+ 0,
+ 0,
+ -0.07196655869483948,
+ 0.007451012264937162,
+ 0,
+ -0.013892672955989838,
+ -0.005596193019300699,
+ -0.005349555052816868,
+ -0.00015437132969964296,
+ 0,
+ 0,
+ 0,
+ -0.00894666463136673,
+ 0,
+ 0,
+ 0,
+ -0.036862581968307495,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.026162482798099518,
+ 0,
+ 0.046491872519254684,
+ 0,
+ -0.030160455033183098,
+ 0,
+ -0.009029642678797245,
+ -0.0021479984279721975,
+ -0.0005375721957534552,
+ -0.002135993679985404,
+ -0.027962258085608482,
+ 0,
+ 0.0008057129452936351,
+ 0,
+ 0,
+ 0,
+ -0.26795026659965515,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0027670287527143955,
+ -0.0002252299600513652,
+ -0.7548060417175293,
+ 0,
+ -0.05009680241346359,
+ 0,
+ -0.03914204612374306,
+ 0,
+ 0,
+ 0.016279445961117744,
+ 0,
+ 0,
+ 0,
+ 0.025662390515208244,
+ -0.000049459828005637974,
+ -0.0023572721984237432,
+ 0,
+ 0,
+ 0.0009027881897054613,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.01007400918751955,
+ 0.07334298640489578,
+ 0.15174342691898346,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.0007311829249374568,
+ 0,
+ 0,
+ 0,
+ 0.011839455924928188,
+ 0,
+ 0,
+ -0.2282165139913559,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.017542533576488495,
+ 0,
+ 0,
+ 0,
+ 0.1636323779821396,
+ 0,
+ 0.10289037227630615,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.024433566257357597,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.013018166646361351,
+ 0,
+ -0.0005916667287237942,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.07111621648073196,
+ 0.0004984873230569065,
+ 0.015917964279651642,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.06262800097465515,
+ 0.17253385484218597,
+ 0,
+ 0,
+ -0.0007970984443090856,
+ -0.1451263427734375,
+ 0,
+ 0.08718064427375793,
+ 0,
+ 0,
+ 0.007446629460901022,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.09546831995248795,
+ 0,
+ 0,
+ 0.06110787391662598,
+ 0.08931172639131546,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.005256101489067078,
+ -0.00553735950961709,
+ 0,
+ 0,
+ 0.006732907146215439,
+ 0,
+ -0.005547903478145599,
+ 0,
+ 0,
+ 0,
+ 0.01766844280064106,
+ -0.0034187675919383764,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.1122211441397667,
+ 0,
+ 0,
+ 0,
+ 0.009442206472158432,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.00800288561731577,
+ 0,
+ 0.006613056641072035,
+ 0,
+ -0.06462590396404266,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0009047402418218553,
+ 0,
+ 0,
+ -0.0005877931835129857,
+ 0,
+ 0,
+ 0,
+ -0.0004729636711999774,
+ 0,
+ -0.05036322772502899,
+ -0.24687804281711578,
+ 0,
+ 0.001115482416935265,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0024291854351758957,
+ 0,
+ 0,
+ 0,
+ -0.029154174029827118,
+ -0.011211197823286057,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0075091151520609856,
+ 0,
+ 0.0037634933833032846,
+ 0.022711526602506638,
+ 0,
+ 0,
+ -0.00011145337339257821,
+ 0,
+ -0.08350298553705215,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.0063380529172718525,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0010615212377160788,
+ 0,
+ 0,
+ 0,
+ 0.001314864493906498,
+ 0,
+ 0,
+ -0.0020079570822417736,
+ 0,
+ 0,
+ 0,
+ -0.095857173204422,
+ 0,
+ 0,
+ 0.04977884888648987,
+ 0.04924672096967697,
+ 0,
+ 0.00675918348133564,
+ 0,
+ 0,
+ 0.02823697216808796,
+ -0.07869893312454224,
+ 0,
+ 0,
+ -0.00039145027403719723,
+ 0,
+ 0,
+ 0,
+ -0.03502006456255913,
+ 0,
+ 0,
+ -0.004709419794380665,
+ -0.007543480955064297,
+ -0.007213911972939968,
+ 0.0026987697929143906,
+ 0,
+ 0,
+ 0,
+ -0.0016787010245025158,
+ 0,
+ 0,
+ -0.002866228111088276,
+ -0.04759479686617851,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.005348640959709883,
+ 0,
+ 0.17661413550376892,
+ 0,
+ -0.0024743194226175547,
+ 0,
+ 0.0269751138985157,
+ 0,
+ 0,
+ 0,
+ -0.025461290031671524,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.14607883989810944,
+ 0,
+ 0.020490022376179695,
+ 0.007573024369776249,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.8939738869667053,
+ 0,
+ -0.006900197826325893,
+ -0.0031849159859120846,
+ -0.015817783772945404,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.00032859406201168895,
+ 0.11629504710435867,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.0020032059401273727,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.02256190776824951,
+ 0.07616151124238968,
+ 0.3106333911418915,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.014044971205294132,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.3483165502548218,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.05930393189191818,
+ 0,
+ 0,
+ 0.004992437083274126,
+ -0.08404884487390518,
+ 0,
+ 0,
+ 0,
+ 0.16281214356422424,
+ 0,
+ 0.28443410992622375,
+ 0,
+ 0,
+ 0.0014393558958545327,
+ 0,
+ 0,
+ -0.009063852950930595,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.001169737195596099,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.06898342072963715,
+ 0,
+ 0.007991905324161053,
+ 0,
+ 0.006260615773499012,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.037955716252326965,
+ 0.3505173921585083,
+ 0,
+ 0,
+ 0,
+ -0.338177889585495,
+ 0,
+ 0.158599853515625,
+ 0,
+ 0,
+ 0.01131439208984375,
+ 0,
+ 0.006751265376806259,
+ 0,
+ 0,
+ 0,
+ -0.04573351889848709,
+ 0,
+ 0,
+ 0.04386100172996521,
+ 0.11277603358030319,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.0003205372195225209,
+ 0,
+ 0,
+ 0,
+ -0.005409737583249807,
+ -0.009204162284731865,
+ 0,
+ 0,
+ 0,
+ 0.004804544150829315,
+ -0.005810749251395464,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.09645535796880722,
+ 0,
+ 0,
+ 0,
+ 0.032931435853242874,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.028524864464998245,
+ 0.09402520954608917,
+ 0,
+ 0.008998546749353409,
+ 0,
+ -0.03251685947179794,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.037343256175518036,
+ 0.0045840502716600895,
+ 0,
+ 0,
+ 0.009021798148751259,
+ 0,
+ -0.004217533860355616,
+ 0,
+ 0.0025705555453896523,
+ 0,
+ -0.035309672355651855,
+ -0.09942735731601715,
+ -0.004700342193245888,
+ 0,
+ 0,
+ 0,
+ 0.018288278952240944,
+ 0.0004021169152110815,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.005593586713075638,
+ 0.0018821493722498417,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.04561242088675499,
+ -0.01815006509423256,
+ 0.016583485528826714,
+ -0.020843051373958588,
+ 0,
+ 0,
+ 0,
+ -0.006372869946062565,
+ 0.04272369295358658,
+ 0,
+ 0.0013309348141774535,
+ 0,
+ 0,
+ 0,
+ -0.0031638317741453648,
+ 0,
+ 0.08714215457439423,
+ 0,
+ -0.005442100111395121,
+ -0.00039313771412707865,
+ 0.0014464370906352997,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.03132649511098862,
+ 0,
+ 0,
+ -0.007972904480993748,
+ 0.01753396727144718,
+ 0,
+ -0.0011563192820176482,
+ 0,
+ 0.0017362519865855575,
+ 0,
+ -0.0004587600124068558,
+ 0,
+ 0.0038881096988916397,
+ -0.008516360074281693,
+ 0,
+ -0.008183307014405727,
+ -0.010095844976603985,
+ 0,
+ 0.10722006857395172,
+ -0.002898464212194085,
+ 0,
+ 0.0012827662285417318,
+ 0.00004252225699019618,
+ 0,
+ 0.07567721605300903,
+ -0.030121177434921265,
+ 0,
+ -0.0010666534071788192,
+ 0.0006539365276694298,
+ 0,
+ 0,
+ -0.0011567147448658943,
+ -0.021622339263558388,
+ 0.028687214478850365,
+ 0,
+ 0,
+ 0.018764594569802284,
+ 0.010613140650093555,
+ 0.019510075449943542,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.005288010463118553,
+ 0,
+ 0,
+ -0.016743114218115807,
+ -0.015873711556196213,
+ -0.009877816773951054,
+ 0.0003150522243231535,
+ 0.023689158260822296,
+ 0,
+ -0.00033418016391806304,
+ 0,
+ 0.04904749244451523,
+ 0,
+ 0,
+ 0,
+ 0.0006500506424345076,
+ 0.000622213410679251,
+ 0.00738720316439867,
+ 0,
+ 0.0012243357487022877,
+ 0.009066173806786537,
+ -0.12073952704668045,
+ 0.0010678119724616408,
+ 0.006296947598457336,
+ 0.002682592486962676,
+ -0.00444818427786231,
+ -0.0023324599023908377,
+ 0,
+ 0,
+ -0.5609893798828125,
+ -0.008780602365732193,
+ 0.015986066311597824,
+ 0,
+ -0.02213476411998272,
+ 0,
+ 0,
+ 0.0006705078994855285,
+ -0.0011221399763599038,
+ 0,
+ 0,
+ 0.025299811735749245,
+ 0,
+ -0.008218510076403618,
+ -0.0034782839938998222,
+ 0.0022423912305384874,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.05859537422657013,
+ 0.0421387143433094,
+ 0.26256099343299866,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.10330676287412643,
+ 0.012355834245681763,
+ 0.0013472040882334113,
+ 0,
+ 0.019914263859391212,
+ 0.005261276848614216,
+ 0.001149827498011291,
+ -0.03320133313536644,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.32198745012283325,
+ 0,
+ 0.05401667580008507,
+ 0.04610951617360115,
+ -0.2326284795999527,
+ 0,
+ 0.0000856258993735537,
+ 0,
+ 0.074106365442276,
+ 0,
+ 0.044469863176345825,
+ 0,
+ 0,
+ 0,
+ -0.006453251000493765,
+ 0,
+ -0.018431225791573524,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.005704954732209444,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.13457728922367096,
+ 0,
+ -0.029186677187681198,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ -0.022995056584477425,
+ 0,
+ 0,
+ 0,
+ -0.09004921466112137,
+ 0.24257110059261322,
+ 0.02852930873632431,
+ 0,
+ 0.021270141005516052,
+ -0.13564155995845795,
+ -0.03098711557686329,
+ 0,
+ 0,
+ 0,
+ 0.0486220121383667,
+ -0.001395023544318974,
+ 0,
+ 0.04929636791348457,
+ 0,
+ 0,
+ -0.13068373501300812,
+ 0,
+ 0,
+ 0.016955919563770294,
+ 0.03848254308104515,
+ 0,
+ -0.011160435155034065,
+ 0,
+ 0,
+ -0.0002991429646499455,
+ 0,
+ 0,
+ 0,
+ 0.01138608530163765,
+ -0.020150866359472275,
+ 0,
+ 0,
+ -0.007353566121309996,
+ -0.021389631554484367,
+ -0.042083244770765305,
+ 0.13586723804473877,
+ 0.005315479822456837,
+ 0,
+ 0.008157049305737019,
+ 0.022239860147237778,
+ 0,
+ 0,
+ 0.01896926946938038,
+ 0.0018052944215014577,
+ 0.0016496418975293636,
+ 0.0005593635141849518,
+ 0,
+ 0,
+ 0.07655386626720428,
+ 0,
+ 0,
+ 0,
+ 0.02781328558921814,
+ 0.04012482985854149,
+ 0,
+ 0,
+ 0,
+ 0.10631410032510757,
+ 0.03608629107475281,
+ 0,
+ -0.02651066705584526,
+ 0,
+ -0.0690990686416626,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0.1022648885846138
+ ],
+ "yaxis": "y"
+ }
+ ],
+ "layout": {
+ "legend": {
+ "tracegroupgap": 0
+ },
+ "shapes": [
+ {
+ "line": {
+ "color": "gray",
+ "dash": "dot",
+ "width": 1
+ },
+ "type": "line",
+ "x0": -0.8653240203857422,
+ "x1": 0.31668567657470703,
+ "y0": -0.8653240203857422,
+ "y1": 0.31668567657470703
+ }
+ ],
+ "template": {
+ "data": {
+ "bar": [
+ {
+ "error_x": {
+ "color": "#2a3f5f"
+ },
+ "error_y": {
+ "color": "#2a3f5f"
+ },
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "bar"
+ }
+ ],
+ "barpolar": [
+ {
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "barpolar"
+ }
+ ],
+ "carpet": [
+ {
+ "aaxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "baxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "type": "carpet"
+ }
+ ],
+ "choropleth": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "choropleth"
+ }
+ ],
+ "contour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "contour"
+ }
+ ],
+ "contourcarpet": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "contourcarpet"
+ }
+ ],
+ "heatmap": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmap"
+ }
+ ],
+ "heatmapgl": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmapgl"
+ }
+ ],
+ "histogram": [
+ {
+ "marker": {
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "histogram"
+ }
+ ],
+ "histogram2d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2d"
+ }
+ ],
+ "histogram2dcontour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2dcontour"
+ }
+ ],
+ "mesh3d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "mesh3d"
+ }
+ ],
+ "parcoords": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "parcoords"
+ }
+ ],
+ "pie": [
+ {
+ "automargin": true,
+ "type": "pie"
+ }
+ ],
+ "scatter": [
+ {
+ "fillpattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ },
+ "type": "scatter"
+ }
+ ],
+ "scatter3d": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatter3d"
+ }
+ ],
+ "scattercarpet": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattercarpet"
+ }
+ ],
+ "scattergeo": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergeo"
+ }
+ ],
+ "scattergl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergl"
+ }
+ ],
+ "scattermapbox": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattermapbox"
+ }
+ ],
+ "scatterpolar": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolar"
+ }
+ ],
+ "scatterpolargl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolargl"
+ }
+ ],
+ "scatterternary": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterternary"
+ }
+ ],
+ "surface": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "surface"
+ }
+ ],
+ "table": [
+ {
+ "cells": {
+ "fill": {
+ "color": "#EBF0F8"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "header": {
+ "fill": {
+ "color": "#C8D4E3"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "type": "table"
+ }
+ ]
+ },
+ "layout": {
+ "annotationdefaults": {
+ "arrowcolor": "#2a3f5f",
+ "arrowhead": 0,
+ "arrowwidth": 1
+ },
+ "autotypenumbers": "strict",
+ "coloraxis": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "colorscale": {
+ "diverging": [
+ [
+ 0,
+ "#8e0152"
+ ],
+ [
+ 0.1,
+ "#c51b7d"
+ ],
+ [
+ 0.2,
+ "#de77ae"
+ ],
+ [
+ 0.3,
+ "#f1b6da"
+ ],
+ [
+ 0.4,
+ "#fde0ef"
+ ],
+ [
+ 0.5,
+ "#f7f7f7"
+ ],
+ [
+ 0.6,
+ "#e6f5d0"
+ ],
+ [
+ 0.7,
+ "#b8e186"
+ ],
+ [
+ 0.8,
+ "#7fbc41"
+ ],
+ [
+ 0.9,
+ "#4d9221"
+ ],
+ [
+ 1,
+ "#276419"
+ ]
+ ],
+ "sequential": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "sequentialminus": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ]
+ },
+ "colorway": [
+ "#636efa",
+ "#EF553B",
+ "#00cc96",
+ "#ab63fa",
+ "#FFA15A",
+ "#19d3f3",
+ "#FF6692",
+ "#B6E880",
+ "#FF97FF",
+ "#FECB52"
+ ],
+ "font": {
+ "color": "#2a3f5f"
+ },
+ "geo": {
+ "bgcolor": "white",
+ "lakecolor": "white",
+ "landcolor": "#E5ECF6",
+ "showlakes": true,
+ "showland": true,
+ "subunitcolor": "white"
+ },
+ "hoverlabel": {
+ "align": "left"
+ },
+ "hovermode": "closest",
+ "mapbox": {
+ "style": "light"
+ },
+ "paper_bgcolor": "white",
+ "plot_bgcolor": "#E5ECF6",
+ "polar": {
+ "angularaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "radialaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "scene": {
+ "xaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "yaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "zaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ }
+ },
+ "shapedefaults": {
+ "line": {
+ "color": "#2a3f5f"
+ }
+ },
+ "ternary": {
+ "aaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "baxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "caxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "title": {
+ "x": 0.05
+ },
+ "xaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ },
+ "yaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ }
+ }
+ },
+ "title": {
+ "text": "Attribution vs Activation Patching Per SAE feature (L5 S2 Pos, all prompts)"
+ },
+ "xaxis": {
+ "anchor": "y",
+ "domain": [
+ 0,
+ 1
+ ],
+ "title": {
+ "text": "Activation Patch"
+ }
+ },
+ "yaxis": {
+ "anchor": "x",
+ "domain": [
+ 0,
+ 1
+ ],
+ "title": {
+ "text": "Attribution Patch"
+ }
+ }
+ }
+ },
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig = scatter(\n",
+ " y=sae_act_attr[:, s2_pos, all_live_features].flatten(), \n",
+ " x=causal_effects.flatten(),\n",
+ " title=\"Attribution vs Activation Patching Per SAE feature (L5 S2 Pos, all prompts)\",\n",
+ " xaxis=\"Activation Patch\",\n",
+ " yaxis=\"Attribution Patch\",\n",
+ " return_fig=True\n",
+ ")\n",
+ "fig.add_shape(\n",
+ " type='line',\n",
+ " x0=causal_effects.min(),\n",
+ " y0=causal_effects.min(),\n",
+ " x1=causal_effects.max(),\n",
+ " y1=causal_effects.max(),\n",
+ " line=dict(\n",
+ " color='gray',\n",
+ " width=1,\n",
+ " dash='dot'\n",
+ " )\n",
+ ")\n",
+ "fig.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "base",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/demos/LLaMA.ipynb b/demos/LLaMA.ipynb
index 9df019d7c..9e9f428e6 100644
--- a/demos/LLaMA.ipynb
+++ b/demos/LLaMA.ipynb
@@ -14,73 +14,40 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# LLaMA and Llama-2 in TransformerLens\n",
- "\n",
- "This demo requires `transformers` version 4.31.0 (which adds Llama-2 support). This tutorial has part a) for LLaMA and b) for Llama-2. Currently the only Llama-2 support is the 7B chat model, as this notebook is being tested.\n",
- "\n",
- "Steps to run this demo:\n",
- "\n",
- "1a. Get LLaMA weights here: https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform\n",
- "\n",
- "1b. Get Llama-2 weights here: https://ai.meta.com/resources/models-and-libraries/llama-downloads/\n",
- "\n",
- "2a. Convert the official weights to huggingface. \n",
- "\n",
- "```bash\n",
- "python src/transformers/models/llama/convert_llama_weights_to_hf.py \\\n",
- " --input_dir /path/to/downloaded/llama/weights \\\n",
- " --model_size 7B \\\n",
- " --output_dir /llama/weights/directory/\n",
- "```\n",
- "\n",
- "2b. Same step for Llama-2, we'll use `7Bf` the 7B chat version\n",
- "\n",
- "```bash\n",
- "python src/transformers/models/llama/convert_llama_weights_to_hf.py \\\n",
- " --input_dir /path/to/downloaded/llama-2/weights \\\n",
- " --model_size 7Bf \\\n",
- " --output_dir /llama/weights/directory/\n",
- "```\n",
- "\n",
- "Note: this didn't work for Arthur by default (even though HF doesn't seem to show this anywhere). I had to change this line of my pip installed `src/transformers/models/llama/convert_llama_weights_to_hf.py` file (which was found at `/opt/conda/envs/arthurenv/lib/python3.10/site-packages/transformers/models/llama/convert_llama_weights_to_hf.py`) from \n",
- "\n",
- "`input_base_path=os.path.join(args.input_dir, args.model_size),` to `input_base_path=os.path.join(args.input_dir),`\n",
- "\n",
- "3. Change the ```MODEL_PATH``` variable in the cell below to where the converted weights are stored."
+ "# LLaMA and Llama-2 in TransformerLens"
]
},
{
- "cell_type": "code",
- "execution_count": 1,
+ "attachments": {},
+ "cell_type": "markdown",
"metadata": {},
- "outputs": [],
"source": [
- "from typing import Literal\n",
- "\n",
- "MODE: Literal[\"LLaMA\", \"Llama-2\"] = \"Llama-2\" # change to LLaMA for original LLaMA\n",
- "MODEL_PATH: str = \"\" # Set the path to the /llama/weights/directory/ that you used in the command"
+ "## Setup (skip)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
- "outputs": [],
- "source": [
- "!pip install transformers>=4.31.0 # Llama requires transformers>=4.31.0 and transformers in turn requires Python 3.8"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Note: you may need to restart the kernel to use updated packages.\n",
+ "Requirement already satisfied: sentencepiece in /root/TransformerLens/.venv/lib/python3.10/site-packages (0.1.99)\n",
+ "Note: you may need to restart the kernel to use updated packages.\n"
+ ]
+ }
+ ],
"source": [
- "## Setup (skip)"
+ "%pip install transformers>=4.31.0 # Llama requires transformers>=4.31.0 and transformers in turn requires Python 3.8\n",
+ "%pip install sentencepiece # Llama tokenizer requires sentencepiece"
]
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 2,
"metadata": {},
"outputs": [
{
@@ -94,9 +61,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "/tmp/ipykernel_20722/410710250.py:21: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
+ "/tmp/ipykernel_16979/572068249.py:21: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
" ipython.magic(\"load_ext autoreload\")\n",
- "/tmp/ipykernel_20722/410710250.py:22: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
+ "/tmp/ipykernel_16979/572068249.py:22: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
" ipython.magic(\"autoreload 2\")\n"
]
}
@@ -128,7 +95,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -153,45 +120,25 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Import stuff\n",
"import torch\n",
- "import torch.nn as nn\n",
- "import torch.nn.functional as F\n",
- "import torch.optim as optim\n",
- "import numpy as np\n",
- "import einops\n",
- "from fancy_einsum import einsum\n",
"import tqdm.auto as tqdm\n",
- "from tqdm import tqdm\n",
- "import random\n",
- "from pathlib import Path\n",
"import plotly.express as px\n",
- "from torch.utils.data import DataLoader\n",
- "\n",
- "from torchtyping import TensorType as TT\n",
- "from typing import List, Union, Optional\n",
- "from jaxtyping import Float, Int\n",
- "from functools import partial\n",
- "import copy\n",
"\n",
- "import itertools\n",
- "from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n",
- "import dataclasses\n",
- "import datasets\n",
- "from IPython.display import HTML\n",
- "# import circuitsvis as cv\n",
+ "from transformers import LlamaForCausalLM, LlamaTokenizer\n",
+ "from tqdm import tqdm\n",
+ "from jaxtyping import Float\n",
"\n",
"import transformer_lens\n",
"import transformer_lens.utils as utils\n",
"from transformer_lens.hook_points import (\n",
- " HookedRootModule,\n",
" HookPoint,\n",
") # Hooking utilities\n",
- "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache\n",
+ "from transformer_lens import HookedTransformer\n",
"\n",
"torch.set_grad_enabled(False)\n",
"\n",
@@ -208,29 +155,95 @@
]
},
{
- "attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Loading model"
+ "## Loading LLaMA"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "LLaMA weights are not available on HuggingFace, so you'll need to download and convert them\n",
+ "manually:\n",
+ "\n",
+ "1. Get LLaMA weights here: https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform\n",
+ "\n",
+ "2. Convert the official weights to huggingface:\n",
+ "\n",
+ "```bash\n",
+ "python src/transformers/models/llama/convert_llama_weights_to_hf.py \\\n",
+ " --input_dir /path/to/downloaded/llama/weights \\\n",
+ " --model_size 7B \\\n",
+ " --output_dir /llama/weights/directory/\n",
+ "```\n",
+ "\n",
+ "Note: this didn't work for Arthur by default (even though HF doesn't seem to show this anywhere). I\n",
+ "had to change this\n",
+ "line of my pip installed `src/transformers/models/llama/convert_llama_weights_to_hf.py` file (which\n",
+ "was found at\n",
+ "`/opt/conda/envs/arthurenv/lib/python3.10/site-packages/transformers/models/llama/convert_llama_weights_to_hf.py`)\n",
+ "from `input_base_path=os.path.join(args.input_dir, args.model_size),` to `input_base_path=os.path.join(args.input_dir),`\n",
+ "\n",
+ "3. Change the ```MODEL_PATH``` variable in the cell below to where the converted weights are stored."
]
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "MODEL_PATH=''\n",
+ "\n",
+ "tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)\n",
+ "hf_model = LlamaForCausalLM.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True)\n",
+ "\n",
+ "model = HookedTransformer.from_pretrained(\"llama-7b\", hf_model=hf_model, device=\"cpu\", fold_ln=False, center_writing_weights=False, center_unembed=False, tokenizer=tokenizer)\n",
+ "\n",
+ "model = model.to(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ "model.generate(\"The capital of Germany is\", max_new_tokens=20, temperature=0)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Loading LLaMA-2\n",
+ "LLaMA-2 is hosted on HuggingFace, but gated by login.\n",
+ "\n",
+ "Before running the notebook, log in to HuggingFace via the cli on your machine:\n",
+ "```bash\n",
+ "transformers-cli login\n",
+ "```\n",
+ "This will cache your HuggingFace credentials, and enable you to download LLaMA-2."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "You are using the legacy behaviour of the . This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565\n"
- ]
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "1821773a30ad4a56960ccae34e8e6a3d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "b70ece404dc04bf2956493196d7baae5",
+ "model_id": "6151bf52dddd49aaa9a88d536ec7bdd8",
"version_major": 2,
"version_minor": 0
},
@@ -241,35 +254,10 @@
"metadata": {},
"output_type": "display_data"
},
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Some weights of LlamaForCausalLM were not initialized from the model checkpoint at /root/lam_out and are newly initialized: ['model.layers.10.self_attn.rotary_emb.cos_cached', 'model.layers.19.self_attn.rotary_emb.sin_cached', 'model.layers.28.self_attn.rotary_emb.cos_cached', 'model.layers.2.self_attn.rotary_emb.cos_cached', 'model.layers.15.self_attn.rotary_emb.sin_cached', 'model.layers.30.self_attn.rotary_emb.cos_cached', 'model.layers.6.self_attn.rotary_emb.sin_cached', 'model.layers.19.self_attn.rotary_emb.cos_cached', 'model.layers.13.self_attn.rotary_emb.cos_cached', 'model.layers.13.self_attn.rotary_emb.sin_cached', 'model.layers.26.self_attn.rotary_emb.sin_cached', 'model.layers.23.self_attn.rotary_emb.sin_cached', 'model.layers.1.self_attn.rotary_emb.sin_cached', 'model.layers.24.self_attn.rotary_emb.cos_cached', 'model.layers.7.self_attn.rotary_emb.sin_cached', 'model.layers.1.self_attn.rotary_emb.cos_cached', 'model.layers.10.self_attn.rotary_emb.sin_cached', 'model.layers.27.self_attn.rotary_emb.sin_cached', 'model.layers.8.self_attn.rotary_emb.sin_cached', 'model.layers.11.self_attn.rotary_emb.cos_cached', 'model.layers.22.self_attn.rotary_emb.cos_cached', 'model.layers.12.self_attn.rotary_emb.sin_cached', 'model.layers.6.self_attn.rotary_emb.cos_cached', 'model.layers.2.self_attn.rotary_emb.sin_cached', 'model.layers.22.self_attn.rotary_emb.sin_cached', 'model.layers.11.self_attn.rotary_emb.sin_cached', 'model.layers.5.self_attn.rotary_emb.cos_cached', 'model.layers.3.self_attn.rotary_emb.sin_cached', 'model.layers.9.self_attn.rotary_emb.cos_cached', 'model.layers.14.self_attn.rotary_emb.sin_cached', 'model.layers.21.self_attn.rotary_emb.cos_cached', 'model.layers.8.self_attn.rotary_emb.cos_cached', 'model.layers.0.self_attn.rotary_emb.cos_cached', 'model.layers.21.self_attn.rotary_emb.sin_cached', 'model.layers.16.self_attn.rotary_emb.sin_cached', 'model.layers.15.self_attn.rotary_emb.cos_cached', 'model.layers.18.self_attn.rotary_emb.sin_cached', 'model.layers.25.self_attn.rotary_emb.sin_cached', 'model.layers.28.self_attn.rotary_emb.sin_cached', 'model.layers.12.self_attn.rotary_emb.cos_cached', 'model.layers.17.self_attn.rotary_emb.cos_cached', 'model.layers.9.self_attn.rotary_emb.sin_cached', 'model.layers.30.self_attn.rotary_emb.sin_cached', 'model.layers.3.self_attn.rotary_emb.cos_cached', 'model.layers.5.self_attn.rotary_emb.sin_cached', 'model.layers.7.self_attn.rotary_emb.cos_cached', 'model.layers.16.self_attn.rotary_emb.cos_cached', 'model.layers.29.self_attn.rotary_emb.cos_cached', 'model.layers.31.self_attn.rotary_emb.cos_cached', 'model.layers.0.self_attn.rotary_emb.sin_cached', 'model.layers.17.self_attn.rotary_emb.sin_cached', 'model.layers.4.self_attn.rotary_emb.sin_cached', 'model.layers.31.self_attn.rotary_emb.sin_cached', 'model.layers.25.self_attn.rotary_emb.cos_cached', 'model.layers.23.self_attn.rotary_emb.cos_cached', 'model.layers.20.self_attn.rotary_emb.cos_cached', 'model.layers.24.self_attn.rotary_emb.sin_cached', 'model.layers.18.self_attn.rotary_emb.cos_cached', 'model.layers.27.self_attn.rotary_emb.cos_cached', 'model.layers.4.self_attn.rotary_emb.cos_cached', 'model.layers.29.self_attn.rotary_emb.sin_cached', 'model.layers.14.self_attn.rotary_emb.cos_cached', 'model.layers.26.self_attn.rotary_emb.cos_cached', 'model.layers.20.self_attn.rotary_emb.sin_cached']\n",
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
- ]
- }
- ],
- "source": [
- "from transformers import LlamaForCausalLM, LlamaTokenizer\n",
- "import os\n",
- "\n",
- "MODEL_PATH=''\n",
- "\n",
- "tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)\n",
- "hf_model = LlamaForCausalLM.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Warning: LLaMA tokenizer not loaded. Please load manually.\n",
"Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer\n",
"Moving model to device: cuda\n"
]
@@ -277,7 +265,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "325b8468e6e149afb313208d5a4d2744",
+ "model_id": "9ef4cd62eb5a45c2b971bd652aa40479",
"version_major": 2,
"version_minor": 0
},
@@ -291,22 +279,22 @@
{
"data": {
"text/plain": [
- "'The capital of Germany is Berlin. Berlin is the largest city in Germany and is known for its rich history, cultural attractions'"
+ "'The capital of Germany is Berlin. Berlin is the largest city in Germany and is known for its rich history, cultural attractions'"
]
},
- "execution_count": 7,
+ "execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "# Loading on CPU is cheapest memory wise in transformer_lens \n",
- "if MODE == \"LLaMA\":\n",
- " model = HookedTransformer.from_pretrained(\"llama-7b\", hf_model=hf_model, device=\"cpu\", fold_ln=False, center_writing_weights=False, center_unembed=False, tokenizer=tokenizer)\n",
+ "LLAMA_2_7B_CHAT_PATH = \"meta-llama/Llama-2-7b-chat-hf\"\n",
+ "\n",
+ "tokenizer = LlamaTokenizer.from_pretrained(LLAMA_2_7B_CHAT_PATH)\n",
+ "hf_model = LlamaForCausalLM.from_pretrained(LLAMA_2_7B_CHAT_PATH, low_cpu_mem_usage=True)\n",
+ "\n",
+ "model = HookedTransformer.from_pretrained(LLAMA_2_7B_CHAT_PATH, device=\"cpu\", fold_ln=False, center_writing_weights=False, center_unembed=False)\n",
"\n",
- "elif MODE == \"Llama-2\":\n",
- " model = HookedTransformer.from_pretrained(\"meta-llama/Llama-2-7b-chat-hf\", hf_model=hf_model, device=\"cpu\", fold_ln=False, center_writing_weights=False, center_unembed=False, tokenizer=tokenizer)\n",
- " \n",
"model = model.to(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"model.generate(\"The capital of Germany is\", max_new_tokens=20, temperature=0)"
]
@@ -321,22 +309,15 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- " 0%| | 0/4 [00:00, ?it/s]"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|āāāāāāāāāā| 4/4 [00:00<00:00, 9.85it/s]\n",
- "100%|āāāāāāāāāā| 4/4 [01:11<00:00, 17.86s/it]\n"
+ "100%|āāāāāāāāāā| 4/4 [00:00<00:00, 9.56it/s]\n",
+ "100%|āāāāāāāāāā| 4/4 [00:22<00:00, 5.70s/it]\n"
]
}
],
@@ -357,7 +338,7 @@
"logits = [hf_model(prompt_ids).logits.detach().cpu() for prompt_ids in tqdm(prompt_ids)]\n",
"\n",
"for i in range(len(prompts)): \n",
- " assert torch.allclose(logits[i], tl_logits[i], atol=1e-2, rtol=1e-2) # Llama-2 doesn't seem to pass at 1e-3 anymore"
+ " assert torch.allclose(logits[i], tl_logits[i], atol=1e-4, rtol=1e-2)"
]
},
{
@@ -378,7 +359,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
@@ -391,23 +372,22 @@
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 11,
"metadata": {},
- "output_type": "execute_result"
+ "output_type": "display_data"
}
],
"source": [
@@ -419,7 +399,7 @@
"llama_str_tokens = model.to_str_tokens(llama_text)\n",
"\n",
"print(\"Layer 0 Head Attention Patterns:\")\n",
- "cv.attention.attention_patterns(tokens=llama_str_tokens, attention=attention_pattern)"
+ "display(cv.attention.attention_patterns(tokens=llama_str_tokens, attention=attention_pattern))"
]
},
{
@@ -432,7 +412,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
@@ -440,8 +420,8 @@
"output_type": "stream",
"text": [
"Shape of the value tensor: torch.Size([1, 34, 32, 128])\n",
- "Original Loss: 2.933\n",
- "Ablated Loss: 2.881\n"
+ "Original Loss: 2.931\n",
+ "Ablated Loss: 2.879\n"
]
}
],
@@ -490,7 +470,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.8"
+ "version": "3.10.13"
},
"orig_nbformat": 4,
"vscode": {
diff --git a/demos/LLaMA2_GPU_quantized.ipynb b/demos/LLaMA2_GPU_quantized.ipynb
new file mode 100644
index 000000000..58631a21e
--- /dev/null
+++ b/demos/LLaMA2_GPU_quantized.ipynb
@@ -0,0 +1,4806 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "EyASOtpeCUsO"
+ },
+ "source": [
+ "# LLaMA and Llama-2 in TransformerLens"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QnUOM0-RCUsO"
+ },
+ "source": [
+ "## Setup (skip)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "HssVtL08CUsP",
+ "outputId": "5ad91c32-95e8-4970-99ec-242f9e2ebab2"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (0.1.99)\n"
+ ]
+ }
+ ],
+ "source": [
+ "%pip install transformers>=4.31.0 # Llama requires transformers>=4.31.0 and transformers in turn requires Python 3.8\n",
+ "%pip install sentencepiece # Llama tokenizer requires sentencepiece"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "DawCWbiaCUsR",
+ "outputId": "3f527879-cbd3-42b5-8e72-ba70dc906d79"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Running as a Colab notebook\n",
+ "Collecting git+https://github.com/coolvision/TransformerLens.git@llama_4bit_v2\n",
+ " Cloning https://github.com/coolvision/TransformerLens.git (to revision llama_4bit_v2) to /tmp/pip-req-build-lpt2rmoh\n",
+ " Running command git clone --filter=blob:none --quiet https://github.com/coolvision/TransformerLens.git /tmp/pip-req-build-lpt2rmoh\n",
+ " Running command git checkout -b llama_4bit_v2 --track origin/llama_4bit_v2\n",
+ " Switched to a new branch 'llama_4bit_v2'\n",
+ " Branch 'llama_4bit_v2' set up to track remote branch 'llama_4bit_v2' from 'origin'.\n",
+ " Resolved https://github.com/coolvision/TransformerLens.git to commit b2b80cb92f4aa6d63a456196f0c3472b3d34c6eb\n",
+ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "Requirement already satisfied: accelerate>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.26.1)\n",
+ "Requirement already satisfied: beartype<0.15.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.14.1)\n",
+ "Requirement already satisfied: datasets>=2.7.1 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (2.16.1)\n",
+ "Requirement already satisfied: einops>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.7.0)\n",
+ "Requirement already satisfied: fancy-einsum>=0.0.3 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.0.3)\n",
+ "Requirement already satisfied: jaxtyping>=0.2.11 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.2.25)\n",
+ "Requirement already satisfied: numpy>=1.24 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (1.26.3)\n",
+ "Requirement already satisfied: pandas>=1.1.5 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (1.5.3)\n",
+ "Requirement already satisfied: rich>=12.6.0 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (13.7.0)\n",
+ "Requirement already satisfied: torch!=2.0,!=2.1.0,>=1.10 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (2.1.2)\n",
+ "Requirement already satisfied: tqdm>=4.64.1 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (4.66.1)\n",
+ "Requirement already satisfied: transformers>=4.25.1 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (4.35.2)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (4.5.0)\n",
+ "Requirement already satisfied: wandb>=0.13.5 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.16.2)\n",
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.23.0->transformer-lens==0.0.0) (23.2)\n",
+ "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.23.0->transformer-lens==0.0.0) (5.9.5)\n",
+ "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.23.0->transformer-lens==0.0.0) (6.0.1)\n",
+ "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.23.0->transformer-lens==0.0.0) (0.20.2)\n",
+ "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.23.0->transformer-lens==0.0.0) (0.4.1)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (3.13.1)\n",
+ "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (10.0.1)\n",
+ "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (0.6)\n",
+ "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (0.3.7)\n",
+ "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (2.31.0)\n",
+ "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (3.4.1)\n",
+ "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (0.70.15)\n",
+ "Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (2023.6.0)\n",
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (3.9.1)\n",
+ "Requirement already satisfied: typeguard<3,>=2.13.3 in /usr/local/lib/python3.10/dist-packages (from jaxtyping>=0.2.11->transformer-lens==0.0.0) (2.13.3)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.1.5->transformer-lens==0.0.0) (2.8.2)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.1.5->transformer-lens==0.0.0) (2023.3.post1)\n",
+ "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=12.6.0->transformer-lens==0.0.0) (3.0.0)\n",
+ "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=12.6.0->transformer-lens==0.0.0) (2.16.1)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (1.12)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (3.2.1)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (3.1.3)\n",
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (8.9.2.26)\n",
+ "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (12.1.3.1)\n",
+ "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (11.0.2.54)\n",
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (10.3.2.106)\n",
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (11.4.5.107)\n",
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (12.1.0.106)\n",
+ "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (2.18.1)\n",
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (12.1.105)\n",
+ "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (2.1.0)\n",
+ "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (12.3.101)\n",
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.25.1->transformer-lens==0.0.0) (2023.6.3)\n",
+ "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.25.1->transformer-lens==0.0.0) (0.15.0)\n",
+ "Requirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (8.1.7)\n",
+ "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (3.1.41)\n",
+ "Requirement already satisfied: sentry-sdk>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (1.39.2)\n",
+ "Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (0.4.0)\n",
+ "Requirement already satisfied: setproctitle in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (1.3.3)\n",
+ "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (67.7.2)\n",
+ "Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (1.4.4)\n",
+ "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (3.20.3)\n",
+ "Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb>=0.13.5->transformer-lens==0.0.0) (1.16.0)\n",
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (23.2.0)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (6.0.4)\n",
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (1.9.4)\n",
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (1.4.1)\n",
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (1.3.1)\n",
+ "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (4.0.3)\n",
+ "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens==0.0.0) (4.0.11)\n",
+ "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=12.6.0->transformer-lens==0.0.0) (0.1.2)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer-lens==0.0.0) (3.3.2)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer-lens==0.0.0) (3.6)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer-lens==0.0.0) (2.0.7)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer-lens==0.0.0) (2023.11.17)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (2.1.3)\n",
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (1.3.0)\n",
+ "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens==0.0.0) (5.0.1)\n",
+ "Requirement already satisfied: circuitsvis in /usr/local/lib/python3.10/dist-packages (1.43.2)\n",
+ "Requirement already satisfied: importlib-metadata>=5.1.0 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (7.0.1)\n",
+ "Requirement already satisfied: numpy>=1.24 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (1.26.3)\n",
+ "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (12.1.3.1)\n",
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (8.9.2.26)\n",
+ "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (11.0.2.54)\n",
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (10.3.2.106)\n",
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (11.4.5.107)\n",
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (12.1.0.106)\n",
+ "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (2.18.1)\n",
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (12.1.105)\n",
+ "Requirement already satisfied: torch>=1.10 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (2.1.2)\n",
+ "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (2.1.0)\n",
+ "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->circuitsvis) (12.3.101)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from triton==2.1.0->circuitsvis) (3.13.1)\n",
+ "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.10/dist-packages (from importlib-metadata>=5.1.0->circuitsvis) (3.17.0)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->circuitsvis) (4.5.0)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->circuitsvis) (1.12)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->circuitsvis) (3.2.1)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->circuitsvis) (3.1.3)\n",
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->circuitsvis) (2023.6.0)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10->circuitsvis) (2.1.3)\n",
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10->circuitsvis) (1.3.0)\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Janky code to do different setup when run in a Colab notebook vs VSCode\n",
+ "DEVELOPMENT_MODE = False\n",
+ "IN_VSCODE = False\n",
+ "try:\n",
+ " import google.colab\n",
+ " IN_COLAB = True\n",
+ " print(\"Running as a Colab notebook\")\n",
+ " # %pip install git+https://github.com/neelnanda-io/TransformerLens.git``\n",
+ " %pip install git+https://github.com/coolvision/TransformerLens.git@llama_4bit_v2``\n",
+ " %pip install circuitsvis\n",
+ "\n",
+ " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n",
+ " # # Install another version of node that makes PySvelte work way faster\n",
+ " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n",
+ " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n",
+ "except:\n",
+ " IN_COLAB = False\n",
+ " print(\"Running as a Jupyter notebook - intended for development only!\")\n",
+ " from IPython import get_ipython\n",
+ "\n",
+ " ipython = get_ipython()\n",
+ " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n",
+ " ipython.magic(\"load_ext autoreload\")\n",
+ " ipython.magic(\"autoreload 2\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "_OukKSMfCUsR",
+ "outputId": "27a2a59f-e635-4b80-c759-f00d542352bd"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Using renderer: colab\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n",
+ "import plotly.io as pio\n",
+ "if IN_COLAB or not DEVELOPMENT_MODE:\n",
+ " pio.renderers.default = \"colab\"\n",
+ "elif IN_VSCODE:\n",
+ " pio.renderers.default = \"notebook_connected\"\n",
+ "print(f\"Using renderer: {pio.renderers.default}\")\n",
+ "\n",
+ "import circuitsvis as cv"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "id": "P8zS3MPkCUsR"
+ },
+ "outputs": [],
+ "source": [
+ "# Import stuff\n",
+ "import torch\n",
+ "import tqdm.auto as tqdm\n",
+ "import plotly.express as px\n",
+ "\n",
+ "from transformers import LlamaForCausalLM, LlamaTokenizer\n",
+ "from tqdm import tqdm\n",
+ "from jaxtyping import Float\n",
+ "\n",
+ "import transformer_lens\n",
+ "import transformer_lens.utils as utils\n",
+ "from transformer_lens.hook_points import (\n",
+ " HookPoint,\n",
+ ") # Hooking utilities\n",
+ "from transformer_lens import HookedTransformer\n",
+ "\n",
+ "torch.set_grad_enabled(False)\n",
+ "\n",
+ "def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n",
+ " px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale=\"RdBu\", labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n",
+ "\n",
+ "def line(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n",
+ " px.line(utils.to_numpy(tensor), labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n",
+ "\n",
+ "def scatter(x, y, xaxis=\"\", yaxis=\"\", caxis=\"\", renderer=None, **kwargs):\n",
+ " x = utils.to_numpy(x)\n",
+ " y = utils.to_numpy(y)\n",
+ " px.scatter(y=y, x=x, labels={\"x\":xaxis, \"y\":yaxis, \"color\":caxis}, **kwargs).show(renderer)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "iXCfIfKKCUsS",
+ "jp-MarkdownHeadingCollapsed": true
+ },
+ "source": [
+ "## Loading LLaMA"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QH3kyhzFCUsS"
+ },
+ "source": [
+ "LLaMA weights are not available on HuggingFace, so you'll need to download and convert them\n",
+ "manually:\n",
+ "\n",
+ "1. Get LLaMA weights here: https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform\n",
+ "\n",
+ "2. Convert the official weights to huggingface:\n",
+ "\n",
+ "```bash\n",
+ "python src/transformers/models/llama/convert_llama_weights_to_hf.py \\\n",
+ " --input_dir /path/to/downloaded/llama/weights \\\n",
+ " --model_size 7B \\\n",
+ " --output_dir /llama/weights/directory/\n",
+ "```\n",
+ "\n",
+ "Note: this didn't work for Arthur by default (even though HF doesn't seem to show this anywhere). I\n",
+ "had to change this\n",
+ "line of my pip installed `src/transformers/models/llama/convert_llama_weights_to_hf.py` file (which\n",
+ "was found at\n",
+ "`/opt/conda/envs/arthurenv/lib/python3.10/site-packages/transformers/models/llama/convert_llama_weights_to_hf.py`)\n",
+ "from `input_base_path=os.path.join(args.input_dir, args.model_size),` to `input_base_path=os.path.join(args.input_dir),`\n",
+ "\n",
+ "3. Change the ```MODEL_PATH``` variable in the cell below to where the converted weights are stored."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "id": "RdJ0AuW_CUsS"
+ },
+ "outputs": [],
+ "source": [
+ "# MODEL_PATH=''\n",
+ "\n",
+ "# tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)\n",
+ "# hf_model = LlamaForCausalLM.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True)\n",
+ "\n",
+ "# model = HookedTransformer.from_pretrained(\"llama-7b\", hf_model=hf_model, device=\"cpu\", fold_ln=False, center_writing_weights=False, center_unembed=False, tokenizer=tokenizer)\n",
+ "\n",
+ "# model = model.to(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ "# model.generate(\"The capital of Germany is\", max_new_tokens=20, temperature=0)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "UmOqXE9wCUsS"
+ },
+ "source": [
+ "## Loading LLaMA-2\n",
+ "LLaMA-2 is hosted on HuggingFace, but gated by login.\n",
+ "\n",
+ "Before running the notebook, log in to HuggingFace via the cli on your machine:\n",
+ "```bash\n",
+ "transformers-cli login\n",
+ "```\n",
+ "This will cache your HuggingFace credentials, and enable you to download LLaMA-2."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KH6evHq1GGQi"
+ },
+ "source": [
+ "## Install additional dependenceis requred for quantization"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "n26wTL_3GYAO",
+ "outputId": "da381126-148a-43f8-8506-1990d39317f6"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Requirement already satisfied: bitsandbytes in /usr/local/lib/python3.10/dist-packages (0.42.0)\n",
+ "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from bitsandbytes) (1.11.4)\n",
+ "Requirement already satisfied: numpy<1.28.0,>=1.21.6 in /usr/local/lib/python3.10/dist-packages (from scipy->bitsandbytes) (1.26.3)\n",
+ "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (0.26.1)\n",
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate) (1.26.3)\n",
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (23.2)\n",
+ "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n",
+ "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate) (6.0.1)\n",
+ "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.1.2)\n",
+ "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.20.2)\n",
+ "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.4.1)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.13.1)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (4.5.0)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.12)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.2.1)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.3)\n",
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2023.6.0)\n",
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (8.9.2.26)\n",
+ "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (12.1.3.1)\n",
+ "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (11.0.2.54)\n",
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (10.3.2.106)\n",
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (11.4.5.107)\n",
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (12.1.0.106)\n",
+ "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.18.1)\n",
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (12.1.105)\n",
+ "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.1.0)\n",
+ "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.10.0->accelerate) (12.3.101)\n",
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate) (2.31.0)\n",
+ "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate) (4.66.1)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (3.3.2)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (3.6)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (2.0.7)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (2023.11.17)\n",
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n"
+ ]
+ }
+ ],
+ "source": [
+ "%pip install bitsandbytes\n",
+ "%pip install accelerate"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "iOntzU3lGZA6"
+ },
+ "source": [
+ "## Load quantized model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 406,
+ "referenced_widgets": [
+ "27ae531ccf1848b79a636a731bc635b9",
+ "cc67c0d024c848898c5298691db43ed6",
+ "0e3321d5dcc74ce09baeadce8bd5d6e1",
+ "4b3dc4ad2d2f4ac59612b6392a6a77f7",
+ "cfb7c854917e4e3abca3158c0554c081",
+ "fc7ac85b780a4605809bf68f006e1ef3",
+ "2582ee70618844ff8168b8637ddbbae4",
+ "e6b8c1942a54415f998e1f44e796ce1b",
+ "2c9f2db0723d49fbb902d658703cdcb6",
+ "37618acd36d44f68b0dbe07cedd60111",
+ "a31c3597f0a546b0b47bbb873b7c3848",
+ "1b918c0947a54ae89af225ec372351bf",
+ "49502d07f7b143e784f4bf76ed6cc272",
+ "9a3264f31b044a73be31ca4fba3e7f1d",
+ "0eaef657be2742b1867943854370d1ed",
+ "e31ca54d82f64114ab694ce50ba62987",
+ "96e00b8ba2e640d79cd80a1c401dcedd",
+ "537df88aa48b4ae883ec9f1b86a82d11",
+ "55f5712956834686ba2b19a36438b726",
+ "f4a21799b8f84e4687e36314cfc873f0",
+ "143ecc2bee0c42489185fbbe8c785578",
+ "fb16992539ed404383974588a5c9263a",
+ "4e17ea3af1b94b52944a816aba2525ba",
+ "186ff73f09504161b24be544f4707858",
+ "24670f0283144c7f97d2ba4733bda25d",
+ "70a8f83e1235458ab87e69a1c05780a4",
+ "c09f3c527dce4bed8f6d72922218317e",
+ "e850e6dcd47f4da2842e5e2b8c8a43c5",
+ "744091704d0347d6ad0a6032ee362257",
+ "843f9fad9fea4db88e4201c4aad858b9",
+ "c0173c0a1f334f6fb3c1e8e0163da260",
+ "f733175715f640668ae119c03542fcb5",
+ "09d337ea6f0945c79bad129283c85f0b",
+ "aec34d7e6c4444778daa42043fb595f2",
+ "99e864d7de4d461fbb8ccc3f4f765645",
+ "368b7e51e67749cc8952784c042e3410",
+ "4ff270b1c5dd484996ce971ed7014317",
+ "5f9f9e24d6834de3a058a63fecd3a19f",
+ "6c8c9155c1f04f0a97cd996e741446c0",
+ "fb3f514d931e4c8192f7cd3557ece592",
+ "a8d929695aee466fbdc0324dab55aa15",
+ "d27922254bcb4c30aebff5da1ee34401",
+ "838dfdb7cec44c86bf41608c17f9852b",
+ "1720ff23ff344e498c6f80a98bbd2534",
+ "0f43034866264e019a364912d63ec11e",
+ "7f0d44e575034332bc737cf702a79468",
+ "201941842a4044148b6fc36e18fe3187",
+ "0b83c8135ef44951964738be07783832",
+ "d63f72800b854c7c980ed08f77dd61bf",
+ "18f4b44f74ff4b3c9198a1ecc8d672a8",
+ "1d2bfdbba1dd4e3db05a1a0ce052fad7",
+ "236fbd6f27c14dcba6ec274fe34852ca",
+ "96bc194611d2414c979340ad2b6416bf",
+ "2255bd5f2c6841bd9af4bc8963945936",
+ "687a8ab9cfcc45619d453fef2c8d04ea",
+ "739dde14dc7e4e46aff9ee8858f0f334",
+ "6624e2d4d0a04e858c08a709f6bcf31d",
+ "fdb572e99dde402199a43ac369891538",
+ "cf227ad8761f4f80a7884ec915add0f4",
+ "3951d93995c345f3a001ec7fb9a63df1",
+ "bbad5c1e5c4d46a49e500d5b3890ad0b",
+ "cca945b5469b41549da4b2bd369c83ff",
+ "2b903ea4e2934e4ca655296945bea0cd",
+ "bef5c99c8ba3495b85454d8f10ff17bd",
+ "0604d0239806455889c09137ebee2815",
+ "9a3e9cc371da4801ab642ea09101a40f",
+ "b7daa28812e54fd6941f33d1b8325666",
+ "f660c1ba3a664866a5abcdd3c35d0e72",
+ "707287071f58451b8e5da6688cba286b",
+ "21f81e0c00e94b5d82039aca95e44bfb",
+ "71a49a876cc1413fad925c15672d5919",
+ "1174909ea64c4afa9ba244564eeafaf5",
+ "a76b6b9e474a4cf79c98b52758adeab5",
+ "4abc540a7c404642b7a988937efdd196",
+ "d98b1abde30a4dd28468c7c12d76c822",
+ "ce270b54645e4f1d87a0a10416739c1f",
+ "056409f93e7f48f9a9b8556ed27fcd08",
+ "d5bb62948e5545319030671e17be5ced",
+ "e354464414fd479388da58d033700024",
+ "07b6cec0de3040efa472867d07cd0495",
+ "2292ea67682e44bc8c61ce31bd18371c",
+ "a6182b17067c431fb933d36cef4d4438",
+ "eb69099d7fd9416881bdc9635838da1b",
+ "6f9ff8cbf6b3427ab3d30aa19b7703db",
+ "162b1719fad54fcabe0c9b0646956985",
+ "e2a836342061410c86827b364d3feb29",
+ "9d350966e16c409eb32e0642ee908f24",
+ "cce9b0e3791e442a9e7b9b87a3ffec41",
+ "b525a8f8d07b4abcb162fcca0bfab28d",
+ "e86dea79f86143cab0b88c2d5f8992e8",
+ "9d0f7c1505c9436ca8c04724381dd70f",
+ "dd5b7da397c849adabcf33c2b7c3aeb8",
+ "1be40306cb6844319d44ba1f4f0164cb",
+ "7b2d516d0c5046dc816801790aa4c2b5",
+ "9e1c77e44d884d4d80684179c6f4c96c",
+ "bb637c634d6f4efebafa3aa503c9862e",
+ "8b28096911874bae9798969414e385de",
+ "1d118df455dd4d6aa395df422c711573",
+ "e990b322c8104b329227207e442003ee",
+ "22c5623f724647689c50a0e7c37cb371",
+ "fceb5194627442d79ee5ecd91ee01a10",
+ "76516b2b4e104548a4b892f3979f41c8",
+ "54ffa55dfd154b8a83edf45c157bec11",
+ "042a11af13fe4a1ab7c2a3beef4ee1ff",
+ "51664a883bcb4d52b55e8826f17726bb",
+ "951f41fc154c4f8aa38124c11e481223",
+ "96f61db2adf54c45bf60630919a67c95",
+ "a055f35c82bb43cbb2ff3f8a98fa7a90",
+ "da72027fa75c4687bee46ab832297882",
+ "251d31dade3348acaec1eb247ebb33ab",
+ "d6603805bb654ea9a7e0ae2b60dd5be6",
+ "76e5f131438d4086bfcf52f7f4969b3a",
+ "19f9d3421cef4c38abb500910e251f22",
+ "c3520f624fd043bfaac26503ff10f254",
+ "f4e419873d6f49a99b8718666811073e",
+ "6485458203c946bba61f5ec96959abfb",
+ "c4bcdc27cad7466e8bf0f8655373a71b",
+ "c92a8c987de548239ec78c93ee5bf660",
+ "1dc24cb487dc4b919bbd3c7e6a89e150",
+ "1c861f3ec5214c148a714f04bdd9696c",
+ "92f27d77468c407aae1208893996a2a7"
+ ]
+ },
+ "id": "urpZu9jECUsT",
+ "outputId": "d4e9217b-a099-4148-89ba-cd790dcde7e5"
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "27ae531ccf1848b79a636a731bc635b9",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model.safetensors.index.json: 0%| | 0.00/26.8k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "1b918c0947a54ae89af225ec372351bf",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading shards: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "4e17ea3af1b94b52944a816aba2525ba",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model-00001-of-00002.safetensors: 0%| | 0.00/9.98G [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "aec34d7e6c4444778daa42043fb595f2",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model-00002-of-00002.safetensors: 0%| | 0.00/3.50G [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0f43034866264e019a364912d63ec11e",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "739dde14dc7e4e46aff9ee8858f0f334",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "generation_config.json: 0%| | 0.00/188 [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b7daa28812e54fd6941f33d1b8325666",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "tokenizer_config.json: 0%| | 0.00/1.62k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "d5bb62948e5545319030671e17be5ced",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "tokenizer.model: 0%| | 0.00/500k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b525a8f8d07b4abcb162fcca0bfab28d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "tokenizer.json: 0%| | 0.00/1.84M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "22c5623f724647689c50a0e7c37cb371",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "special_tokens_map.json: 0%| | 0.00/414 [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "d6603805bb654ea9a7e0ae2b60dd5be6",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "string"
+ },
+ "text/plain": [
+ "'The capital of Germany is Berlin.'"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "\n",
+ "from transformers import AutoModelForCausalLM\n",
+ "from transformers import AutoTokenizer\n",
+ "\n",
+ "LLAMA_2_7B_CHAT_PATH = \"meta-llama/Llama-2-7b-chat-hf\"\n",
+ "inference_dtype = torch.float32\n",
+ "# inference_dtype = torch.float32\n",
+ "# inference_dtype = torch.float16\n",
+ "\n",
+ "hf_model = AutoModelForCausalLM.from_pretrained(LLAMA_2_7B_CHAT_PATH,\n",
+ " torch_dtype=inference_dtype,\n",
+ " device_map = \"cuda:0\",\n",
+ " load_in_4bit=True)\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(LLAMA_2_7B_CHAT_PATH)\n",
+ "\n",
+ "model = HookedTransformer.from_pretrained(LLAMA_2_7B_CHAT_PATH,\n",
+ " hf_model=hf_model,\n",
+ " dtype=inference_dtype,\n",
+ " fold_ln=False,\n",
+ " fold_value_biases=False,\n",
+ " center_writing_weights=False,\n",
+ " center_unembed=False,\n",
+ " tokenizer=tokenizer)\n",
+ "\n",
+ "model.generate(\"The capital of Germany is\", max_new_tokens=2, temperature=0)\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qpe_ZL3FCUsT"
+ },
+ "source": [
+ "### Verify GPU memory use"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "iEgCpdJ0CUsT",
+ "outputId": "2214d6ec-3384-4aea-f453-b61e7d5220bf"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "free(Gb): 9.29988608 total(Gb): 15.835660288\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"free(Gb):\", torch.cuda.mem_get_info()[0]/1000000000, \"total(Gb):\", torch.cuda.mem_get_info()[1]/1000000000)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "t_ZUUccfCUsT"
+ },
+ "source": [
+ "### Compare logits with HuggingFace model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "qm_Ng6SiCUsT",
+ "outputId": "9b2461e4-9e25-4457-efe0-7597bb7fe9ae"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|āāāāāāāāāā| 4/4 [00:02<00:00, 1.67it/s]\n",
+ "100%|āāāāāāāāāā| 4/4 [00:02<00:00, 1.78it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "logits[i] 0 torch.float32 tensor([[[ 2.6141e-01, -1.3572e+00, -2.5338e-03, ..., 1.5789e+00,\n",
+ " 1.7533e+00, 6.7760e-01],\n",
+ " [-9.6479e+00, -5.7782e+00, -4.9753e+00, ..., -6.4692e+00,\n",
+ " -6.1699e+00, -6.2673e+00],\n",
+ " [-7.2656e+00, -5.5785e+00, 1.7281e+00, ..., -1.5952e+00,\n",
+ " -6.6061e+00, -3.6781e+00],\n",
+ " [-1.4183e+00, -1.1532e-01, 3.3884e+00, ..., -2.7672e+00,\n",
+ " -1.5437e+00, -2.0235e+00],\n",
+ " [-1.5692e-01, -4.4547e-02, 1.2237e+01, ..., 4.0257e-01,\n",
+ " -5.5462e-01, 6.4535e-01],\n",
+ " [-4.0320e+00, -3.9415e+00, 7.8710e+00, ..., -1.0899e+00,\n",
+ " -3.5164e+00, -1.2452e+00]]])\n",
+ "tl_logits[i] 0 torch.float32 tensor([[[ 2.6141e-01, -1.3572e+00, -2.5325e-03, ..., 1.5789e+00,\n",
+ " 1.7533e+00, 6.7759e-01],\n",
+ " [-9.6479e+00, -5.7782e+00, -4.9753e+00, ..., -6.4692e+00,\n",
+ " -6.1700e+00, -6.2673e+00],\n",
+ " [-7.2656e+00, -5.5785e+00, 1.7281e+00, ..., -1.5952e+00,\n",
+ " -6.6061e+00, -3.6781e+00],\n",
+ " [-1.4183e+00, -1.1532e-01, 3.3884e+00, ..., -2.7672e+00,\n",
+ " -1.5437e+00, -2.0235e+00],\n",
+ " [-1.5692e-01, -4.4552e-02, 1.2237e+01, ..., 4.0256e-01,\n",
+ " -5.5462e-01, 6.4535e-01],\n",
+ " [-4.0320e+00, -3.9415e+00, 7.8710e+00, ..., -1.0899e+00,\n",
+ " -3.5164e+00, -1.2452e+00]]])\n"
+ ]
+ }
+ ],
+ "source": [
+ "prompts = [\n",
+ " \"The capital of Germany is\",\n",
+ " \"2 * 42 = \",\n",
+ " \"My favorite\",\n",
+ " \"aosetuhaosuh aostud aoestuaoentsudhasuh aos tasat naostutshaosuhtnaoe usaho uaotsnhuaosntuhaosntu haouaoshat u saotheu saonuh aoesntuhaosut aosu thaosu thaoustaho usaothusaothuao sutao sutaotduaoetudet uaosthuao uaostuaoeu aostouhsaonh aosnthuaoscnuhaoshkbaoesnit haosuhaoe uasotehusntaosn.p.uo ksoentudhao ustahoeuaso usant.hsa otuhaotsi aostuhs\",\n",
+ "]\n",
+ "\n",
+ "model.eval()\n",
+ "hf_model.eval()\n",
+ "prompt_ids = [tokenizer.encode(prompt, return_tensors=\"pt\") for prompt in prompts]\n",
+ "tl_logits = [model(prompt_ids).detach().cpu() for prompt_ids in tqdm(prompt_ids)]\n",
+ "\n",
+ "# hf logits are really slow as it's on CPU. If you have a big/multi-GPU machine, run `hf_model = hf_model.to(\"cuda\")` to speed this up\n",
+ "logits = [hf_model(prompt_ids).logits.detach().cpu() for prompt_ids in tqdm(prompt_ids)]\n",
+ "\n",
+ "for i in range(len(prompts)):\n",
+ " if i == 0:\n",
+ " print(\"logits[i]\", i, logits[i].dtype, logits[i])\n",
+ " print(\"tl_logits[i]\", i, tl_logits[i].dtype, tl_logits[i])\n",
+ " assert torch.allclose(logits[i], tl_logits[i], atol=1e-4, rtol=1e-2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Jdbw9hClCUsU"
+ },
+ "source": [
+ "## TransformerLens Demo"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "5e0Q1EGeCUsU"
+ },
+ "source": [
+ "### Reading from hooks"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 580
+ },
+ "id": "ymC0G-h7CUsU",
+ "outputId": "a45ce29f-5fc2-403c-f137-0e7c61627ee8"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Layer 0 Head Attention Patterns:\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "llama_text = \"Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets.\"\n",
+ "llama_tokens = model.to_tokens(llama_text)\n",
+ "llama_logits, llama_cache = model.run_with_cache(llama_tokens, remove_batch_dim=True)\n",
+ "\n",
+ "attention_pattern = llama_cache[\"pattern\", 0, \"attn\"]\n",
+ "llama_str_tokens = model.to_str_tokens(llama_text)\n",
+ "\n",
+ "print(\"Layer 0 Head Attention Patterns:\")\n",
+ "display(cv.attention.attention_patterns(tokens=llama_str_tokens, attention=attention_pattern))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0f_t5QZeCUsU"
+ },
+ "source": [
+ "### Writing to hooks"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "YVHW-VKJCUsU",
+ "outputId": "346b37b7-0bfd-4dcf-e21c-fa930dec14c1"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Shape of the value tensor: torch.Size([1, 34, 32, 128])\n",
+ "Original Loss: 2.841\n",
+ "Ablated Loss: 2.806\n"
+ ]
+ }
+ ],
+ "source": [
+ "layer_to_ablate = 0\n",
+ "head_index_to_ablate = 31\n",
+ "\n",
+ "# We define a head ablation hook\n",
+ "# The type annotations are NOT necessary, they're just a useful guide to the reader\n",
+ "#\n",
+ "def head_ablation_hook(\n",
+ " value: Float[torch.Tensor, \"batch pos head_index d_head\"],\n",
+ " hook: HookPoint\n",
+ ") -> Float[torch.Tensor, \"batch pos head_index d_head\"]:\n",
+ " print(f\"Shape of the value tensor: {value.shape}\")\n",
+ " value[:, :, head_index_to_ablate, :] = 0.\n",
+ " return value\n",
+ "\n",
+ "original_loss = model(llama_tokens, return_type=\"loss\")\n",
+ "ablated_loss = model.run_with_hooks(\n",
+ " llama_tokens,\n",
+ " return_type=\"loss\",\n",
+ " fwd_hooks=[(\n",
+ " utils.get_act_name(\"v\", layer_to_ablate),\n",
+ " head_ablation_hook\n",
+ " )]\n",
+ " )\n",
+ "print(f\"Original Loss: {original_loss.item():.3f}\")\n",
+ "print(f\"Ablated Loss: {ablated_loss.item():.3f}\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.4"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "f03ec946e3b5caa7cc710a963f479e62a68fff56c790a7066e03c8b5c22adad9"
+ }
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "042a11af13fe4a1ab7c2a3beef4ee1ff": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "056409f93e7f48f9a9b8556ed27fcd08": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "0604d0239806455889c09137ebee2815": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "07b6cec0de3040efa472867d07cd0495": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_162b1719fad54fcabe0c9b0646956985",
+ "max": 499723,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_e2a836342061410c86827b364d3feb29",
+ "value": 499723
+ }
+ },
+ "09d337ea6f0945c79bad129283c85f0b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "0b83c8135ef44951964738be07783832": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_2255bd5f2c6841bd9af4bc8963945936",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_687a8ab9cfcc45619d453fef2c8d04ea",
+ "value": " 2/2 [01:25<00:00, 38.63s/it]"
+ }
+ },
+ "0e3321d5dcc74ce09baeadce8bd5d6e1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_e6b8c1942a54415f998e1f44e796ce1b",
+ "max": 26788,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_2c9f2db0723d49fbb902d658703cdcb6",
+ "value": 26788
+ }
+ },
+ "0eaef657be2742b1867943854370d1ed": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_143ecc2bee0c42489185fbbe8c785578",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_fb16992539ed404383974588a5c9263a",
+ "value": " 2/2 [02:09<00:00, 59.74s/it]"
+ }
+ },
+ "0f43034866264e019a364912d63ec11e": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_7f0d44e575034332bc737cf702a79468",
+ "IPY_MODEL_201941842a4044148b6fc36e18fe3187",
+ "IPY_MODEL_0b83c8135ef44951964738be07783832"
+ ],
+ "layout": "IPY_MODEL_d63f72800b854c7c980ed08f77dd61bf"
+ }
+ },
+ "1174909ea64c4afa9ba244564eeafaf5": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "143ecc2bee0c42489185fbbe8c785578": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "162b1719fad54fcabe0c9b0646956985": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "1720ff23ff344e498c6f80a98bbd2534": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "186ff73f09504161b24be544f4707858": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_e850e6dcd47f4da2842e5e2b8c8a43c5",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_744091704d0347d6ad0a6032ee362257",
+ "value": "model-00001-of-00002.safetensors: 100%"
+ }
+ },
+ "18f4b44f74ff4b3c9198a1ecc8d672a8": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "19f9d3421cef4c38abb500910e251f22": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_c92a8c987de548239ec78c93ee5bf660",
+ "max": 2,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_1dc24cb487dc4b919bbd3c7e6a89e150",
+ "value": 2
+ }
+ },
+ "1b918c0947a54ae89af225ec372351bf": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_49502d07f7b143e784f4bf76ed6cc272",
+ "IPY_MODEL_9a3264f31b044a73be31ca4fba3e7f1d",
+ "IPY_MODEL_0eaef657be2742b1867943854370d1ed"
+ ],
+ "layout": "IPY_MODEL_e31ca54d82f64114ab694ce50ba62987"
+ }
+ },
+ "1be40306cb6844319d44ba1f4f0164cb": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "1c861f3ec5214c148a714f04bdd9696c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "1d118df455dd4d6aa395df422c711573": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "1d2bfdbba1dd4e3db05a1a0ce052fad7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "1dc24cb487dc4b919bbd3c7e6a89e150": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "201941842a4044148b6fc36e18fe3187": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_236fbd6f27c14dcba6ec274fe34852ca",
+ "max": 2,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_96bc194611d2414c979340ad2b6416bf",
+ "value": 2
+ }
+ },
+ "21f81e0c00e94b5d82039aca95e44bfb": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_ce270b54645e4f1d87a0a10416739c1f",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_056409f93e7f48f9a9b8556ed27fcd08",
+ "value": " 1.62k/1.62k [00:00<00:00, 98.1kB/s]"
+ }
+ },
+ "2255bd5f2c6841bd9af4bc8963945936": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "2292ea67682e44bc8c61ce31bd18371c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_9d350966e16c409eb32e0642ee908f24",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_cce9b0e3791e442a9e7b9b87a3ffec41",
+ "value": " 500k/500k [00:00<00:00, 31.8MB/s]"
+ }
+ },
+ "22c5623f724647689c50a0e7c37cb371": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_fceb5194627442d79ee5ecd91ee01a10",
+ "IPY_MODEL_76516b2b4e104548a4b892f3979f41c8",
+ "IPY_MODEL_54ffa55dfd154b8a83edf45c157bec11"
+ ],
+ "layout": "IPY_MODEL_042a11af13fe4a1ab7c2a3beef4ee1ff"
+ }
+ },
+ "236fbd6f27c14dcba6ec274fe34852ca": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "24670f0283144c7f97d2ba4733bda25d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_843f9fad9fea4db88e4201c4aad858b9",
+ "max": 9976576152,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_c0173c0a1f334f6fb3c1e8e0163da260",
+ "value": 9976576152
+ }
+ },
+ "251d31dade3348acaec1eb247ebb33ab": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "2582ee70618844ff8168b8637ddbbae4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "27ae531ccf1848b79a636a731bc635b9": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_cc67c0d024c848898c5298691db43ed6",
+ "IPY_MODEL_0e3321d5dcc74ce09baeadce8bd5d6e1",
+ "IPY_MODEL_4b3dc4ad2d2f4ac59612b6392a6a77f7"
+ ],
+ "layout": "IPY_MODEL_cfb7c854917e4e3abca3158c0554c081"
+ }
+ },
+ "2b903ea4e2934e4ca655296945bea0cd": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "2c9f2db0723d49fbb902d658703cdcb6": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "368b7e51e67749cc8952784c042e3410": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_a8d929695aee466fbdc0324dab55aa15",
+ "max": 3500296424,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_d27922254bcb4c30aebff5da1ee34401",
+ "value": 3500296424
+ }
+ },
+ "37618acd36d44f68b0dbe07cedd60111": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "3951d93995c345f3a001ec7fb9a63df1": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "49502d07f7b143e784f4bf76ed6cc272": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_96e00b8ba2e640d79cd80a1c401dcedd",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_537df88aa48b4ae883ec9f1b86a82d11",
+ "value": "Downloading shards: 100%"
+ }
+ },
+ "4abc540a7c404642b7a988937efdd196": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "4b3dc4ad2d2f4ac59612b6392a6a77f7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_37618acd36d44f68b0dbe07cedd60111",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_a31c3597f0a546b0b47bbb873b7c3848",
+ "value": " 26.8k/26.8k [00:00<00:00, 1.55MB/s]"
+ }
+ },
+ "4e17ea3af1b94b52944a816aba2525ba": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_186ff73f09504161b24be544f4707858",
+ "IPY_MODEL_24670f0283144c7f97d2ba4733bda25d",
+ "IPY_MODEL_70a8f83e1235458ab87e69a1c05780a4"
+ ],
+ "layout": "IPY_MODEL_c09f3c527dce4bed8f6d72922218317e"
+ }
+ },
+ "4ff270b1c5dd484996ce971ed7014317": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_838dfdb7cec44c86bf41608c17f9852b",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_1720ff23ff344e498c6f80a98bbd2534",
+ "value": " 3.50G/3.50G [00:36<00:00, 132MB/s]"
+ }
+ },
+ "51664a883bcb4d52b55e8826f17726bb": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "537df88aa48b4ae883ec9f1b86a82d11": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "54ffa55dfd154b8a83edf45c157bec11": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_da72027fa75c4687bee46ab832297882",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_251d31dade3348acaec1eb247ebb33ab",
+ "value": " 414/414 [00:00<00:00, 30.9kB/s]"
+ }
+ },
+ "55f5712956834686ba2b19a36438b726": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "5f9f9e24d6834de3a058a63fecd3a19f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "6485458203c946bba61f5ec96959abfb": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "6624e2d4d0a04e858c08a709f6bcf31d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_bbad5c1e5c4d46a49e500d5b3890ad0b",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_cca945b5469b41549da4b2bd369c83ff",
+ "value": "generation_config.json: 100%"
+ }
+ },
+ "687a8ab9cfcc45619d453fef2c8d04ea": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "6c8c9155c1f04f0a97cd996e741446c0": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "6f9ff8cbf6b3427ab3d30aa19b7703db": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "707287071f58451b8e5da6688cba286b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_4abc540a7c404642b7a988937efdd196",
+ "max": 1618,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_d98b1abde30a4dd28468c7c12d76c822",
+ "value": 1618
+ }
+ },
+ "70a8f83e1235458ab87e69a1c05780a4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_f733175715f640668ae119c03542fcb5",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_09d337ea6f0945c79bad129283c85f0b",
+ "value": " 9.98G/9.98G [01:33<00:00, 167MB/s]"
+ }
+ },
+ "71a49a876cc1413fad925c15672d5919": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "739dde14dc7e4e46aff9ee8858f0f334": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_6624e2d4d0a04e858c08a709f6bcf31d",
+ "IPY_MODEL_fdb572e99dde402199a43ac369891538",
+ "IPY_MODEL_cf227ad8761f4f80a7884ec915add0f4"
+ ],
+ "layout": "IPY_MODEL_3951d93995c345f3a001ec7fb9a63df1"
+ }
+ },
+ "744091704d0347d6ad0a6032ee362257": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "76516b2b4e104548a4b892f3979f41c8": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_96f61db2adf54c45bf60630919a67c95",
+ "max": 414,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_a055f35c82bb43cbb2ff3f8a98fa7a90",
+ "value": 414
+ }
+ },
+ "76e5f131438d4086bfcf52f7f4969b3a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_6485458203c946bba61f5ec96959abfb",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_c4bcdc27cad7466e8bf0f8655373a71b",
+ "value": "100%"
+ }
+ },
+ "7b2d516d0c5046dc816801790aa4c2b5": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "7f0d44e575034332bc737cf702a79468": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_18f4b44f74ff4b3c9198a1ecc8d672a8",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_1d2bfdbba1dd4e3db05a1a0ce052fad7",
+ "value": "Loading checkpoint shards: 100%"
+ }
+ },
+ "838dfdb7cec44c86bf41608c17f9852b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "843f9fad9fea4db88e4201c4aad858b9": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "8b28096911874bae9798969414e385de": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "92f27d77468c407aae1208893996a2a7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "951f41fc154c4f8aa38124c11e481223": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "96bc194611d2414c979340ad2b6416bf": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "96e00b8ba2e640d79cd80a1c401dcedd": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "96f61db2adf54c45bf60630919a67c95": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "99e864d7de4d461fbb8ccc3f4f765645": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_6c8c9155c1f04f0a97cd996e741446c0",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_fb3f514d931e4c8192f7cd3557ece592",
+ "value": "model-00002-of-00002.safetensors: 100%"
+ }
+ },
+ "9a3264f31b044a73be31ca4fba3e7f1d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_55f5712956834686ba2b19a36438b726",
+ "max": 2,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_f4a21799b8f84e4687e36314cfc873f0",
+ "value": 2
+ }
+ },
+ "9a3e9cc371da4801ab642ea09101a40f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "9d0f7c1505c9436ca8c04724381dd70f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_bb637c634d6f4efebafa3aa503c9862e",
+ "max": 1842767,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_8b28096911874bae9798969414e385de",
+ "value": 1842767
+ }
+ },
+ "9d350966e16c409eb32e0642ee908f24": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9e1c77e44d884d4d80684179c6f4c96c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "a055f35c82bb43cbb2ff3f8a98fa7a90": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "a31c3597f0a546b0b47bbb873b7c3848": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "a6182b17067c431fb933d36cef4d4438": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "a76b6b9e474a4cf79c98b52758adeab5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "a8d929695aee466fbdc0324dab55aa15": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "aec34d7e6c4444778daa42043fb595f2": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_99e864d7de4d461fbb8ccc3f4f765645",
+ "IPY_MODEL_368b7e51e67749cc8952784c042e3410",
+ "IPY_MODEL_4ff270b1c5dd484996ce971ed7014317"
+ ],
+ "layout": "IPY_MODEL_5f9f9e24d6834de3a058a63fecd3a19f"
+ }
+ },
+ "b525a8f8d07b4abcb162fcca0bfab28d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_e86dea79f86143cab0b88c2d5f8992e8",
+ "IPY_MODEL_9d0f7c1505c9436ca8c04724381dd70f",
+ "IPY_MODEL_dd5b7da397c849adabcf33c2b7c3aeb8"
+ ],
+ "layout": "IPY_MODEL_1be40306cb6844319d44ba1f4f0164cb"
+ }
+ },
+ "b7daa28812e54fd6941f33d1b8325666": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_f660c1ba3a664866a5abcdd3c35d0e72",
+ "IPY_MODEL_707287071f58451b8e5da6688cba286b",
+ "IPY_MODEL_21f81e0c00e94b5d82039aca95e44bfb"
+ ],
+ "layout": "IPY_MODEL_71a49a876cc1413fad925c15672d5919"
+ }
+ },
+ "bb637c634d6f4efebafa3aa503c9862e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "bbad5c1e5c4d46a49e500d5b3890ad0b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "bef5c99c8ba3495b85454d8f10ff17bd": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "c0173c0a1f334f6fb3c1e8e0163da260": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "c09f3c527dce4bed8f6d72922218317e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "c3520f624fd043bfaac26503ff10f254": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_1c861f3ec5214c148a714f04bdd9696c",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_92f27d77468c407aae1208893996a2a7",
+ "value": " 2/2 [00:02<00:00, 1.09s/it]"
+ }
+ },
+ "c4bcdc27cad7466e8bf0f8655373a71b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "c92a8c987de548239ec78c93ee5bf660": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "cc67c0d024c848898c5298691db43ed6": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_fc7ac85b780a4605809bf68f006e1ef3",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_2582ee70618844ff8168b8637ddbbae4",
+ "value": "model.safetensors.index.json: 100%"
+ }
+ },
+ "cca945b5469b41549da4b2bd369c83ff": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "cce9b0e3791e442a9e7b9b87a3ffec41": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "ce270b54645e4f1d87a0a10416739c1f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "cf227ad8761f4f80a7884ec915add0f4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_0604d0239806455889c09137ebee2815",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_9a3e9cc371da4801ab642ea09101a40f",
+ "value": " 188/188 [00:00<00:00, 13.3kB/s]"
+ }
+ },
+ "cfb7c854917e4e3abca3158c0554c081": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "d27922254bcb4c30aebff5da1ee34401": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "d5bb62948e5545319030671e17be5ced": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_e354464414fd479388da58d033700024",
+ "IPY_MODEL_07b6cec0de3040efa472867d07cd0495",
+ "IPY_MODEL_2292ea67682e44bc8c61ce31bd18371c"
+ ],
+ "layout": "IPY_MODEL_a6182b17067c431fb933d36cef4d4438"
+ }
+ },
+ "d63f72800b854c7c980ed08f77dd61bf": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "d6603805bb654ea9a7e0ae2b60dd5be6": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_76e5f131438d4086bfcf52f7f4969b3a",
+ "IPY_MODEL_19f9d3421cef4c38abb500910e251f22",
+ "IPY_MODEL_c3520f624fd043bfaac26503ff10f254"
+ ],
+ "layout": "IPY_MODEL_f4e419873d6f49a99b8718666811073e"
+ }
+ },
+ "d98b1abde30a4dd28468c7c12d76c822": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "da72027fa75c4687bee46ab832297882": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "dd5b7da397c849adabcf33c2b7c3aeb8": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_1d118df455dd4d6aa395df422c711573",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_e990b322c8104b329227207e442003ee",
+ "value": " 1.84M/1.84M [00:00<00:00, 25.3MB/s]"
+ }
+ },
+ "e2a836342061410c86827b364d3feb29": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "e31ca54d82f64114ab694ce50ba62987": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e354464414fd479388da58d033700024": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_eb69099d7fd9416881bdc9635838da1b",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_6f9ff8cbf6b3427ab3d30aa19b7703db",
+ "value": "tokenizer.model: 100%"
+ }
+ },
+ "e6b8c1942a54415f998e1f44e796ce1b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e850e6dcd47f4da2842e5e2b8c8a43c5": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e86dea79f86143cab0b88c2d5f8992e8": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_7b2d516d0c5046dc816801790aa4c2b5",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_9e1c77e44d884d4d80684179c6f4c96c",
+ "value": "tokenizer.json: 100%"
+ }
+ },
+ "e990b322c8104b329227207e442003ee": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "eb69099d7fd9416881bdc9635838da1b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f4a21799b8f84e4687e36314cfc873f0": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "f4e419873d6f49a99b8718666811073e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f660c1ba3a664866a5abcdd3c35d0e72": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_1174909ea64c4afa9ba244564eeafaf5",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_a76b6b9e474a4cf79c98b52758adeab5",
+ "value": "tokenizer_config.json: 100%"
+ }
+ },
+ "f733175715f640668ae119c03542fcb5": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "fb16992539ed404383974588a5c9263a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "fb3f514d931e4c8192f7cd3557ece592": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "fc7ac85b780a4605809bf68f006e1ef3": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "fceb5194627442d79ee5ecd91ee01a10": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_51664a883bcb4d52b55e8826f17726bb",
+ "placeholder": "ā",
+ "style": "IPY_MODEL_951f41fc154c4f8aa38124c11e481223",
+ "value": "special_tokens_map.json: 100%"
+ }
+ },
+ "fdb572e99dde402199a43ac369891538": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_2b903ea4e2934e4ca655296945bea0cd",
+ "max": 188,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_bef5c99c8ba3495b85454d8f10ff17bd",
+ "value": 188
+ }
+ }
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/demos/Main_Demo.ipynb b/demos/Main_Demo.ipynb
index d5d524c76..b2f89b695 100644
--- a/demos/Main_Demo.ipynb
+++ b/demos/Main_Demo.ipynb
@@ -45,7 +45,7 @@
},
{
"cell_type": "code",
- "execution_count": 292,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -74,13 +74,12 @@
" ip.extension_manager.load('autoreload')\n",
" %autoreload 2\n",
" \n",
- "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n",
- "IN_GITHUB = True\n"
+ "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n"
]
},
{
"cell_type": "code",
- "execution_count": 293,
+ "execution_count": 12,
"metadata": {},
"outputs": [
{
@@ -103,32 +102,28 @@
},
{
"cell_type": "code",
- "execution_count": 294,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 294,
- "metadata": {
- "text/html": {
- "Content-Type": "text/html"
- }
- },
+ "execution_count": 13,
+ "metadata": {},
"output_type": "execute_result"
}
],
@@ -140,7 +135,7 @@
},
{
"cell_type": "code",
- "execution_count": 295,
+ "execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
@@ -158,7 +153,7 @@
},
{
"cell_type": "code",
- "execution_count": 296,
+ "execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
@@ -179,16 +174,16 @@
},
{
"cell_type": "code",
- "execution_count": 297,
+ "execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 297,
+ "execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
@@ -254,7 +249,7 @@
},
{
"cell_type": "code",
- "execution_count": 299,
+ "execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
@@ -263,7 +258,7 @@
},
{
"cell_type": "code",
- "execution_count": 300,
+ "execution_count": 18,
"metadata": {},
"outputs": [
{
@@ -1210,7 +1205,7 @@
},
{
"cell_type": "code",
- "execution_count": 314,
+ "execution_count": 19,
"metadata": {},
"outputs": [
{
@@ -1218,13 +1213,13 @@
"output_type": "stream",
"text": [
"blocks.0.attn.W_Q torch.Size([12, 768, 64])\n",
- "blocks.0.attn.W_K torch.Size([12, 768, 64])\n",
- "blocks.0.attn.W_V torch.Size([12, 768, 64])\n",
"blocks.0.attn.W_O torch.Size([12, 64, 768])\n",
"blocks.0.attn.b_Q torch.Size([12, 64])\n",
+ "blocks.0.attn.b_O torch.Size([768])\n",
+ "blocks.0.attn.W_K torch.Size([12, 768, 64])\n",
+ "blocks.0.attn.W_V torch.Size([12, 768, 64])\n",
"blocks.0.attn.b_K torch.Size([12, 64])\n",
"blocks.0.attn.b_V torch.Size([12, 64])\n",
- "blocks.0.attn.b_O torch.Size([768])\n",
"blocks.0.mlp.W_in torch.Size([768, 3072])\n",
"blocks.0.mlp.b_in torch.Size([3072])\n",
"blocks.0.mlp.W_out torch.Size([3072, 768])\n",
@@ -1247,7 +1242,7 @@
},
{
"cell_type": "code",
- "execution_count": 315,
+ "execution_count": 20,
"metadata": {},
"outputs": [
{
diff --git a/demos/No_Position_Experiment.ipynb b/demos/No_Position_Experiment.ipynb
index d784f2518..98b2ddf2a 100644
--- a/demos/No_Position_Experiment.ipynb
+++ b/demos/No_Position_Experiment.ipynb
@@ -28,6 +28,11 @@
"# Setup"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ },
{
"cell_type": "code",
"execution_count": 1,
@@ -39,7 +44,7 @@
"\n",
" IN_COLAB = True\n",
" !pip install einops\n",
- " !pip install https://github.com/neelnanda-io/TransformerLens@no-position-experiment\n",
+ " %pip install transformer_lens\n",
"except:\n",
" IN_COLAB = False\n",
"\n",
@@ -577,7 +582,7 @@
}
],
"source": [
- "cache[\"blocks.0.attn.hook_attn\"].shape"
+ "cache[\"blocks.0.attn.hook_pattern\"].shape"
]
},
{
@@ -717,7 +722,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -733,7 +738,7 @@
"logit_components = (\n",
" resid_stack[:, batch_index]\n",
" @ fold_W_U\n",
- " / cache[\"scale\", None, \"ln_final\"][batch_index]\n",
+ " / cache[\"scale\"][batch_index]\n",
")\n",
"print(logit_components.shape)"
]
@@ -1274,7 +1279,7 @@
"losses = []\n",
"loss_labels = []\n",
"for hook_name in hook_list:\n",
- " if hook_name != \"hook_pos_embed\" and \"result\" not in hook_name:\n",
+ " if hook_name in cache and hook_name != \"hook_pos_embed\" and \"result\" not in hook_name:\n",
" average_act = cache[hook_name].mean(0)\n",
"\n",
" def replacing_with_average_act(activation, hook):\n",
diff --git a/demos/Othello_GPT.ipynb b/demos/Othello_GPT.ipynb
index 7cbb68a5c..1b4400bc7 100644
--- a/demos/Othello_GPT.ipynb
+++ b/demos/Othello_GPT.ipynb
@@ -69,6 +69,7 @@
" print(\"Running as a Colab notebook\")\n",
" %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n",
" %pip install circuitsvis\n",
+ " %pip install torchtyping\n",
" \n",
" # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n",
" # # Install another version of node that makes PySvelte work way faster\n",
diff --git a/demos/Qwen.ipynb b/demos/Qwen.ipynb
new file mode 100644
index 000000000..e8ef18f57
--- /dev/null
+++ b/demos/Qwen.ipynb
@@ -0,0 +1,376 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Requirement already satisfied: transformers_stream_generator in /root/TransformerLens/.venv/lib/python3.10/site-packages (0.0.4)\n",
+ "Requirement already satisfied: plotly in /root/TransformerLens/.venv/lib/python3.10/site-packages (5.18.0)\n",
+ "Requirement already satisfied: circuitsvis in /root/TransformerLens/.venv/lib/python3.10/site-packages (1.43.2)\n",
+ "Requirement already satisfied: huggingface_hub in /root/TransformerLens/.venv/lib/python3.10/site-packages (0.20.2)\n",
+ "Requirement already satisfied: einops in /root/TransformerLens/.venv/lib/python3.10/site-packages (0.7.0)\n",
+ "Requirement already satisfied: tiktoken in /root/TransformerLens/.venv/lib/python3.10/site-packages (0.5.2)\n",
+ "Requirement already satisfied: datasets in /root/TransformerLens/.venv/lib/python3.10/site-packages (2.14.4)\n",
+ "Requirement already satisfied: transformers>=4.26.1 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from transformers_stream_generator) (4.37.2)\n",
+ "Requirement already satisfied: tenacity>=6.2.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from plotly) (8.2.3)\n",
+ "Requirement already satisfied: packaging in /root/TransformerLens/.venv/lib/python3.10/site-packages (from plotly) (23.2)\n",
+ "Requirement already satisfied: importlib-metadata>=5.1.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (7.0.1)\n",
+ "Requirement already satisfied: numpy>=1.24 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (1.26.3)\n",
+ "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (12.1.3.1)\n",
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n",
+ "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (8.9.2.26)\n",
+ "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (11.0.2.54)\n",
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (10.3.2.106)\n",
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (11.4.5.107)\n",
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (12.1.0.106)\n",
+ "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (2.18.1)\n",
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n",
+ "Requirement already satisfied: torch>=1.10 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (2.1.2)\n",
+ "Requirement already satisfied: triton==2.1.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (2.1.0)\n",
+ "Requirement already satisfied: nvidia-nvjitlink-cu12 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->circuitsvis) (12.3.101)\n",
+ "Requirement already satisfied: filelock in /root/TransformerLens/.venv/lib/python3.10/site-packages (from triton==2.1.0->circuitsvis) (3.13.1)\n",
+ "Requirement already satisfied: fsspec>=2023.5.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from huggingface_hub) (2023.12.2)\n",
+ "Requirement already satisfied: requests in /root/TransformerLens/.venv/lib/python3.10/site-packages (from huggingface_hub) (2.31.0)\n",
+ "Requirement already satisfied: tqdm>=4.42.1 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from huggingface_hub) (4.66.1)\n",
+ "Requirement already satisfied: pyyaml>=5.1 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from huggingface_hub) (6.0.1)\n",
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from huggingface_hub) (4.9.0)\n",
+ "Requirement already satisfied: regex>=2022.1.18 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from tiktoken) (2023.12.25)\n",
+ "Requirement already satisfied: pyarrow>=8.0.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from datasets) (14.0.2)\n",
+ "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from datasets) (0.3.7)\n",
+ "Requirement already satisfied: pandas in /root/TransformerLens/.venv/lib/python3.10/site-packages (from datasets) (2.0.3)\n",
+ "Requirement already satisfied: xxhash in /root/TransformerLens/.venv/lib/python3.10/site-packages (from datasets) (3.4.1)\n",
+ "Requirement already satisfied: multiprocess in /root/TransformerLens/.venv/lib/python3.10/site-packages (from datasets) (0.70.15)\n",
+ "Requirement already satisfied: aiohttp in /root/TransformerLens/.venv/lib/python3.10/site-packages (from datasets) (3.9.1)\n",
+ "Requirement already satisfied: attrs>=17.3.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from aiohttp->datasets) (23.2.0)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from aiohttp->datasets) (6.0.4)\n",
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from aiohttp->datasets) (1.9.4)\n",
+ "Requirement already satisfied: frozenlist>=1.1.1 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from aiohttp->datasets) (1.4.1)\n",
+ "Requirement already satisfied: aiosignal>=1.1.2 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from aiohttp->datasets) (1.3.1)\n",
+ "Requirement already satisfied: async-timeout<5.0,>=4.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from aiohttp->datasets) (4.0.3)\n",
+ "Requirement already satisfied: zipp>=0.5 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from importlib-metadata>=5.1.0->circuitsvis) (3.17.0)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from requests->huggingface_hub) (3.3.2)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from requests->huggingface_hub) (3.6)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from requests->huggingface_hub) (2.1.0)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from requests->huggingface_hub) (2023.11.17)\n",
+ "Requirement already satisfied: sympy in /root/TransformerLens/.venv/lib/python3.10/site-packages (from torch>=1.10->circuitsvis) (1.12)\n",
+ "Requirement already satisfied: networkx in /root/TransformerLens/.venv/lib/python3.10/site-packages (from torch>=1.10->circuitsvis) (3.1)\n",
+ "Requirement already satisfied: jinja2 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from torch>=1.10->circuitsvis) (3.1.2)\n",
+ "Requirement already satisfied: tokenizers<0.19,>=0.14 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from transformers>=4.26.1->transformers_stream_generator) (0.15.0)\n",
+ "Requirement already satisfied: safetensors>=0.4.1 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from transformers>=4.26.1->transformers_stream_generator) (0.4.1)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from pandas->datasets) (2.8.2)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from pandas->datasets) (2023.3.post1)\n",
+ "Requirement already satisfied: tzdata>=2022.1 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from pandas->datasets) (2023.4)\n",
+ "Requirement already satisfied: six>=1.5 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from jinja2->torch>=1.10->circuitsvis) (2.1.3)\n",
+ "Requirement already satisfied: mpmath>=0.19 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from sympy->torch>=1.10->circuitsvis) (1.3.0)\n",
+ "\n",
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
+ "Note: you may need to restart the kernel to use updated packages.\n"
+ ]
+ }
+ ],
+ "source": [
+ "%pip install transformers_stream_generator plotly circuitsvis huggingface_hub einops tiktoken datasets"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Running as a Jupyter notebook - intended for development only!\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/tmp/ipykernel_13850/410710250.py:21: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
+ " ipython.magic(\"load_ext autoreload\")\n",
+ "/tmp/ipykernel_13850/410710250.py:22: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
+ " ipython.magic(\"autoreload 2\")\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Janky code to do different setup when run in a Colab notebook vs VSCode\n",
+ "DEVELOPMENT_MODE = False\n",
+ "try:\n",
+ " import google.colab\n",
+ " IN_COLAB = True\n",
+ " print(\"Running as a Colab notebook\")\n",
+ " %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n",
+ " %pip install circuitsvis\n",
+ " \n",
+ " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n",
+ " # # Install another version of node that makes PySvelte work way faster\n",
+ " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n",
+ " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n",
+ "except:\n",
+ " IN_COLAB = False\n",
+ " print(\"Running as a Jupyter notebook - intended for development only!\")\n",
+ " from IPython import get_ipython\n",
+ "\n",
+ " ipython = get_ipython()\n",
+ " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n",
+ " ipython.magic(\"load_ext autoreload\")\n",
+ " ipython.magic(\"autoreload 2\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Using renderer: colab\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n",
+ "import plotly.io as pio\n",
+ "if IN_COLAB or not DEVELOPMENT_MODE:\n",
+ " pio.renderers.default = \"colab\"\n",
+ "else:\n",
+ " pio.renderers.default = \"notebook_connected\"\n",
+ "print(f\"Using renderer: {pio.renderers.default}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/root/TransformerLens\n"
+ ]
+ }
+ ],
+ "source": [
+ "%cd ~/TransformerLens\n",
+ "import torch\n",
+ "torch.set_grad_enabled(False)\n",
+ "\n",
+ "from transformers import AutoTokenizer\n",
+ "from transformer_lens import HookedTransformer\n",
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
+ "from transformers.generation import GenerationConfig\n",
+ "\n",
+ "from functools import partial"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def assert_hf_and_tl_model_are_close(\n",
+ " hf_model,\n",
+ " tl_model,\n",
+ " tokenizer,\n",
+ " prompt=\"This is a prompt to test out\",\n",
+ " atol=1e-3,\n",
+ "):\n",
+ " prompt_toks = tokenizer(prompt, return_tensors=\"pt\").input_ids\n",
+ "\n",
+ " hf_logits = hf_model(prompt_toks.to(hf_model.device)).logits\n",
+ " tl_logits = tl_model(prompt_toks).to(hf_logits)\n",
+ "\n",
+ " assert torch.allclose(torch.softmax(hf_logits, dim=-1), torch.softmax(tl_logits, dim=-1), atol=atol)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Qwen, first generation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "2cffaf8715b64623b6799822d7cf1cfe",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:root:Loading model Qwen/Qwen-1_8B-Chat requires setting trust_remote_code=True\n",
+ "WARNING:root:Loading model Qwen/Qwen-1_8B-Chat state dict requires setting trust_remote_code=True\n",
+ "WARNING:transformers_modules.Qwen.Qwen-1_8B-Chat.1d0f68de57b88cfde81f3c3e537f24464d889081.modeling_qwen:Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "619309048d964c0ca76bfa098d71f25a",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loaded pretrained model Qwen/Qwen-1_8B-Chat into HookedTransformer\n",
+ "Moving model to device: cuda\n"
+ ]
+ }
+ ],
+ "source": [
+ "model_path = \"Qwen/Qwen-1_8B-Chat\"\n",
+ "device = \"cuda\"\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(\n",
+ " model_path,\n",
+ " trust_remote_code=True\n",
+ ")\n",
+ "\n",
+ "hf_model = AutoModelForCausalLM.from_pretrained(\n",
+ " model_path,\n",
+ " device_map=device,\n",
+ " fp32=True,\n",
+ " use_logn_attn=False,\n",
+ " use_dynamic_ntk = False,\n",
+ " scale_attn_weights = False,\n",
+ " trust_remote_code = True\n",
+ ").eval()\n",
+ "\n",
+ "tl_model = HookedTransformer.from_pretrained_no_processing(\n",
+ " model_path,\n",
+ " device=device,\n",
+ " fp32=True,\n",
+ " dtype=torch.float32,\n",
+ ").to(device)\n",
+ "\n",
+ "assert_hf_and_tl_model_are_close(hf_model, tl_model, tokenizer)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Qwen, new generation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loaded pretrained model Qwen/Qwen1.5-1.8B-Chat into HookedTransformer\n",
+ "Moving model to device: cuda\n"
+ ]
+ }
+ ],
+ "source": [
+ "model_path = \"Qwen/Qwen1.5-1.8B-Chat\"\n",
+ "device = \"cuda\"\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(\n",
+ " model_path,\n",
+ ")\n",
+ "\n",
+ "hf_model = AutoModelForCausalLM.from_pretrained(\n",
+ " model_path,\n",
+ " device_map=device,\n",
+ ").eval()\n",
+ "\n",
+ "tl_model = HookedTransformer.from_pretrained_no_processing(\n",
+ " model_path,\n",
+ " device=device,\n",
+ " dtype=torch.float32,\n",
+ ").to(device)\n",
+ "\n",
+ "assert_hf_and_tl_model_are_close(hf_model, tl_model, tokenizer)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "base",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/demos/santacoder.ipynb b/demos/Santa_Coder.ipynb
similarity index 99%
rename from demos/santacoder.ipynb
rename to demos/Santa_Coder.ipynb
index 61b455035..a69071c38 100644
--- a/demos/santacoder.ipynb
+++ b/demos/Santa_Coder.ipynb
@@ -35,6 +35,7 @@
" print(\"Running as a Colab notebook\")\n",
" %pip install git+https://github.com/neelnanda-io/TransformerLens.git``\n",
" %pip install circuitsvis\n",
+ " %pip install torchtyping\n",
" \n",
" # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n",
" # # Install another version of node that makes PySvelte work way faster\n",
diff --git a/docs/make_docs.py b/docs/make_docs.py
index ea6c41fba..d15b7a124 100644
--- a/docs/make_docs.py
+++ b/docs/make_docs.py
@@ -93,10 +93,7 @@ def generate_model_table(_app: Optional[Any] = None):
]
df = pd.DataFrame(
{
- name: [
- get_property(name, model_name)
- for model_name in loading.DEFAULT_MODEL_ALIASES
- ]
+ name: [get_property(name, model_name) for model_name in loading.DEFAULT_MODEL_ALIASES]
for name in column_names
},
index=loading.DEFAULT_MODEL_ALIASES,
diff --git a/docs/source/conf.py b/docs/source/conf.py
index af38914c0..96d6b42e6 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -80,11 +80,13 @@
"convert_gptj_weights",
"convert_llama_weights",
"convert_mingpt_weights",
+ "convert_nanogpt_weights",
"convert_neel_solu_old_weights",
"convert_neo_weights",
"convert_neox_weights",
"convert_neel_model_config",
"convert_opt_weights",
+ "convert_gemma_weights",
"fill_missing_keys",
"get_basic_config",
"get_official_model_name",
diff --git a/docs/source/content/contributing.md b/docs/source/content/contributing.md
index f90c0be53..49bf28f99 100644
--- a/docs/source/content/contributing.md
+++ b/docs/source/content/contributing.md
@@ -32,6 +32,8 @@ quite slow (as we only have CPU actions) so the smaller models like `attn-only-1
- Unit tests only via `make unit-test`
- Acceptance tests only via `make acceptance-test`
- Docstring tests only via `make docstring-test`
+- Notebook tests only via `make notebook-test`
+- Run all test suites mentioned `make test`
## Formatting
@@ -41,6 +43,8 @@ actions.
- Format all files via `make format`
- Only check the formatting via `make check-format`
+Note that `black` line length is set to 100 in `pyproject.toml` (instead of the default 88).
+
## Documentation
Please make sure to add thorough documentation for any features you add. You should do this directly
diff --git a/docs/source/content/getting_started.md b/docs/source/content/getting_started.md
index 459b65b44..13952cd5e 100644
--- a/docs/source/content/getting_started.md
+++ b/docs/source/content/getting_started.md
@@ -19,3 +19,13 @@ One significant design decision made was to have a single transformer implementa
Import the library with `import transformer_lens`
(Note: This library used to be known as EasyTransformer, and some breaking changes have been made since the rename. If you need to use the old version with some legacy code, run `pip install git+https://github.com/neelnanda-io/TransformerLens@v1`.)
+
+## Huggingface Gated Access
+
+Some of the models available in TransformerLens require gated access to be used. Luckily TransformerLens provides a way to access those models via the configuration of an environmental variable. Simply configure your access token found [here](https://huggingface.co/settings/tokens) as `HF_TOKEN` in your environment.
+
+You will need to make sure you accept the agreements for any gated models, but once you do, the models will work with TransformerLens without issue. If you attempt to ues one of these models before you have accepted any related agreements, the console output will be very helpful and point you to the URL where you need to accept an agreement. As of 23/4/24, the current list of gated models supported by TransformerLens is as follows.
+
+* https://huggingface.co/mistralai/Mixtral-8x7B-v0.1
+* https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
+* https://huggingface.co/mistralai/Mistral-7B-v0.1
diff --git a/docs/source/content/special_cases.md b/docs/source/content/special_cases.md
new file mode 100644
index 000000000..a5eae2164
--- /dev/null
+++ b/docs/source/content/special_cases.md
@@ -0,0 +1,11 @@
+# Special Cases
+
+## Mixture of Experts error rates
+Due to the Top-K gating performed in the hidden layer of Mixture of Experts models, small errors can be amplified
+greatly in cases where a different expert is selected, which leads to a higher than normal variance in the error rate
+of the final logits. In testing done on Mixtral running in half precision, the standard deviation of the absolute error
+rate of the logits compared to those from the default model was found to be around 2e-3.
+
+There are two main ways to mitigate this:
+1. Disable preprocessing options by using `HookedTransformer.from_pretrained_no_processing` instead of `HookedTransformer.from_pretrained`
+2. Increase the precision of the data type used in the model
diff --git a/docs/source/index.md b/docs/source/index.md
index 09b00f20f..869ebc248 100644
--- a/docs/source/index.md
+++ b/docs/source/index.md
@@ -44,6 +44,7 @@ content/citation
content/contributing
generated/demos/Main_Demo
generated/demos/Exploratory_Analysis_Demo
+content/special_cases
```
```{toctree}
diff --git a/makefile b/makefile
index b786aa209..4cc4633dc 100644
--- a/makefile
+++ b/makefile
@@ -9,17 +9,21 @@ check-format:
poetry run black --check .
unit-test:
- poetry run pytest --cov=transformer_lens/ --cov-report=term-missing --cov-branch tests/unit
+ poetry run pytest tests/unit
acceptance-test:
- poetry run pytest --cov=transformer_lens/ --cov-report=term-missing --cov-branch tests/acceptance
+ poetry run pytest tests/acceptance
+
+coverage-report-test:
+ poetry run pytest --cov=transformer_lens/ --cov-report=html --cov-branch tests/unit tests/acceptance
docstring-test:
poetry run pytest transformer_lens/
notebook-test:
- poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Main_Demo.ipynb
+ poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/BERT.ipynb
poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Exploratory_Analysis_Demo.ipynb
+ poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Main_Demo.ipynb
test:
make unit-test
diff --git a/poetry.lock b/poetry.lock
index bac8ff0c5..6aba9fa7e 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,13 +1,14 @@
-# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
[[package]]
name = "accelerate"
-version = "0.24.0"
+version = "0.29.1"
description = "Accelerate"
optional = false
python-versions = ">=3.8.0"
files = [
- {file = "accelerate-0.24.0-py3-none-any.whl", hash = "sha256:04bb1483c90eacb3beb2687cb54950d8caf9a0b93432f6b2d42efebbb6c0491e"},
+ {file = "accelerate-0.29.1-py3-none-any.whl", hash = "sha256:7eda0c8bc62bc59129103310f1272a0fb7b3ebc55fc8920cfe1c102db30aca58"},
+ {file = "accelerate-0.29.1.tar.gz", hash = "sha256:d1d0e5a591177891812fd6d1bc843af191e1192c80e5180258f52fefcb653a9f"},
]
[package.dependencies]
@@ -16,125 +17,114 @@ numpy = ">=1.17"
packaging = ">=20.0"
psutil = "*"
pyyaml = "*"
+safetensors = ">=0.3.1"
torch = ">=1.10.0"
[package.extras]
-dev = ["bitsandbytes", "black (>=23.1,<24.0)", "datasets", "deepspeed", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.0.241)", "scikit-learn", "scipy", "timm", "tqdm", "transformers", "urllib3 (<2.0.0)"]
-quality = ["black (>=23.1,<24.0)", "hf-doc-builder (>=0.3.0)", "ruff (>=0.0.241)", "urllib3 (<2.0.0)"]
+dev = ["bitsandbytes", "black (>=23.1,<24.0)", "datasets", "deepspeed", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.2.1,<0.3.0)", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
+quality = ["black (>=23.1,<24.0)", "hf-doc-builder (>=0.3.0)", "ruff (>=0.2.1,<0.3.0)"]
rich = ["rich"]
sagemaker = ["sagemaker"]
-test-dev = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "scikit-learn", "scipy", "timm", "tqdm", "transformers"]
-test-prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"]
-test-trackers = ["comet-ml", "tensorboard", "wandb"]
-testing = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "tqdm", "transformers"]
+test-dev = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
+test-prod = ["parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist"]
+test-trackers = ["comet-ml", "dvclive", "tensorboard", "wandb"]
+testing = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
[[package]]
name = "aiohttp"
-version = "3.8.6"
+version = "3.9.3"
description = "Async http client/server framework (asyncio)"
optional = false
-python-versions = ">=3.6"
+python-versions = ">=3.8"
files = [
- {file = "aiohttp-3.8.6-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:41d55fc043954cddbbd82503d9cc3f4814a40bcef30b3569bc7b5e34130718c1"},
- {file = "aiohttp-3.8.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1d84166673694841d8953f0a8d0c90e1087739d24632fe86b1a08819168b4566"},
- {file = "aiohttp-3.8.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:253bf92b744b3170eb4c4ca2fa58f9c4b87aeb1df42f71d4e78815e6e8b73c9e"},
- {file = "aiohttp-3.8.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3fd194939b1f764d6bb05490987bfe104287bbf51b8d862261ccf66f48fb4096"},
- {file = "aiohttp-3.8.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6c5f938d199a6fdbdc10bbb9447496561c3a9a565b43be564648d81e1102ac22"},
- {file = "aiohttp-3.8.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2817b2f66ca82ee699acd90e05c95e79bbf1dc986abb62b61ec8aaf851e81c93"},
- {file = "aiohttp-3.8.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fa375b3d34e71ccccf172cab401cd94a72de7a8cc01847a7b3386204093bb47"},
- {file = "aiohttp-3.8.6-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9de50a199b7710fa2904be5a4a9b51af587ab24c8e540a7243ab737b45844543"},
- {file = "aiohttp-3.8.6-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e1d8cb0b56b3587c5c01de3bf2f600f186da7e7b5f7353d1bf26a8ddca57f965"},
- {file = "aiohttp-3.8.6-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8e31e9db1bee8b4f407b77fd2507337a0a80665ad7b6c749d08df595d88f1cf5"},
- {file = "aiohttp-3.8.6-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:7bc88fc494b1f0311d67f29fee6fd636606f4697e8cc793a2d912ac5b19aa38d"},
- {file = "aiohttp-3.8.6-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ec00c3305788e04bf6d29d42e504560e159ccaf0be30c09203b468a6c1ccd3b2"},
- {file = "aiohttp-3.8.6-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ad1407db8f2f49329729564f71685557157bfa42b48f4b93e53721a16eb813ed"},
- {file = "aiohttp-3.8.6-cp310-cp310-win32.whl", hash = "sha256:ccc360e87341ad47c777f5723f68adbb52b37ab450c8bc3ca9ca1f3e849e5fe2"},
- {file = "aiohttp-3.8.6-cp310-cp310-win_amd64.whl", hash = "sha256:93c15c8e48e5e7b89d5cb4613479d144fda8344e2d886cf694fd36db4cc86865"},
- {file = "aiohttp-3.8.6-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6e2f9cc8e5328f829f6e1fb74a0a3a939b14e67e80832975e01929e320386b34"},
- {file = "aiohttp-3.8.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e6a00ffcc173e765e200ceefb06399ba09c06db97f401f920513a10c803604ca"},
- {file = "aiohttp-3.8.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:41bdc2ba359032e36c0e9de5a3bd00d6fb7ea558a6ce6b70acedf0da86458321"},
- {file = "aiohttp-3.8.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:14cd52ccf40006c7a6cd34a0f8663734e5363fd981807173faf3a017e202fec9"},
- {file = "aiohttp-3.8.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2d5b785c792802e7b275c420d84f3397668e9d49ab1cb52bd916b3b3ffcf09ad"},
- {file = "aiohttp-3.8.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1bed815f3dc3d915c5c1e556c397c8667826fbc1b935d95b0ad680787896a358"},
- {file = "aiohttp-3.8.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96603a562b546632441926cd1293cfcb5b69f0b4159e6077f7c7dbdfb686af4d"},
- {file = "aiohttp-3.8.6-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d76e8b13161a202d14c9584590c4df4d068c9567c99506497bdd67eaedf36403"},
- {file = "aiohttp-3.8.6-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e3f1e3f1a1751bb62b4a1b7f4e435afcdade6c17a4fd9b9d43607cebd242924a"},
- {file = "aiohttp-3.8.6-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:76b36b3124f0223903609944a3c8bf28a599b2cc0ce0be60b45211c8e9be97f8"},
- {file = "aiohttp-3.8.6-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:a2ece4af1f3c967a4390c284797ab595a9f1bc1130ef8b01828915a05a6ae684"},
- {file = "aiohttp-3.8.6-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:16d330b3b9db87c3883e565340d292638a878236418b23cc8b9b11a054aaa887"},
- {file = "aiohttp-3.8.6-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:42c89579f82e49db436b69c938ab3e1559e5a4409eb8639eb4143989bc390f2f"},
- {file = "aiohttp-3.8.6-cp311-cp311-win32.whl", hash = "sha256:efd2fcf7e7b9d7ab16e6b7d54205beded0a9c8566cb30f09c1abe42b4e22bdcb"},
- {file = "aiohttp-3.8.6-cp311-cp311-win_amd64.whl", hash = "sha256:3b2ab182fc28e7a81f6c70bfbd829045d9480063f5ab06f6e601a3eddbbd49a0"},
- {file = "aiohttp-3.8.6-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:fdee8405931b0615220e5ddf8cd7edd8592c606a8e4ca2a00704883c396e4479"},
- {file = "aiohttp-3.8.6-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d25036d161c4fe2225d1abff2bd52c34ed0b1099f02c208cd34d8c05729882f0"},
- {file = "aiohttp-3.8.6-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5d791245a894be071d5ab04bbb4850534261a7d4fd363b094a7b9963e8cdbd31"},
- {file = "aiohttp-3.8.6-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0cccd1de239afa866e4ce5c789b3032442f19c261c7d8a01183fd956b1935349"},
- {file = "aiohttp-3.8.6-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f13f60d78224f0dace220d8ab4ef1dbc37115eeeab8c06804fec11bec2bbd07"},
- {file = "aiohttp-3.8.6-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8a9b5a0606faca4f6cc0d338359d6fa137104c337f489cd135bb7fbdbccb1e39"},
- {file = "aiohttp-3.8.6-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:13da35c9ceb847732bf5c6c5781dcf4780e14392e5d3b3c689f6d22f8e15ae31"},
- {file = "aiohttp-3.8.6-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:4d4cbe4ffa9d05f46a28252efc5941e0462792930caa370a6efaf491f412bc66"},
- {file = "aiohttp-3.8.6-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:229852e147f44da0241954fc6cb910ba074e597f06789c867cb7fb0621e0ba7a"},
- {file = "aiohttp-3.8.6-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:713103a8bdde61d13490adf47171a1039fd880113981e55401a0f7b42c37d071"},
- {file = "aiohttp-3.8.6-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:45ad816b2c8e3b60b510f30dbd37fe74fd4a772248a52bb021f6fd65dff809b6"},
- {file = "aiohttp-3.8.6-cp36-cp36m-win32.whl", hash = "sha256:2b8d4e166e600dcfbff51919c7a3789ff6ca8b3ecce16e1d9c96d95dd569eb4c"},
- {file = "aiohttp-3.8.6-cp36-cp36m-win_amd64.whl", hash = "sha256:0912ed87fee967940aacc5306d3aa8ba3a459fcd12add0b407081fbefc931e53"},
- {file = "aiohttp-3.8.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e2a988a0c673c2e12084f5e6ba3392d76c75ddb8ebc6c7e9ead68248101cd446"},
- {file = "aiohttp-3.8.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ebf3fd9f141700b510d4b190094db0ce37ac6361a6806c153c161dc6c041ccda"},
- {file = "aiohttp-3.8.6-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3161ce82ab85acd267c8f4b14aa226047a6bee1e4e6adb74b798bd42c6ae1f80"},
- {file = "aiohttp-3.8.6-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d95fc1bf33a9a81469aa760617b5971331cdd74370d1214f0b3109272c0e1e3c"},
- {file = "aiohttp-3.8.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c43ecfef7deaf0617cee936836518e7424ee12cb709883f2c9a1adda63cc460"},
- {file = "aiohttp-3.8.6-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca80e1b90a05a4f476547f904992ae81eda5c2c85c66ee4195bb8f9c5fb47f28"},
- {file = "aiohttp-3.8.6-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:90c72ebb7cb3a08a7f40061079817133f502a160561d0675b0a6adf231382c92"},
- {file = "aiohttp-3.8.6-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:bb54c54510e47a8c7c8e63454a6acc817519337b2b78606c4e840871a3e15349"},
- {file = "aiohttp-3.8.6-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:de6a1c9f6803b90e20869e6b99c2c18cef5cc691363954c93cb9adeb26d9f3ae"},
- {file = "aiohttp-3.8.6-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:a3628b6c7b880b181a3ae0a0683698513874df63783fd89de99b7b7539e3e8a8"},
- {file = "aiohttp-3.8.6-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:fc37e9aef10a696a5a4474802930079ccfc14d9f9c10b4662169671ff034b7df"},
- {file = "aiohttp-3.8.6-cp37-cp37m-win32.whl", hash = "sha256:f8ef51e459eb2ad8e7a66c1d6440c808485840ad55ecc3cafefadea47d1b1ba2"},
- {file = "aiohttp-3.8.6-cp37-cp37m-win_amd64.whl", hash = "sha256:b2fe42e523be344124c6c8ef32a011444e869dc5f883c591ed87f84339de5976"},
- {file = "aiohttp-3.8.6-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:9e2ee0ac5a1f5c7dd3197de309adfb99ac4617ff02b0603fd1e65b07dc772e4b"},
- {file = "aiohttp-3.8.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:01770d8c04bd8db568abb636c1fdd4f7140b284b8b3e0b4584f070180c1e5c62"},
- {file = "aiohttp-3.8.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3c68330a59506254b556b99a91857428cab98b2f84061260a67865f7f52899f5"},
- {file = "aiohttp-3.8.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89341b2c19fb5eac30c341133ae2cc3544d40d9b1892749cdd25892bbc6ac951"},
- {file = "aiohttp-3.8.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:71783b0b6455ac8f34b5ec99d83e686892c50498d5d00b8e56d47f41b38fbe04"},
- {file = "aiohttp-3.8.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f628dbf3c91e12f4d6c8b3f092069567d8eb17814aebba3d7d60c149391aee3a"},
- {file = "aiohttp-3.8.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b04691bc6601ef47c88f0255043df6f570ada1a9ebef99c34bd0b72866c217ae"},
- {file = "aiohttp-3.8.6-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7ee912f7e78287516df155f69da575a0ba33b02dd7c1d6614dbc9463f43066e3"},
- {file = "aiohttp-3.8.6-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9c19b26acdd08dd239e0d3669a3dddafd600902e37881f13fbd8a53943079dbc"},
- {file = "aiohttp-3.8.6-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:99c5ac4ad492b4a19fc132306cd57075c28446ec2ed970973bbf036bcda1bcc6"},
- {file = "aiohttp-3.8.6-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:f0f03211fd14a6a0aed2997d4b1c013d49fb7b50eeb9ffdf5e51f23cfe2c77fa"},
- {file = "aiohttp-3.8.6-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:8d399dade330c53b4106160f75f55407e9ae7505263ea86f2ccca6bfcbdb4921"},
- {file = "aiohttp-3.8.6-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ec4fd86658c6a8964d75426517dc01cbf840bbf32d055ce64a9e63a40fd7b771"},
- {file = "aiohttp-3.8.6-cp38-cp38-win32.whl", hash = "sha256:33164093be11fcef3ce2571a0dccd9041c9a93fa3bde86569d7b03120d276c6f"},
- {file = "aiohttp-3.8.6-cp38-cp38-win_amd64.whl", hash = "sha256:bdf70bfe5a1414ba9afb9d49f0c912dc524cf60141102f3a11143ba3d291870f"},
- {file = "aiohttp-3.8.6-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d52d5dc7c6682b720280f9d9db41d36ebe4791622c842e258c9206232251ab2b"},
- {file = "aiohttp-3.8.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4ac39027011414dbd3d87f7edb31680e1f430834c8cef029f11c66dad0670aa5"},
- {file = "aiohttp-3.8.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3f5c7ce535a1d2429a634310e308fb7d718905487257060e5d4598e29dc17f0b"},
- {file = "aiohttp-3.8.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b30e963f9e0d52c28f284d554a9469af073030030cef8693106d918b2ca92f54"},
- {file = "aiohttp-3.8.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:918810ef188f84152af6b938254911055a72e0f935b5fbc4c1a4ed0b0584aed1"},
- {file = "aiohttp-3.8.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:002f23e6ea8d3dd8d149e569fd580c999232b5fbc601c48d55398fbc2e582e8c"},
- {file = "aiohttp-3.8.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4fcf3eabd3fd1a5e6092d1242295fa37d0354b2eb2077e6eb670accad78e40e1"},
- {file = "aiohttp-3.8.6-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:255ba9d6d5ff1a382bb9a578cd563605aa69bec845680e21c44afc2670607a95"},
- {file = "aiohttp-3.8.6-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d67f8baed00870aa390ea2590798766256f31dc5ed3ecc737debb6e97e2ede78"},
- {file = "aiohttp-3.8.6-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:86f20cee0f0a317c76573b627b954c412ea766d6ada1a9fcf1b805763ae7feeb"},
- {file = "aiohttp-3.8.6-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:39a312d0e991690ccc1a61f1e9e42daa519dcc34ad03eb6f826d94c1190190dd"},
- {file = "aiohttp-3.8.6-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:e827d48cf802de06d9c935088c2924e3c7e7533377d66b6f31ed175c1620e05e"},
- {file = "aiohttp-3.8.6-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:bd111d7fc5591ddf377a408ed9067045259ff2770f37e2d94e6478d0f3fc0c17"},
- {file = "aiohttp-3.8.6-cp39-cp39-win32.whl", hash = "sha256:caf486ac1e689dda3502567eb89ffe02876546599bbf915ec94b1fa424eeffd4"},
- {file = "aiohttp-3.8.6-cp39-cp39-win_amd64.whl", hash = "sha256:3f0e27e5b733803333bb2371249f41cf42bae8884863e8e8965ec69bebe53132"},
- {file = "aiohttp-3.8.6.tar.gz", hash = "sha256:b0cf2a4501bff9330a8a5248b4ce951851e415bdcce9dc158e76cfd55e15085c"},
+ {file = "aiohttp-3.9.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:939677b61f9d72a4fa2a042a5eee2a99a24001a67c13da113b2e30396567db54"},
+ {file = "aiohttp-3.9.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1f5cd333fcf7590a18334c90f8c9147c837a6ec8a178e88d90a9b96ea03194cc"},
+ {file = "aiohttp-3.9.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:82e6aa28dd46374f72093eda8bcd142f7771ee1eb9d1e223ff0fa7177a96b4a5"},
+ {file = "aiohttp-3.9.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f56455b0c2c7cc3b0c584815264461d07b177f903a04481dfc33e08a89f0c26b"},
+ {file = "aiohttp-3.9.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bca77a198bb6e69795ef2f09a5f4c12758487f83f33d63acde5f0d4919815768"},
+ {file = "aiohttp-3.9.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e083c285857b78ee21a96ba1eb1b5339733c3563f72980728ca2b08b53826ca5"},
+ {file = "aiohttp-3.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab40e6251c3873d86ea9b30a1ac6d7478c09277b32e14745d0d3c6e76e3c7e29"},
+ {file = "aiohttp-3.9.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:df822ee7feaaeffb99c1a9e5e608800bd8eda6e5f18f5cfb0dc7eeb2eaa6bbec"},
+ {file = "aiohttp-3.9.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:acef0899fea7492145d2bbaaaec7b345c87753168589cc7faf0afec9afe9b747"},
+ {file = "aiohttp-3.9.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:cd73265a9e5ea618014802ab01babf1940cecb90c9762d8b9e7d2cc1e1969ec6"},
+ {file = "aiohttp-3.9.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:a78ed8a53a1221393d9637c01870248a6f4ea5b214a59a92a36f18151739452c"},
+ {file = "aiohttp-3.9.3-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:6b0e029353361f1746bac2e4cc19b32f972ec03f0f943b390c4ab3371840aabf"},
+ {file = "aiohttp-3.9.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7cf5c9458e1e90e3c390c2639f1017a0379a99a94fdfad3a1fd966a2874bba52"},
+ {file = "aiohttp-3.9.3-cp310-cp310-win32.whl", hash = "sha256:3e59c23c52765951b69ec45ddbbc9403a8761ee6f57253250c6e1536cacc758b"},
+ {file = "aiohttp-3.9.3-cp310-cp310-win_amd64.whl", hash = "sha256:055ce4f74b82551678291473f66dc9fb9048a50d8324278751926ff0ae7715e5"},
+ {file = "aiohttp-3.9.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6b88f9386ff1ad91ace19d2a1c0225896e28815ee09fc6a8932fded8cda97c3d"},
+ {file = "aiohttp-3.9.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c46956ed82961e31557b6857a5ca153c67e5476972e5f7190015018760938da2"},
+ {file = "aiohttp-3.9.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:07b837ef0d2f252f96009e9b8435ec1fef68ef8b1461933253d318748ec1acdc"},
+ {file = "aiohttp-3.9.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dad46e6f620574b3b4801c68255492e0159d1712271cc99d8bdf35f2043ec266"},
+ {file = "aiohttp-3.9.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5ed3e046ea7b14938112ccd53d91c1539af3e6679b222f9469981e3dac7ba1ce"},
+ {file = "aiohttp-3.9.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:039df344b45ae0b34ac885ab5b53940b174530d4dd8a14ed8b0e2155b9dddccb"},
+ {file = "aiohttp-3.9.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7943c414d3a8d9235f5f15c22ace69787c140c80b718dcd57caaade95f7cd93b"},
+ {file = "aiohttp-3.9.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:84871a243359bb42c12728f04d181a389718710129b36b6aad0fc4655a7647d4"},
+ {file = "aiohttp-3.9.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:5eafe2c065df5401ba06821b9a054d9cb2848867f3c59801b5d07a0be3a380ae"},
+ {file = "aiohttp-3.9.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:9d3c9b50f19704552f23b4eaea1fc082fdd82c63429a6506446cbd8737823da3"},
+ {file = "aiohttp-3.9.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:f033d80bc6283092613882dfe40419c6a6a1527e04fc69350e87a9df02bbc283"},
+ {file = "aiohttp-3.9.3-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:2c895a656dd7e061b2fd6bb77d971cc38f2afc277229ce7dd3552de8313a483e"},
+ {file = "aiohttp-3.9.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1f5a71d25cd8106eab05f8704cd9167b6e5187bcdf8f090a66c6d88b634802b4"},
+ {file = "aiohttp-3.9.3-cp311-cp311-win32.whl", hash = "sha256:50fca156d718f8ced687a373f9e140c1bb765ca16e3d6f4fe116e3df7c05b2c5"},
+ {file = "aiohttp-3.9.3-cp311-cp311-win_amd64.whl", hash = "sha256:5fe9ce6c09668063b8447f85d43b8d1c4e5d3d7e92c63173e6180b2ac5d46dd8"},
+ {file = "aiohttp-3.9.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:38a19bc3b686ad55804ae931012f78f7a534cce165d089a2059f658f6c91fa60"},
+ {file = "aiohttp-3.9.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:770d015888c2a598b377bd2f663adfd947d78c0124cfe7b959e1ef39f5b13869"},
+ {file = "aiohttp-3.9.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ee43080e75fc92bf36219926c8e6de497f9b247301bbf88c5c7593d931426679"},
+ {file = "aiohttp-3.9.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:52df73f14ed99cee84865b95a3d9e044f226320a87af208f068ecc33e0c35b96"},
+ {file = "aiohttp-3.9.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dc9b311743a78043b26ffaeeb9715dc360335e5517832f5a8e339f8a43581e4d"},
+ {file = "aiohttp-3.9.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b955ed993491f1a5da7f92e98d5dad3c1e14dc175f74517c4e610b1f2456fb11"},
+ {file = "aiohttp-3.9.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:504b6981675ace64c28bf4a05a508af5cde526e36492c98916127f5a02354d53"},
+ {file = "aiohttp-3.9.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a6fe5571784af92b6bc2fda8d1925cccdf24642d49546d3144948a6a1ed58ca5"},
+ {file = "aiohttp-3.9.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ba39e9c8627edc56544c8628cc180d88605df3892beeb2b94c9bc857774848ca"},
+ {file = "aiohttp-3.9.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:e5e46b578c0e9db71d04c4b506a2121c0cb371dd89af17a0586ff6769d4c58c1"},
+ {file = "aiohttp-3.9.3-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:938a9653e1e0c592053f815f7028e41a3062e902095e5a7dc84617c87267ebd5"},
+ {file = "aiohttp-3.9.3-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:c3452ea726c76e92f3b9fae4b34a151981a9ec0a4847a627c43d71a15ac32aa6"},
+ {file = "aiohttp-3.9.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ff30218887e62209942f91ac1be902cc80cddb86bf00fbc6783b7a43b2bea26f"},
+ {file = "aiohttp-3.9.3-cp312-cp312-win32.whl", hash = "sha256:38f307b41e0bea3294a9a2a87833191e4bcf89bb0365e83a8be3a58b31fb7f38"},
+ {file = "aiohttp-3.9.3-cp312-cp312-win_amd64.whl", hash = "sha256:b791a3143681a520c0a17e26ae7465f1b6f99461a28019d1a2f425236e6eedb5"},
+ {file = "aiohttp-3.9.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0ed621426d961df79aa3b963ac7af0d40392956ffa9be022024cd16297b30c8c"},
+ {file = "aiohttp-3.9.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7f46acd6a194287b7e41e87957bfe2ad1ad88318d447caf5b090012f2c5bb528"},
+ {file = "aiohttp-3.9.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:feeb18a801aacb098220e2c3eea59a512362eb408d4afd0c242044c33ad6d542"},
+ {file = "aiohttp-3.9.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f734e38fd8666f53da904c52a23ce517f1b07722118d750405af7e4123933511"},
+ {file = "aiohttp-3.9.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b40670ec7e2156d8e57f70aec34a7216407848dfe6c693ef131ddf6e76feb672"},
+ {file = "aiohttp-3.9.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fdd215b7b7fd4a53994f238d0f46b7ba4ac4c0adb12452beee724ddd0743ae5d"},
+ {file = "aiohttp-3.9.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:017a21b0df49039c8f46ca0971b3a7fdc1f56741ab1240cb90ca408049766168"},
+ {file = "aiohttp-3.9.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e99abf0bba688259a496f966211c49a514e65afa9b3073a1fcee08856e04425b"},
+ {file = "aiohttp-3.9.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:648056db9a9fa565d3fa851880f99f45e3f9a771dd3ff3bb0c048ea83fb28194"},
+ {file = "aiohttp-3.9.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8aacb477dc26797ee089721536a292a664846489c49d3ef9725f992449eda5a8"},
+ {file = "aiohttp-3.9.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:522a11c934ea660ff8953eda090dcd2154d367dec1ae3c540aff9f8a5c109ab4"},
+ {file = "aiohttp-3.9.3-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:5bce0dc147ca85caa5d33debc4f4d65e8e8b5c97c7f9f660f215fa74fc49a321"},
+ {file = "aiohttp-3.9.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4b4af9f25b49a7be47c0972139e59ec0e8285c371049df1a63b6ca81fdd216a2"},
+ {file = "aiohttp-3.9.3-cp38-cp38-win32.whl", hash = "sha256:298abd678033b8571995650ccee753d9458dfa0377be4dba91e4491da3f2be63"},
+ {file = "aiohttp-3.9.3-cp38-cp38-win_amd64.whl", hash = "sha256:69361bfdca5468c0488d7017b9b1e5ce769d40b46a9f4a2eed26b78619e9396c"},
+ {file = "aiohttp-3.9.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0fa43c32d1643f518491d9d3a730f85f5bbaedcbd7fbcae27435bb8b7a061b29"},
+ {file = "aiohttp-3.9.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:835a55b7ca49468aaaac0b217092dfdff370e6c215c9224c52f30daaa735c1c1"},
+ {file = "aiohttp-3.9.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:06a9b2c8837d9a94fae16c6223acc14b4dfdff216ab9b7202e07a9a09541168f"},
+ {file = "aiohttp-3.9.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:abf151955990d23f84205286938796c55ff11bbfb4ccfada8c9c83ae6b3c89a3"},
+ {file = "aiohttp-3.9.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:59c26c95975f26e662ca78fdf543d4eeaef70e533a672b4113dd888bd2423caa"},
+ {file = "aiohttp-3.9.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f95511dd5d0e05fd9728bac4096319f80615aaef4acbecb35a990afebe953b0e"},
+ {file = "aiohttp-3.9.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:595f105710293e76b9dc09f52e0dd896bd064a79346234b521f6b968ffdd8e58"},
+ {file = "aiohttp-3.9.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c7c8b816c2b5af5c8a436df44ca08258fc1a13b449393a91484225fcb7545533"},
+ {file = "aiohttp-3.9.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f1088fa100bf46e7b398ffd9904f4808a0612e1d966b4aa43baa535d1b6341eb"},
+ {file = "aiohttp-3.9.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f59dfe57bb1ec82ac0698ebfcdb7bcd0e99c255bd637ff613760d5f33e7c81b3"},
+ {file = "aiohttp-3.9.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:361a1026c9dd4aba0109e4040e2aecf9884f5cfe1b1b1bd3d09419c205e2e53d"},
+ {file = "aiohttp-3.9.3-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:363afe77cfcbe3a36353d8ea133e904b108feea505aa4792dad6585a8192c55a"},
+ {file = "aiohttp-3.9.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8e2c45c208c62e955e8256949eb225bd8b66a4c9b6865729a786f2aa79b72e9d"},
+ {file = "aiohttp-3.9.3-cp39-cp39-win32.whl", hash = "sha256:f7217af2e14da0856e082e96ff637f14ae45c10a5714b63c77f26d8884cf1051"},
+ {file = "aiohttp-3.9.3-cp39-cp39-win_amd64.whl", hash = "sha256:27468897f628c627230dba07ec65dc8d0db566923c48f29e084ce382119802bc"},
+ {file = "aiohttp-3.9.3.tar.gz", hash = "sha256:90842933e5d1ff760fae6caca4b2b3edba53ba8f4b71e95dacf2818a2aca06f7"},
]
[package.dependencies]
aiosignal = ">=1.1.2"
-async-timeout = ">=4.0.0a3,<5.0"
+async-timeout = {version = ">=4.0,<5.0", markers = "python_version < \"3.11\""}
attrs = ">=17.3.0"
-charset-normalizer = ">=2.0,<4.0"
frozenlist = ">=1.1.1"
multidict = ">=4.5,<7.0"
yarl = ">=1.0,<2.0"
[package.extras]
-speedups = ["Brotli", "aiodns", "cchardet"]
+speedups = ["Brotli", "aiodns", "brotlicffi"]
[[package]]
name = "aiosignal"
@@ -163,24 +153,25 @@ files = [
[[package]]
name = "anyio"
-version = "4.0.0"
+version = "4.3.0"
description = "High level compatibility layer for multiple asynchronous event loop implementations"
optional = false
python-versions = ">=3.8"
files = [
- {file = "anyio-4.0.0-py3-none-any.whl", hash = "sha256:cfdb2b588b9fc25ede96d8db56ed50848b0b649dca3dd1df0b11f683bb9e0b5f"},
- {file = "anyio-4.0.0.tar.gz", hash = "sha256:f7ed51751b2c2add651e5747c891b47e26d2a21be5d32d9311dfe9692f3e5d7a"},
+ {file = "anyio-4.3.0-py3-none-any.whl", hash = "sha256:048e05d0f6caeed70d731f3db756d35dcc1f35747c8c403364a8332c630441b8"},
+ {file = "anyio-4.3.0.tar.gz", hash = "sha256:f75253795a87df48568485fd18cdd2a3fa5c4f7c5be8e5e36637733fce06fed6"},
]
[package.dependencies]
exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""}
idna = ">=2.8"
sniffio = ">=1.1"
+typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""}
[package.extras]
-doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)"]
-test = ["anyio[trio]", "coverage[toml] (>=7)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"]
-trio = ["trio (>=0.22)"]
+doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"]
+test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"]
+trio = ["trio (>=0.23)"]
[[package]]
name = "appdirs"
@@ -195,13 +186,13 @@ files = [
[[package]]
name = "appnope"
-version = "0.1.3"
+version = "0.1.4"
description = "Disable App Nap on macOS >= 10.9"
optional = false
-python-versions = "*"
+python-versions = ">=3.6"
files = [
- {file = "appnope-0.1.3-py2.py3-none-any.whl", hash = "sha256:265a455292d0bd8a72453494fa24df5a11eb18373a60c7c0430889f22548605e"},
- {file = "appnope-0.1.3.tar.gz", hash = "sha256:02bd91c4de869fbb1e1c50aafc4098827a7a54ab2f39d9dcba6c9547ed920e24"},
+ {file = "appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c"},
+ {file = "appnope-0.1.4.tar.gz", hash = "sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee"},
]
[[package]]
@@ -325,36 +316,36 @@ files = [
[[package]]
name = "attrs"
-version = "23.1.0"
+version = "23.2.0"
description = "Classes Without Boilerplate"
optional = false
python-versions = ">=3.7"
files = [
- {file = "attrs-23.1.0-py3-none-any.whl", hash = "sha256:1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04"},
- {file = "attrs-23.1.0.tar.gz", hash = "sha256:6279836d581513a26f1bf235f9acd333bc9115683f14f7e8fae46c98fc50e015"},
+ {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"},
+ {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"},
]
[package.extras]
cov = ["attrs[tests]", "coverage[toml] (>=5.3)"]
-dev = ["attrs[docs,tests]", "pre-commit"]
+dev = ["attrs[tests]", "pre-commit"]
docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"]
tests = ["attrs[tests-no-zope]", "zope-interface"]
-tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"]
+tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"]
[[package]]
name = "babel"
-version = "2.13.1"
+version = "2.14.0"
description = "Internationalization utilities"
optional = false
python-versions = ">=3.7"
files = [
- {file = "Babel-2.13.1-py3-none-any.whl", hash = "sha256:7077a4984b02b6727ac10f1f7294484f737443d7e2e66c5e4380e41a3ae0b4ed"},
- {file = "Babel-2.13.1.tar.gz", hash = "sha256:33e0952d7dd6374af8dbf6768cc4ddf3ccfefc244f9986d4074704f2fbd18900"},
+ {file = "Babel-2.14.0-py3-none-any.whl", hash = "sha256:efb1a25b7118e67ce3a259bed20545c29cb68be8ad2c784c83689981b7a57287"},
+ {file = "Babel-2.14.0.tar.gz", hash = "sha256:6919867db036398ba21eb5c7a0f6b28ab8cbc3ae7a73a44ebe34ae74a4e7d363"},
]
[package.dependencies]
pytz = {version = ">=2015.7", markers = "python_version < \"3.9\""}
-setuptools = {version = "*", markers = "python_version >= \"3.12\""}
[package.extras]
dev = ["freezegun (>=1.0,<2.0)", "pytest (>=6.0)", "pytest-cov"]
@@ -390,47 +381,65 @@ test-tox-coverage = ["coverage (>=5.5)"]
[[package]]
name = "beautifulsoup4"
-version = "4.12.2"
+version = "4.12.3"
description = "Screen-scraping library"
optional = false
python-versions = ">=3.6.0"
files = [
- {file = "beautifulsoup4-4.12.2-py3-none-any.whl", hash = "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a"},
- {file = "beautifulsoup4-4.12.2.tar.gz", hash = "sha256:492bbc69dca35d12daac71c4db1bfff0c876c00ef4a2ffacce226d4638eb72da"},
+ {file = "beautifulsoup4-4.12.3-py3-none-any.whl", hash = "sha256:b80878c9f40111313e55da8ba20bdba06d8fa3969fc68304167741bbf9e082ed"},
+ {file = "beautifulsoup4-4.12.3.tar.gz", hash = "sha256:74e3d1928edc070d21748185c46e3fb33490f22f52a3addee9aee0f4f7781051"},
]
[package.dependencies]
soupsieve = ">1.2"
[package.extras]
+cchardet = ["cchardet"]
+chardet = ["chardet"]
+charset-normalizer = ["charset-normalizer"]
html5lib = ["html5lib"]
lxml = ["lxml"]
+[[package]]
+name = "better-abc"
+version = "0.0.3"
+description = "Python ABC plus abstract attributes"
+optional = false
+python-versions = "*"
+files = [
+ {file = "better-abc-0.0.3.tar.gz", hash = "sha256:a880fd6bc9675da2ec991e8712a555bffa0f12722efed78c739f78343cf989f6"},
+ {file = "better_abc-0.0.3-py3-none-any.whl", hash = "sha256:3ae73b473fbeb536a548f542984976e80b821676ae6e18f14e24d8e180647187"},
+]
+
[[package]]
name = "black"
-version = "23.10.1"
+version = "23.12.1"
description = "The uncompromising code formatter."
optional = false
python-versions = ">=3.8"
files = [
- {file = "black-23.10.1-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:ec3f8e6234c4e46ff9e16d9ae96f4ef69fa328bb4ad08198c8cee45bb1f08c69"},
- {file = "black-23.10.1-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:1b917a2aa020ca600483a7b340c165970b26e9029067f019e3755b56e8dd5916"},
- {file = "black-23.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c74de4c77b849e6359c6f01987e94873c707098322b91490d24296f66d067dc"},
- {file = "black-23.10.1-cp310-cp310-win_amd64.whl", hash = "sha256:7b4d10b0f016616a0d93d24a448100adf1699712fb7a4efd0e2c32bbb219b173"},
- {file = "black-23.10.1-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b15b75fc53a2fbcac8a87d3e20f69874d161beef13954747e053bca7a1ce53a0"},
- {file = "black-23.10.1-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:e293e4c2f4a992b980032bbd62df07c1bcff82d6964d6c9496f2cd726e246ace"},
- {file = "black-23.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d56124b7a61d092cb52cce34182a5280e160e6aff3137172a68c2c2c4b76bcb"},
- {file = "black-23.10.1-cp311-cp311-win_amd64.whl", hash = "sha256:3f157a8945a7b2d424da3335f7ace89c14a3b0625e6593d21139c2d8214d55ce"},
- {file = "black-23.10.1-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:cfcce6f0a384d0da692119f2d72d79ed07c7159879d0bb1bb32d2e443382bf3a"},
- {file = "black-23.10.1-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:33d40f5b06be80c1bbce17b173cda17994fbad096ce60eb22054da021bf933d1"},
- {file = "black-23.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:840015166dbdfbc47992871325799fd2dc0dcf9395e401ada6d88fe11498abad"},
- {file = "black-23.10.1-cp38-cp38-win_amd64.whl", hash = "sha256:037e9b4664cafda5f025a1728c50a9e9aedb99a759c89f760bd83730e76ba884"},
- {file = "black-23.10.1-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:7cb5936e686e782fddb1c73f8aa6f459e1ad38a6a7b0e54b403f1f05a1507ee9"},
- {file = "black-23.10.1-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:7670242e90dc129c539e9ca17665e39a146a761e681805c54fbd86015c7c84f7"},
- {file = "black-23.10.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ed45ac9a613fb52dad3b61c8dea2ec9510bf3108d4db88422bacc7d1ba1243d"},
- {file = "black-23.10.1-cp39-cp39-win_amd64.whl", hash = "sha256:6d23d7822140e3fef190734216cefb262521789367fbdc0b3f22af6744058982"},
- {file = "black-23.10.1-py3-none-any.whl", hash = "sha256:d431e6739f727bb2e0495df64a6c7a5310758e87505f5f8cde9ff6c0f2d7e4fe"},
- {file = "black-23.10.1.tar.gz", hash = "sha256:1f8ce316753428ff68749c65a5f7844631aa18c8679dfd3ca9dc1a289979c258"},
+ {file = "black-23.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0aaf6041986767a5e0ce663c7a2f0e9eaf21e6ff87a5f95cbf3675bfd4c41d2"},
+ {file = "black-23.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c88b3711d12905b74206227109272673edce0cb29f27e1385f33b0163c414bba"},
+ {file = "black-23.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a920b569dc6b3472513ba6ddea21f440d4b4c699494d2e972a1753cdc25df7b0"},
+ {file = "black-23.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:3fa4be75ef2a6b96ea8d92b1587dd8cb3a35c7e3d51f0738ced0781c3aa3a5a3"},
+ {file = "black-23.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8d4df77958a622f9b5a4c96edb4b8c0034f8434032ab11077ec6c56ae9f384ba"},
+ {file = "black-23.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:602cfb1196dc692424c70b6507593a2b29aac0547c1be9a1d1365f0d964c353b"},
+ {file = "black-23.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c4352800f14be5b4864016882cdba10755bd50805c95f728011bcb47a4afd59"},
+ {file = "black-23.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:0808494f2b2df923ffc5723ed3c7b096bd76341f6213989759287611e9837d50"},
+ {file = "black-23.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:25e57fd232a6d6ff3f4478a6fd0580838e47c93c83eaf1ccc92d4faf27112c4e"},
+ {file = "black-23.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2d9e13db441c509a3763a7a3d9a49ccc1b4e974a47be4e08ade2a228876500ec"},
+ {file = "black-23.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1bd9c210f8b109b1762ec9fd36592fdd528485aadb3f5849b2740ef17e674e"},
+ {file = "black-23.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:ae76c22bde5cbb6bfd211ec343ded2163bba7883c7bc77f6b756a1049436fbb9"},
+ {file = "black-23.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1fa88a0f74e50e4487477bc0bb900c6781dbddfdfa32691e780bf854c3b4a47f"},
+ {file = "black-23.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a4d6a9668e45ad99d2f8ec70d5c8c04ef4f32f648ef39048d010b0689832ec6d"},
+ {file = "black-23.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b18fb2ae6c4bb63eebe5be6bd869ba2f14fd0259bda7d18a46b764d8fb86298a"},
+ {file = "black-23.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:c04b6d9d20e9c13f43eee8ea87d44156b8505ca8a3c878773f68b4e4812a421e"},
+ {file = "black-23.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3e1b38b3135fd4c025c28c55ddfc236b05af657828a8a6abe5deec419a0b7055"},
+ {file = "black-23.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4f0031eaa7b921db76decd73636ef3a12c942ed367d8c3841a0739412b260a54"},
+ {file = "black-23.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97e56155c6b737854e60a9ab1c598ff2533d57e7506d97af5481141671abf3ea"},
+ {file = "black-23.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:dd15245c8b68fe2b6bd0f32c1556509d11bb33aec9b5d0866dd8e2ed3dba09c2"},
+ {file = "black-23.12.1-py3-none-any.whl", hash = "sha256:78baad24af0f033958cad29731e27363183e140962595def56423e626f4bee3e"},
+ {file = "black-23.12.1.tar.gz", hash = "sha256:4ce3ef14ebe8d9509188014d96af1c456a910d5b5cbf434a09fef7e024b3d0d5"},
]
[package.dependencies]
@@ -444,7 +453,7 @@ typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""}
[package.extras]
colorama = ["colorama (>=0.4.3)"]
-d = ["aiohttp (>=3.7.4)"]
+d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"]
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
uvloop = ["uvloop (>=0.15.2)"]
@@ -468,13 +477,13 @@ css = ["tinycss2 (>=1.1.0,<1.3)"]
[[package]]
name = "certifi"
-version = "2023.7.22"
+version = "2024.2.2"
description = "Python package for providing Mozilla's CA Bundle."
optional = false
python-versions = ">=3.6"
files = [
- {file = "certifi-2023.7.22-py3-none-any.whl", hash = "sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9"},
- {file = "certifi-2023.7.22.tar.gz", hash = "sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082"},
+ {file = "certifi-2024.2.2-py3-none-any.whl", hash = "sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1"},
+ {file = "certifi-2024.2.2.tar.gz", hash = "sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f"},
]
[[package]]
@@ -543,112 +552,112 @@ pycparser = "*"
[[package]]
name = "charset-normalizer"
-version = "3.3.1"
+version = "3.3.2"
description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
optional = false
python-versions = ">=3.7.0"
files = [
- {file = "charset-normalizer-3.3.1.tar.gz", hash = "sha256:d9137a876020661972ca6eec0766d81aef8a5627df628b664b234b73396e727e"},
- {file = "charset_normalizer-3.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8aee051c89e13565c6bd366813c386939f8e928af93c29fda4af86d25b73d8f8"},
- {file = "charset_normalizer-3.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:352a88c3df0d1fa886562384b86f9a9e27563d4704ee0e9d56ec6fcd270ea690"},
- {file = "charset_normalizer-3.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:223b4d54561c01048f657fa6ce41461d5ad8ff128b9678cfe8b2ecd951e3f8a2"},
- {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f861d94c2a450b974b86093c6c027888627b8082f1299dfd5a4bae8e2292821"},
- {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1171ef1fc5ab4693c5d151ae0fdad7f7349920eabbaca6271f95969fa0756c2d"},
- {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28f512b9a33235545fbbdac6a330a510b63be278a50071a336afc1b78781b147"},
- {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0e842112fe3f1a4ffcf64b06dc4c61a88441c2f02f373367f7b4c1aa9be2ad5"},
- {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f9bc2ce123637a60ebe819f9fccc614da1bcc05798bbbaf2dd4ec91f3e08846"},
- {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:f194cce575e59ffe442c10a360182a986535fd90b57f7debfaa5c845c409ecc3"},
- {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:9a74041ba0bfa9bc9b9bb2cd3238a6ab3b7618e759b41bd15b5f6ad958d17605"},
- {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:b578cbe580e3b41ad17b1c428f382c814b32a6ce90f2d8e39e2e635d49e498d1"},
- {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:6db3cfb9b4fcecb4390db154e75b49578c87a3b9979b40cdf90d7e4b945656e1"},
- {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:debb633f3f7856f95ad957d9b9c781f8e2c6303ef21724ec94bea2ce2fcbd056"},
- {file = "charset_normalizer-3.3.1-cp310-cp310-win32.whl", hash = "sha256:87071618d3d8ec8b186d53cb6e66955ef2a0e4fa63ccd3709c0c90ac5a43520f"},
- {file = "charset_normalizer-3.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:e372d7dfd154009142631de2d316adad3cc1c36c32a38b16a4751ba78da2a397"},
- {file = "charset_normalizer-3.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ae4070f741f8d809075ef697877fd350ecf0b7c5837ed68738607ee0a2c572cf"},
- {file = "charset_normalizer-3.3.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:58e875eb7016fd014c0eea46c6fa92b87b62c0cb31b9feae25cbbe62c919f54d"},
- {file = "charset_normalizer-3.3.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dbd95e300367aa0827496fe75a1766d198d34385a58f97683fe6e07f89ca3e3c"},
- {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de0b4caa1c8a21394e8ce971997614a17648f94e1cd0640fbd6b4d14cab13a72"},
- {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:985c7965f62f6f32bf432e2681173db41336a9c2611693247069288bcb0c7f8b"},
- {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a15c1fe6d26e83fd2e5972425a772cca158eae58b05d4a25a4e474c221053e2d"},
- {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ae55d592b02c4349525b6ed8f74c692509e5adffa842e582c0f861751701a673"},
- {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:be4d9c2770044a59715eb57c1144dedea7c5d5ae80c68fb9959515037cde2008"},
- {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:851cf693fb3aaef71031237cd68699dded198657ec1e76a76eb8be58c03a5d1f"},
- {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:31bbaba7218904d2eabecf4feec0d07469284e952a27400f23b6628439439fa7"},
- {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:871d045d6ccc181fd863a3cd66ee8e395523ebfbc57f85f91f035f50cee8e3d4"},
- {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:501adc5eb6cd5f40a6f77fbd90e5ab915c8fd6e8c614af2db5561e16c600d6f3"},
- {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f5fb672c396d826ca16a022ac04c9dce74e00a1c344f6ad1a0fdc1ba1f332213"},
- {file = "charset_normalizer-3.3.1-cp311-cp311-win32.whl", hash = "sha256:bb06098d019766ca16fc915ecaa455c1f1cd594204e7f840cd6258237b5079a8"},
- {file = "charset_normalizer-3.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:8af5a8917b8af42295e86b64903156b4f110a30dca5f3b5aedea123fbd638bff"},
- {file = "charset_normalizer-3.3.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:7ae8e5142dcc7a49168f4055255dbcced01dc1714a90a21f87448dc8d90617d1"},
- {file = "charset_normalizer-3.3.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5b70bab78accbc672f50e878a5b73ca692f45f5b5e25c8066d748c09405e6a55"},
- {file = "charset_normalizer-3.3.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5ceca5876032362ae73b83347be8b5dbd2d1faf3358deb38c9c88776779b2e2f"},
- {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34d95638ff3613849f473afc33f65c401a89f3b9528d0d213c7037c398a51296"},
- {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9edbe6a5bf8b56a4a84533ba2b2f489d0046e755c29616ef8830f9e7d9cf5728"},
- {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f6a02a3c7950cafaadcd46a226ad9e12fc9744652cc69f9e5534f98b47f3bbcf"},
- {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10b8dd31e10f32410751b3430996f9807fc4d1587ca69772e2aa940a82ab571a"},
- {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edc0202099ea1d82844316604e17d2b175044f9bcb6b398aab781eba957224bd"},
- {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b891a2f68e09c5ef989007fac11476ed33c5c9994449a4e2c3386529d703dc8b"},
- {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:71ef3b9be10070360f289aea4838c784f8b851be3ba58cf796262b57775c2f14"},
- {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:55602981b2dbf8184c098bc10287e8c245e351cd4fdcad050bd7199d5a8bf514"},
- {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:46fb9970aa5eeca547d7aa0de5d4b124a288b42eaefac677bde805013c95725c"},
- {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:520b7a142d2524f999447b3a0cf95115df81c4f33003c51a6ab637cbda9d0bf4"},
- {file = "charset_normalizer-3.3.1-cp312-cp312-win32.whl", hash = "sha256:8ec8ef42c6cd5856a7613dcd1eaf21e5573b2185263d87d27c8edcae33b62a61"},
- {file = "charset_normalizer-3.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:baec8148d6b8bd5cee1ae138ba658c71f5b03e0d69d5907703e3e1df96db5e41"},
- {file = "charset_normalizer-3.3.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:63a6f59e2d01310f754c270e4a257426fe5a591dc487f1983b3bbe793cf6bac6"},
- {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d6bfc32a68bc0933819cfdfe45f9abc3cae3877e1d90aac7259d57e6e0f85b1"},
- {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4f3100d86dcd03c03f7e9c3fdb23d92e32abbca07e7c13ebd7ddfbcb06f5991f"},
- {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39b70a6f88eebe239fa775190796d55a33cfb6d36b9ffdd37843f7c4c1b5dc67"},
- {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e12f8ee80aa35e746230a2af83e81bd6b52daa92a8afaef4fea4a2ce9b9f4fa"},
- {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b6cefa579e1237ce198619b76eaa148b71894fb0d6bcf9024460f9bf30fd228"},
- {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:61f1e3fb621f5420523abb71f5771a204b33c21d31e7d9d86881b2cffe92c47c"},
- {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4f6e2a839f83a6a76854d12dbebde50e4b1afa63e27761549d006fa53e9aa80e"},
- {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:1ec937546cad86d0dce5396748bf392bb7b62a9eeb8c66efac60e947697f0e58"},
- {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:82ca51ff0fc5b641a2d4e1cc8c5ff108699b7a56d7f3ad6f6da9dbb6f0145b48"},
- {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:633968254f8d421e70f91c6ebe71ed0ab140220469cf87a9857e21c16687c034"},
- {file = "charset_normalizer-3.3.1-cp37-cp37m-win32.whl", hash = "sha256:c0c72d34e7de5604df0fde3644cc079feee5e55464967d10b24b1de268deceb9"},
- {file = "charset_normalizer-3.3.1-cp37-cp37m-win_amd64.whl", hash = "sha256:63accd11149c0f9a99e3bc095bbdb5a464862d77a7e309ad5938fbc8721235ae"},
- {file = "charset_normalizer-3.3.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5a3580a4fdc4ac05f9e53c57f965e3594b2f99796231380adb2baaab96e22761"},
- {file = "charset_normalizer-3.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2465aa50c9299d615d757c1c888bc6fef384b7c4aec81c05a0172b4400f98557"},
- {file = "charset_normalizer-3.3.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cb7cd68814308aade9d0c93c5bd2ade9f9441666f8ba5aa9c2d4b389cb5e2a45"},
- {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91e43805ccafa0a91831f9cd5443aa34528c0c3f2cc48c4cb3d9a7721053874b"},
- {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:854cc74367180beb327ab9d00f964f6d91da06450b0855cbbb09187bcdb02de5"},
- {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c15070ebf11b8b7fd1bfff7217e9324963c82dbdf6182ff7050519e350e7ad9f"},
- {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c4c99f98fc3a1835af8179dcc9013f93594d0670e2fa80c83aa36346ee763d2"},
- {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3fb765362688821404ad6cf86772fc54993ec11577cd5a92ac44b4c2ba52155b"},
- {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:dced27917823df984fe0c80a5c4ad75cf58df0fbfae890bc08004cd3888922a2"},
- {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a66bcdf19c1a523e41b8e9d53d0cedbfbac2e93c649a2e9502cb26c014d0980c"},
- {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ecd26be9f112c4f96718290c10f4caea6cc798459a3a76636b817a0ed7874e42"},
- {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:3f70fd716855cd3b855316b226a1ac8bdb3caf4f7ea96edcccc6f484217c9597"},
- {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:17a866d61259c7de1bdadef418a37755050ddb4b922df8b356503234fff7932c"},
- {file = "charset_normalizer-3.3.1-cp38-cp38-win32.whl", hash = "sha256:548eefad783ed787b38cb6f9a574bd8664468cc76d1538215d510a3cd41406cb"},
- {file = "charset_normalizer-3.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:45f053a0ece92c734d874861ffe6e3cc92150e32136dd59ab1fb070575189c97"},
- {file = "charset_normalizer-3.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bc791ec3fd0c4309a753f95bb6c749ef0d8ea3aea91f07ee1cf06b7b02118f2f"},
- {file = "charset_normalizer-3.3.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0c8c61fb505c7dad1d251c284e712d4e0372cef3b067f7ddf82a7fa82e1e9a93"},
- {file = "charset_normalizer-3.3.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2c092be3885a1b7899cd85ce24acedc1034199d6fca1483fa2c3a35c86e43041"},
- {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2000c54c395d9e5e44c99dc7c20a64dc371f777faf8bae4919ad3e99ce5253e"},
- {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4cb50a0335382aac15c31b61d8531bc9bb657cfd848b1d7158009472189f3d62"},
- {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c30187840d36d0ba2893bc3271a36a517a717f9fd383a98e2697ee890a37c273"},
- {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe81b35c33772e56f4b6cf62cf4aedc1762ef7162a31e6ac7fe5e40d0149eb67"},
- {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0bf89afcbcf4d1bb2652f6580e5e55a840fdf87384f6063c4a4f0c95e378656"},
- {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:06cf46bdff72f58645434d467bf5228080801298fbba19fe268a01b4534467f5"},
- {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:3c66df3f41abee950d6638adc7eac4730a306b022570f71dd0bd6ba53503ab57"},
- {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd805513198304026bd379d1d516afbf6c3c13f4382134a2c526b8b854da1c2e"},
- {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:9505dc359edb6a330efcd2be825fdb73ee3e628d9010597aa1aee5aa63442e97"},
- {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:31445f38053476a0c4e6d12b047b08ced81e2c7c712e5a1ad97bc913256f91b2"},
- {file = "charset_normalizer-3.3.1-cp39-cp39-win32.whl", hash = "sha256:bd28b31730f0e982ace8663d108e01199098432a30a4c410d06fe08fdb9e93f4"},
- {file = "charset_normalizer-3.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:555fe186da0068d3354cdf4bbcbc609b0ecae4d04c921cc13e209eece7720727"},
- {file = "charset_normalizer-3.3.1-py3-none-any.whl", hash = "sha256:800561453acdecedaac137bf09cd719c7a440b6800ec182f077bb8e7025fb708"},
+ {file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-win32.whl", hash = "sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-win32.whl", hash = "sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-win32.whl", hash = "sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-win32.whl", hash = "sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4"},
+ {file = "charset_normalizer-3.3.2-cp37-cp37m-win_amd64.whl", hash = "sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-win32.whl", hash = "sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25"},
+ {file = "charset_normalizer-3.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-win32.whl", hash = "sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d"},
+ {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"},
]
[[package]]
name = "circuitsvis"
-version = "1.43.1"
+version = "1.43.2"
description = "Mechanistic Interpretability Visualizations"
optional = false
python-versions = ">=3.8"
files = [
- {file = "circuitsvis-1.43.1-py3-none-any.whl", hash = "sha256:096138020986f79f1493c0ad8e94107a19b5d19cd5771b4401e231e993267019"},
- {file = "circuitsvis-1.43.1.tar.gz", hash = "sha256:5d730b9ee4c256cdf9c9da598343e7c8cd1eceeacfa8385969c008fa7123d6bc"},
+ {file = "circuitsvis-1.43.2-py3-none-any.whl", hash = "sha256:1128fde5de8b738dd3c932d0b0ec4ee5556387b4405592fdf37f617e647183fb"},
+ {file = "circuitsvis-1.43.2.tar.gz", hash = "sha256:388c1a6ea1bcf308da51fa6f67be761483ba361321d2e111f4c28faaea458287"},
]
[package.dependencies]
@@ -699,82 +708,80 @@ files = [
[[package]]
name = "comm"
-version = "0.1.4"
+version = "0.2.2"
description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc."
optional = false
-python-versions = ">=3.6"
+python-versions = ">=3.8"
files = [
- {file = "comm-0.1.4-py3-none-any.whl", hash = "sha256:6d52794cba11b36ed9860999cd10fd02d6b2eac177068fdd585e1e2f8a96e67a"},
- {file = "comm-0.1.4.tar.gz", hash = "sha256:354e40a59c9dd6db50c5cc6b4acc887d82e9603787f83b68c01a80a923984d15"},
+ {file = "comm-0.2.2-py3-none-any.whl", hash = "sha256:e6fb86cb70ff661ee8c9c14e7d36d6de3b4066f1441be4063df9c5009f0a64d3"},
+ {file = "comm-0.2.2.tar.gz", hash = "sha256:3fd7a84065306e07bea1773df6eb8282de51ba82f77c72f9c85716ab11fe980e"},
]
[package.dependencies]
traitlets = ">=4"
[package.extras]
-lint = ["black (>=22.6.0)", "mdformat (>0.7)", "mdformat-gfm (>=0.3.5)", "ruff (>=0.0.156)"]
test = ["pytest"]
-typing = ["mypy (>=0.990)"]
[[package]]
name = "coverage"
-version = "7.3.2"
+version = "7.4.4"
description = "Code coverage measurement for Python"
optional = false
python-versions = ">=3.8"
files = [
- {file = "coverage-7.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d872145f3a3231a5f20fd48500274d7df222e291d90baa2026cc5152b7ce86bf"},
- {file = "coverage-7.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:310b3bb9c91ea66d59c53fa4989f57d2436e08f18fb2f421a1b0b6b8cc7fffda"},
- {file = "coverage-7.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f47d39359e2c3779c5331fc740cf4bce6d9d680a7b4b4ead97056a0ae07cb49a"},
- {file = "coverage-7.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aa72dbaf2c2068404b9870d93436e6d23addd8bbe9295f49cbca83f6e278179c"},
- {file = "coverage-7.3.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:beaa5c1b4777f03fc63dfd2a6bd820f73f036bfb10e925fce067b00a340d0f3f"},
- {file = "coverage-7.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:dbc1b46b92186cc8074fee9d9fbb97a9dd06c6cbbef391c2f59d80eabdf0faa6"},
- {file = "coverage-7.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:315a989e861031334d7bee1f9113c8770472db2ac484e5b8c3173428360a9148"},
- {file = "coverage-7.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d1bc430677773397f64a5c88cb522ea43175ff16f8bfcc89d467d974cb2274f9"},
- {file = "coverage-7.3.2-cp310-cp310-win32.whl", hash = "sha256:a889ae02f43aa45032afe364c8ae84ad3c54828c2faa44f3bfcafecb5c96b02f"},
- {file = "coverage-7.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:c0ba320de3fb8c6ec16e0be17ee1d3d69adcda99406c43c0409cb5c41788a611"},
- {file = "coverage-7.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ac8c802fa29843a72d32ec56d0ca792ad15a302b28ca6203389afe21f8fa062c"},
- {file = "coverage-7.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:89a937174104339e3a3ffcf9f446c00e3a806c28b1841c63edb2b369310fd074"},
- {file = "coverage-7.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e267e9e2b574a176ddb983399dec325a80dbe161f1a32715c780b5d14b5f583a"},
- {file = "coverage-7.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2443cbda35df0d35dcfb9bf8f3c02c57c1d6111169e3c85fc1fcc05e0c9f39a3"},
- {file = "coverage-7.3.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4175e10cc8dda0265653e8714b3174430b07c1dca8957f4966cbd6c2b1b8065a"},
- {file = "coverage-7.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0cbf38419fb1a347aaf63481c00f0bdc86889d9fbf3f25109cf96c26b403fda1"},
- {file = "coverage-7.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:5c913b556a116b8d5f6ef834038ba983834d887d82187c8f73dec21049abd65c"},
- {file = "coverage-7.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1981f785239e4e39e6444c63a98da3a1db8e971cb9ceb50a945ba6296b43f312"},
- {file = "coverage-7.3.2-cp311-cp311-win32.whl", hash = "sha256:43668cabd5ca8258f5954f27a3aaf78757e6acf13c17604d89648ecc0cc66640"},
- {file = "coverage-7.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10c39c0452bf6e694511c901426d6b5ac005acc0f78ff265dbe36bf81f808a2"},
- {file = "coverage-7.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:4cbae1051ab791debecc4a5dcc4a1ff45fc27b91b9aee165c8a27514dd160836"},
- {file = "coverage-7.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:12d15ab5833a997716d76f2ac1e4b4d536814fc213c85ca72756c19e5a6b3d63"},
- {file = "coverage-7.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c7bba973ebee5e56fe9251300c00f1579652587a9f4a5ed8404b15a0471f216"},
- {file = "coverage-7.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fe494faa90ce6381770746077243231e0b83ff3f17069d748f645617cefe19d4"},
- {file = "coverage-7.3.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6e9589bd04d0461a417562649522575d8752904d35c12907d8c9dfeba588faf"},
- {file = "coverage-7.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d51ac2a26f71da1b57f2dc81d0e108b6ab177e7d30e774db90675467c847bbdf"},
- {file = "coverage-7.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:99b89d9f76070237975b315b3d5f4d6956ae354a4c92ac2388a5695516e47c84"},
- {file = "coverage-7.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:fa28e909776dc69efb6ed975a63691bc8172b64ff357e663a1bb06ff3c9b589a"},
- {file = "coverage-7.3.2-cp312-cp312-win32.whl", hash = "sha256:289fe43bf45a575e3ab10b26d7b6f2ddb9ee2dba447499f5401cfb5ecb8196bb"},
- {file = "coverage-7.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:7dbc3ed60e8659bc59b6b304b43ff9c3ed858da2839c78b804973f613d3e92ed"},
- {file = "coverage-7.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f94b734214ea6a36fe16e96a70d941af80ff3bfd716c141300d95ebc85339738"},
- {file = "coverage-7.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:af3d828d2c1cbae52d34bdbb22fcd94d1ce715d95f1a012354a75e5913f1bda2"},
- {file = "coverage-7.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:630b13e3036e13c7adc480ca42fa7afc2a5d938081d28e20903cf7fd687872e2"},
- {file = "coverage-7.3.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c9eacf273e885b02a0273bb3a2170f30e2d53a6d53b72dbe02d6701b5296101c"},
- {file = "coverage-7.3.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8f17966e861ff97305e0801134e69db33b143bbfb36436efb9cfff6ec7b2fd9"},
- {file = "coverage-7.3.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b4275802d16882cf9c8b3d057a0839acb07ee9379fa2749eca54efbce1535b82"},
- {file = "coverage-7.3.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:72c0cfa5250f483181e677ebc97133ea1ab3eb68645e494775deb6a7f6f83901"},
- {file = "coverage-7.3.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cb536f0dcd14149425996821a168f6e269d7dcd2c273a8bff8201e79f5104e76"},
- {file = "coverage-7.3.2-cp38-cp38-win32.whl", hash = "sha256:307adb8bd3abe389a471e649038a71b4eb13bfd6b7dd9a129fa856f5c695cf92"},
- {file = "coverage-7.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:88ed2c30a49ea81ea3b7f172e0269c182a44c236eb394718f976239892c0a27a"},
- {file = "coverage-7.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b631c92dfe601adf8f5ebc7fc13ced6bb6e9609b19d9a8cd59fa47c4186ad1ce"},
- {file = "coverage-7.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d3d9df4051c4a7d13036524b66ecf7a7537d14c18a384043f30a303b146164e9"},
- {file = "coverage-7.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f7363d3b6a1119ef05015959ca24a9afc0ea8a02c687fe7e2d557705375c01f"},
- {file = "coverage-7.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f11cc3c967a09d3695d2a6f03fb3e6236622b93be7a4b5dc09166a861be6d25"},
- {file = "coverage-7.3.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:149de1d2401ae4655c436a3dced6dd153f4c3309f599c3d4bd97ab172eaf02d9"},
- {file = "coverage-7.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:3a4006916aa6fee7cd38db3bfc95aa9c54ebb4ffbfc47c677c8bba949ceba0a6"},
- {file = "coverage-7.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9028a3871280110d6e1aa2df1afd5ef003bab5fb1ef421d6dc748ae1c8ef2ebc"},
- {file = "coverage-7.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9f805d62aec8eb92bab5b61c0f07329275b6f41c97d80e847b03eb894f38d083"},
- {file = "coverage-7.3.2-cp39-cp39-win32.whl", hash = "sha256:d1c88ec1a7ff4ebca0219f5b1ef863451d828cccf889c173e1253aa84b1e07ce"},
- {file = "coverage-7.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:b4767da59464bb593c07afceaddea61b154136300881844768037fd5e859353f"},
- {file = "coverage-7.3.2-pp38.pp39.pp310-none-any.whl", hash = "sha256:ae97af89f0fbf373400970c0a21eef5aa941ffeed90aee43650b81f7d7f47637"},
- {file = "coverage-7.3.2.tar.gz", hash = "sha256:be32ad29341b0170e795ca590e1c07e81fc061cb5b10c74ce7203491484404ef"},
+ {file = "coverage-7.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0be5efd5127542ef31f165de269f77560d6cdef525fffa446de6f7e9186cfb2"},
+ {file = "coverage-7.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ccd341521be3d1b3daeb41960ae94a5e87abe2f46f17224ba5d6f2b8398016cf"},
+ {file = "coverage-7.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09fa497a8ab37784fbb20ab699c246053ac294d13fc7eb40ec007a5043ec91f8"},
+ {file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b1a93009cb80730c9bca5d6d4665494b725b6e8e157c1cb7f2db5b4b122ea562"},
+ {file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:690db6517f09336559dc0b5f55342df62370a48f5469fabf502db2c6d1cffcd2"},
+ {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:09c3255458533cb76ef55da8cc49ffab9e33f083739c8bd4f58e79fecfe288f7"},
+ {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8ce1415194b4a6bd0cdcc3a1dfbf58b63f910dcb7330fe15bdff542c56949f87"},
+ {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b91cbc4b195444e7e258ba27ac33769c41b94967919f10037e6355e998af255c"},
+ {file = "coverage-7.4.4-cp310-cp310-win32.whl", hash = "sha256:598825b51b81c808cb6f078dcb972f96af96b078faa47af7dfcdf282835baa8d"},
+ {file = "coverage-7.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:09ef9199ed6653989ebbcaacc9b62b514bb63ea2f90256e71fea3ed74bd8ff6f"},
+ {file = "coverage-7.4.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0f9f50e7ef2a71e2fae92774c99170eb8304e3fdf9c8c3c7ae9bab3e7229c5cf"},
+ {file = "coverage-7.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:623512f8ba53c422fcfb2ce68362c97945095b864cda94a92edbaf5994201083"},
+ {file = "coverage-7.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0513b9508b93da4e1716744ef6ebc507aff016ba115ffe8ecff744d1322a7b63"},
+ {file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40209e141059b9370a2657c9b15607815359ab3ef9918f0196b6fccce8d3230f"},
+ {file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a2b2b78c78293782fd3767d53e6474582f62443d0504b1554370bde86cc8227"},
+ {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:73bfb9c09951125d06ee473bed216e2c3742f530fc5acc1383883125de76d9cd"},
+ {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1f384c3cc76aeedce208643697fb3e8437604b512255de6d18dae3f27655a384"},
+ {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:54eb8d1bf7cacfbf2a3186019bcf01d11c666bd495ed18717162f7eb1e9dd00b"},
+ {file = "coverage-7.4.4-cp311-cp311-win32.whl", hash = "sha256:cac99918c7bba15302a2d81f0312c08054a3359eaa1929c7e4b26ebe41e9b286"},
+ {file = "coverage-7.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:b14706df8b2de49869ae03a5ccbc211f4041750cd4a66f698df89d44f4bd30ec"},
+ {file = "coverage-7.4.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:201bef2eea65e0e9c56343115ba3814e896afe6d36ffd37bab783261db430f76"},
+ {file = "coverage-7.4.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:41c9c5f3de16b903b610d09650e5e27adbfa7f500302718c9ffd1c12cf9d6818"},
+ {file = "coverage-7.4.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d898fe162d26929b5960e4e138651f7427048e72c853607f2b200909794ed978"},
+ {file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3ea79bb50e805cd6ac058dfa3b5c8f6c040cb87fe83de10845857f5535d1db70"},
+ {file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce4b94265ca988c3f8e479e741693d143026632672e3ff924f25fab50518dd51"},
+ {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:00838a35b882694afda09f85e469c96367daa3f3f2b097d846a7216993d37f4c"},
+ {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fdfafb32984684eb03c2d83e1e51f64f0906b11e64482df3c5db936ce3839d48"},
+ {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:69eb372f7e2ece89f14751fbcbe470295d73ed41ecd37ca36ed2eb47512a6ab9"},
+ {file = "coverage-7.4.4-cp312-cp312-win32.whl", hash = "sha256:137eb07173141545e07403cca94ab625cc1cc6bc4c1e97b6e3846270e7e1fea0"},
+ {file = "coverage-7.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:d71eec7d83298f1af3326ce0ff1d0ea83c7cb98f72b577097f9083b20bdaf05e"},
+ {file = "coverage-7.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d5ae728ff3b5401cc320d792866987e7e7e880e6ebd24433b70a33b643bb0384"},
+ {file = "coverage-7.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cc4f1358cb0c78edef3ed237ef2c86056206bb8d9140e73b6b89fbcfcbdd40e1"},
+ {file = "coverage-7.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8130a2aa2acb8788e0b56938786c33c7c98562697bf9f4c7d6e8e5e3a0501e4a"},
+ {file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf271892d13e43bc2b51e6908ec9a6a5094a4df1d8af0bfc360088ee6c684409"},
+ {file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4cdc86d54b5da0df6d3d3a2f0b710949286094c3a6700c21e9015932b81447e"},
+ {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ae71e7ddb7a413dd60052e90528f2f65270aad4b509563af6d03d53e979feafd"},
+ {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:38dd60d7bf242c4ed5b38e094baf6401faa114fc09e9e6632374388a404f98e7"},
+ {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aa5b1c1bfc28384f1f53b69a023d789f72b2e0ab1b3787aae16992a7ca21056c"},
+ {file = "coverage-7.4.4-cp38-cp38-win32.whl", hash = "sha256:dfa8fe35a0bb90382837b238fff375de15f0dcdb9ae68ff85f7a63649c98527e"},
+ {file = "coverage-7.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:b2991665420a803495e0b90a79233c1433d6ed77ef282e8e152a324bbbc5e0c8"},
+ {file = "coverage-7.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3b799445b9f7ee8bf299cfaed6f5b226c0037b74886a4e11515e569b36fe310d"},
+ {file = "coverage-7.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b4d33f418f46362995f1e9d4f3a35a1b6322cb959c31d88ae56b0298e1c22357"},
+ {file = "coverage-7.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aadacf9a2f407a4688d700e4ebab33a7e2e408f2ca04dbf4aef17585389eff3e"},
+ {file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c95949560050d04d46b919301826525597f07b33beba6187d04fa64d47ac82e"},
+ {file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff7687ca3d7028d8a5f0ebae95a6e4827c5616b31a4ee1192bdfde697db110d4"},
+ {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5fc1de20b2d4a061b3df27ab9b7c7111e9a710f10dc2b84d33a4ab25065994ec"},
+ {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c74880fc64d4958159fbd537a091d2a585448a8f8508bf248d72112723974cbd"},
+ {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:742a76a12aa45b44d236815d282b03cfb1de3b4323f3e4ec933acfae08e54ade"},
+ {file = "coverage-7.4.4-cp39-cp39-win32.whl", hash = "sha256:d89d7b2974cae412400e88f35d86af72208e1ede1a541954af5d944a8ba46c57"},
+ {file = "coverage-7.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:9ca28a302acb19b6af89e90f33ee3e1906961f94b54ea37de6737b7ca9d8827c"},
+ {file = "coverage-7.4.4-pp38.pp39.pp310-none-any.whl", hash = "sha256:b2c5edc4ac10a7ef6605a966c58929ec6c1bd0917fb8c15cb3363f65aa40e677"},
+ {file = "coverage-7.4.4.tar.gz", hash = "sha256:c901df83d097649e257e803be22592aedfd5182f07b3cc87d640bbb9afd50f49"},
]
[package.dependencies]
@@ -785,71 +792,77 @@ toml = ["tomli"]
[[package]]
name = "datasets"
-version = "2.14.6"
+version = "2.18.0"
description = "HuggingFace community-driven open-source library of datasets"
optional = false
python-versions = ">=3.8.0"
files = [
- {file = "datasets-2.14.6-py3-none-any.whl", hash = "sha256:4de857ffce21cfc847236745c69f102e33cd1f0fa8398e7be9964525fd4cd5db"},
- {file = "datasets-2.14.6.tar.gz", hash = "sha256:97ebbace8ec7af11434a87d1215379927f8fee2beab2c4a674003756ecfe920c"},
+ {file = "datasets-2.18.0-py3-none-any.whl", hash = "sha256:f1bbf0e2896917a914de01cbd37075b14deea3837af87ad0d9f697388ccaeb50"},
+ {file = "datasets-2.18.0.tar.gz", hash = "sha256:cdf8b8c6abf7316377ba4f49f9589a4c74556d6b481afd0abd2284f3d69185cb"},
]
[package.dependencies]
aiohttp = "*"
-dill = ">=0.3.0,<0.3.8"
-fsspec = {version = ">=2023.1.0,<=2023.10.0", extras = ["http"]}
-huggingface-hub = ">=0.14.0,<1.0.0"
+dill = ">=0.3.0,<0.3.9"
+filelock = "*"
+fsspec = {version = ">=2023.1.0,<=2024.2.0", extras = ["http"]}
+huggingface-hub = ">=0.19.4"
multiprocess = "*"
numpy = ">=1.17"
packaging = "*"
pandas = "*"
-pyarrow = ">=8.0.0"
+pyarrow = ">=12.0.0"
+pyarrow-hotfix = "*"
pyyaml = ">=5.1"
requests = ">=2.19.0"
tqdm = ">=4.62.1"
xxhash = "*"
[package.extras]
-apache-beam = ["apache-beam (>=2.26.0,<2.44.0)"]
+apache-beam = ["apache-beam (>=2.26.0)"]
audio = ["librosa", "soundfile (>=0.12.1)"]
benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"]
-dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "black (>=23.1,<24.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "pyyaml (>=5.3.1)", "rarfile (>=4.0)", "ruff (>=0.0.241)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"]
+dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"]
-jax = ["jax (>=0.2.8,!=0.3.2,<=0.3.25)", "jaxlib (>=0.1.65,<=0.3.25)"]
+jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"]
metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"]
-quality = ["black (>=23.1,<24.0)", "pyyaml (>=5.3.1)", "ruff (>=0.0.241)"]
+quality = ["ruff (>=0.3.0)"]
s3 = ["s3fs"]
tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"]
tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"]
-tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"]
+tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"]
torch = ["torch"]
vision = ["Pillow (>=6.2.1)"]
[[package]]
name = "debugpy"
-version = "1.8.0"
+version = "1.8.1"
description = "An implementation of the Debug Adapter Protocol for Python"
optional = false
python-versions = ">=3.8"
files = [
- {file = "debugpy-1.8.0-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:7fb95ca78f7ac43393cd0e0f2b6deda438ec7c5e47fa5d38553340897d2fbdfb"},
- {file = "debugpy-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef9ab7df0b9a42ed9c878afd3eaaff471fce3fa73df96022e1f5c9f8f8c87ada"},
- {file = "debugpy-1.8.0-cp310-cp310-win32.whl", hash = "sha256:a8b7a2fd27cd9f3553ac112f356ad4ca93338feadd8910277aff71ab24d8775f"},
- {file = "debugpy-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:5d9de202f5d42e62f932507ee8b21e30d49aae7e46d5b1dd5c908db1d7068637"},
- {file = "debugpy-1.8.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:ef54404365fae8d45cf450d0544ee40cefbcb9cb85ea7afe89a963c27028261e"},
- {file = "debugpy-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:60009b132c91951354f54363f8ebdf7457aeb150e84abba5ae251b8e9f29a8a6"},
- {file = "debugpy-1.8.0-cp311-cp311-win32.whl", hash = "sha256:8cd0197141eb9e8a4566794550cfdcdb8b3db0818bdf8c49a8e8f8053e56e38b"},
- {file = "debugpy-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:a64093656c4c64dc6a438e11d59369875d200bd5abb8f9b26c1f5f723622e153"},
- {file = "debugpy-1.8.0-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:b05a6b503ed520ad58c8dc682749113d2fd9f41ffd45daec16e558ca884008cd"},
- {file = "debugpy-1.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c6fb41c98ec51dd010d7ed650accfd07a87fe5e93eca9d5f584d0578f28f35f"},
- {file = "debugpy-1.8.0-cp38-cp38-win32.whl", hash = "sha256:46ab6780159eeabb43c1495d9c84cf85d62975e48b6ec21ee10c95767c0590aa"},
- {file = "debugpy-1.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:bdc5ef99d14b9c0fcb35351b4fbfc06ac0ee576aeab6b2511702e5a648a2e595"},
- {file = "debugpy-1.8.0-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:61eab4a4c8b6125d41a34bad4e5fe3d2cc145caecd63c3fe953be4cc53e65bf8"},
- {file = "debugpy-1.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:125b9a637e013f9faac0a3d6a82bd17c8b5d2c875fb6b7e2772c5aba6d082332"},
- {file = "debugpy-1.8.0-cp39-cp39-win32.whl", hash = "sha256:57161629133113c97b387382045649a2b985a348f0c9366e22217c87b68b73c6"},
- {file = "debugpy-1.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:e3412f9faa9ade82aa64a50b602544efcba848c91384e9f93497a458767e6926"},
- {file = "debugpy-1.8.0-py2.py3-none-any.whl", hash = "sha256:9c9b0ac1ce2a42888199df1a1906e45e6f3c9555497643a85e0bf2406e3ffbc4"},
- {file = "debugpy-1.8.0.zip", hash = "sha256:12af2c55b419521e33d5fb21bd022df0b5eb267c3e178f1d374a63a2a6bdccd0"},
+ {file = "debugpy-1.8.1-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:3bda0f1e943d386cc7a0e71bfa59f4137909e2ed947fb3946c506e113000f741"},
+ {file = "debugpy-1.8.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dda73bf69ea479c8577a0448f8c707691152e6c4de7f0c4dec5a4bc11dee516e"},
+ {file = "debugpy-1.8.1-cp310-cp310-win32.whl", hash = "sha256:3a79c6f62adef994b2dbe9fc2cc9cc3864a23575b6e387339ab739873bea53d0"},
+ {file = "debugpy-1.8.1-cp310-cp310-win_amd64.whl", hash = "sha256:7eb7bd2b56ea3bedb009616d9e2f64aab8fc7000d481faec3cd26c98a964bcdd"},
+ {file = "debugpy-1.8.1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:016a9fcfc2c6b57f939673c874310d8581d51a0fe0858e7fac4e240c5eb743cb"},
+ {file = "debugpy-1.8.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd97ed11a4c7f6d042d320ce03d83b20c3fb40da892f994bc041bbc415d7a099"},
+ {file = "debugpy-1.8.1-cp311-cp311-win32.whl", hash = "sha256:0de56aba8249c28a300bdb0672a9b94785074eb82eb672db66c8144fff673146"},
+ {file = "debugpy-1.8.1-cp311-cp311-win_amd64.whl", hash = "sha256:1a9fe0829c2b854757b4fd0a338d93bc17249a3bf69ecf765c61d4c522bb92a8"},
+ {file = "debugpy-1.8.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3ebb70ba1a6524d19fa7bb122f44b74170c447d5746a503e36adc244a20ac539"},
+ {file = "debugpy-1.8.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2e658a9630f27534e63922ebf655a6ab60c370f4d2fc5c02a5b19baf4410ace"},
+ {file = "debugpy-1.8.1-cp312-cp312-win32.whl", hash = "sha256:caad2846e21188797a1f17fc09c31b84c7c3c23baf2516fed5b40b378515bbf0"},
+ {file = "debugpy-1.8.1-cp312-cp312-win_amd64.whl", hash = "sha256:edcc9f58ec0fd121a25bc950d4578df47428d72e1a0d66c07403b04eb93bcf98"},
+ {file = "debugpy-1.8.1-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:7a3afa222f6fd3d9dfecd52729bc2e12c93e22a7491405a0ecbf9e1d32d45b39"},
+ {file = "debugpy-1.8.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d915a18f0597ef685e88bb35e5d7ab968964b7befefe1aaea1eb5b2640b586c7"},
+ {file = "debugpy-1.8.1-cp38-cp38-win32.whl", hash = "sha256:92116039b5500633cc8d44ecc187abe2dfa9b90f7a82bbf81d079fcdd506bae9"},
+ {file = "debugpy-1.8.1-cp38-cp38-win_amd64.whl", hash = "sha256:e38beb7992b5afd9d5244e96ad5fa9135e94993b0c551ceebf3fe1a5d9beb234"},
+ {file = "debugpy-1.8.1-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:bfb20cb57486c8e4793d41996652e5a6a885b4d9175dd369045dad59eaacea42"},
+ {file = "debugpy-1.8.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:efd3fdd3f67a7e576dd869c184c5dd71d9aaa36ded271939da352880c012e703"},
+ {file = "debugpy-1.8.1-cp39-cp39-win32.whl", hash = "sha256:58911e8521ca0c785ac7a0539f1e77e0ce2df753f786188f382229278b4cdf23"},
+ {file = "debugpy-1.8.1-cp39-cp39-win_amd64.whl", hash = "sha256:6df9aa9599eb05ca179fb0b810282255202a66835c6efb1d112d21ecb830ddd3"},
+ {file = "debugpy-1.8.1-py2.py3-none-any.whl", hash = "sha256:28acbe2241222b87e255260c76741e1fbf04fdc3b6d094fcf57b6c6f75ce1242"},
+ {file = "debugpy-1.8.1.zip", hash = "sha256:f696d6be15be87aef621917585f9bb94b1dc9e8aced570db1b8a6fc14e8f9b42"},
]
[[package]]
@@ -876,17 +889,18 @@ files = [
[[package]]
name = "dill"
-version = "0.3.7"
+version = "0.3.8"
description = "serialize all of Python"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
files = [
- {file = "dill-0.3.7-py3-none-any.whl", hash = "sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e"},
- {file = "dill-0.3.7.tar.gz", hash = "sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03"},
+ {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"},
+ {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"},
]
[package.extras]
graph = ["objgraph (>=1.7.2)"]
+profile = ["gprof2dot (>=2022.7.29)"]
[[package]]
name = "docker-pycreds"
@@ -926,13 +940,13 @@ files = [
[[package]]
name = "exceptiongroup"
-version = "1.1.3"
+version = "1.2.0"
description = "Backport of PEP 654 (exception groups)"
optional = false
python-versions = ">=3.7"
files = [
- {file = "exceptiongroup-1.1.3-py3-none-any.whl", hash = "sha256:343280667a4585d195ca1cf9cef84a4e178c4b6cf2274caef9859782b567d5e3"},
- {file = "exceptiongroup-1.1.3.tar.gz", hash = "sha256:097acd85d473d75af5bb98e41b61ff7fe35efe6675e4f9370ec6ec5126d160e9"},
+ {file = "exceptiongroup-1.2.0-py3-none-any.whl", hash = "sha256:4bfd3996ac73b41e9b9628b04e079f193850720ea5945fc96a08633c66912f14"},
+ {file = "exceptiongroup-1.2.0.tar.gz", hash = "sha256:91f5c769735f051a4290d52edd0858999b57e5876e9f85937691bd4c9fa3ed68"},
]
[package.extras]
@@ -940,13 +954,13 @@ test = ["pytest (>=6)"]
[[package]]
name = "executing"
-version = "2.0.0"
+version = "2.0.1"
description = "Get the currently executing AST node of a frame, and other information"
optional = false
-python-versions = "*"
+python-versions = ">=3.5"
files = [
- {file = "executing-2.0.0-py2.py3-none-any.whl", hash = "sha256:06df6183df67389625f4e763921c6cf978944721abf3e714000200aab95b0657"},
- {file = "executing-2.0.0.tar.gz", hash = "sha256:0ff053696fdeef426cda5bd18eacd94f82c91f49823a2e9090124212ceea9b08"},
+ {file = "executing-2.0.1-py2.py3-none-any.whl", hash = "sha256:eac49ca94516ccc753f9fb5ce82603156e590b27525a8bc32cce8ae302eb61bc"},
+ {file = "executing-2.0.1.tar.gz", hash = "sha256:35afe2ce3affba8ee97f2d69927fa823b08b472b7b994e36a52a964b93d16147"},
]
[package.extras]
@@ -965,13 +979,13 @@ files = [
[[package]]
name = "fastjsonschema"
-version = "2.18.1"
+version = "2.19.1"
description = "Fastest Python implementation of JSON schema"
optional = false
python-versions = "*"
files = [
- {file = "fastjsonschema-2.18.1-py3-none-any.whl", hash = "sha256:aec6a19e9f66e9810ab371cc913ad5f4e9e479b63a7072a2cd060a9369e329a8"},
- {file = "fastjsonschema-2.18.1.tar.gz", hash = "sha256:06dc8680d937628e993fa0cd278f196d20449a1adc087640710846b324d422ea"},
+ {file = "fastjsonschema-2.19.1-py3-none-any.whl", hash = "sha256:3672b47bc94178c9f23dbb654bf47440155d4db9df5f7bc47643315f9c405cd0"},
+ {file = "fastjsonschema-2.19.1.tar.gz", hash = "sha256:e3126a94bdc4623d3de4485f8d468a12f02a67921315ddc87836d6e456dc789d"},
]
[package.extras]
@@ -979,19 +993,19 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc
[[package]]
name = "filelock"
-version = "3.12.4"
+version = "3.13.3"
description = "A platform independent file lock."
optional = false
python-versions = ">=3.8"
files = [
- {file = "filelock-3.12.4-py3-none-any.whl", hash = "sha256:08c21d87ded6e2b9da6728c3dff51baf1dcecf973b768ef35bcbc3447edb9ad4"},
- {file = "filelock-3.12.4.tar.gz", hash = "sha256:2e6f249f1f3654291606e046b09f1fd5eac39b360664c27f5aad072012f8bcbd"},
+ {file = "filelock-3.13.3-py3-none-any.whl", hash = "sha256:5ffa845303983e7a0b7ae17636509bc97997d58afeafa72fb141a17b152284cb"},
+ {file = "filelock-3.13.3.tar.gz", hash = "sha256:a79895a25bbefdf55d1a2a0a80968f7dbb28edcd6d4234a0afb3f37ecde4b546"},
]
[package.extras]
-docs = ["furo (>=2023.7.26)", "sphinx (>=7.1.2)", "sphinx-autodoc-typehints (>=1.24)"]
-testing = ["covdefaults (>=2.3)", "coverage (>=7.3)", "diff-cover (>=7.7)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)", "pytest-timeout (>=2.1)"]
-typing = ["typing-extensions (>=4.7.1)"]
+docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"]
+testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"]
+typing = ["typing-extensions (>=4.8)"]
[[package]]
name = "fqdn"
@@ -1006,88 +1020,103 @@ files = [
[[package]]
name = "frozenlist"
-version = "1.4.0"
+version = "1.4.1"
description = "A list-like structure which implements collections.abc.MutableSequence"
optional = false
python-versions = ">=3.8"
files = [
- {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:764226ceef3125e53ea2cb275000e309c0aa5464d43bd72abd661e27fffc26ab"},
- {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d6484756b12f40003c6128bfcc3fa9f0d49a687e171186c2d85ec82e3758c559"},
- {file = "frozenlist-1.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9ac08e601308e41eb533f232dbf6b7e4cea762f9f84f6357136eed926c15d12c"},
- {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d081f13b095d74b67d550de04df1c756831f3b83dc9881c38985834387487f1b"},
- {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:71932b597f9895f011f47f17d6428252fc728ba2ae6024e13c3398a087c2cdea"},
- {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:981b9ab5a0a3178ff413bca62526bb784249421c24ad7381e39d67981be2c326"},
- {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e41f3de4df3e80de75845d3e743b3f1c4c8613c3997a912dbf0229fc61a8b963"},
- {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6918d49b1f90821e93069682c06ffde41829c346c66b721e65a5c62b4bab0300"},
- {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0e5c8764c7829343d919cc2dfc587a8db01c4f70a4ebbc49abde5d4b158b007b"},
- {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8d0edd6b1c7fb94922bf569c9b092ee187a83f03fb1a63076e7774b60f9481a8"},
- {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e29cda763f752553fa14c68fb2195150bfab22b352572cb36c43c47bedba70eb"},
- {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:0c7c1b47859ee2cac3846fde1c1dc0f15da6cec5a0e5c72d101e0f83dcb67ff9"},
- {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:901289d524fdd571be1c7be054f48b1f88ce8dddcbdf1ec698b27d4b8b9e5d62"},
- {file = "frozenlist-1.4.0-cp310-cp310-win32.whl", hash = "sha256:1a0848b52815006ea6596c395f87449f693dc419061cc21e970f139d466dc0a0"},
- {file = "frozenlist-1.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:b206646d176a007466358aa21d85cd8600a415c67c9bd15403336c331a10d956"},
- {file = "frozenlist-1.4.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:de343e75f40e972bae1ef6090267f8260c1446a1695e77096db6cfa25e759a95"},
- {file = "frozenlist-1.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ad2a9eb6d9839ae241701d0918f54c51365a51407fd80f6b8289e2dfca977cc3"},
- {file = "frozenlist-1.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bd7bd3b3830247580de99c99ea2a01416dfc3c34471ca1298bccabf86d0ff4dc"},
- {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bdf1847068c362f16b353163391210269e4f0569a3c166bc6a9f74ccbfc7e839"},
- {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:38461d02d66de17455072c9ba981d35f1d2a73024bee7790ac2f9e361ef1cd0c"},
- {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5a32087d720c608f42caed0ef36d2b3ea61a9d09ee59a5142d6070da9041b8f"},
- {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dd65632acaf0d47608190a71bfe46b209719bf2beb59507db08ccdbe712f969b"},
- {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261b9f5d17cac914531331ff1b1d452125bf5daa05faf73b71d935485b0c510b"},
- {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b89ac9768b82205936771f8d2eb3ce88503b1556324c9f903e7156669f521472"},
- {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:008eb8b31b3ea6896da16c38c1b136cb9fec9e249e77f6211d479db79a4eaf01"},
- {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e74b0506fa5aa5598ac6a975a12aa8928cbb58e1f5ac8360792ef15de1aa848f"},
- {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:490132667476f6781b4c9458298b0c1cddf237488abd228b0b3650e5ecba7467"},
- {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:76d4711f6f6d08551a7e9ef28c722f4a50dd0fc204c56b4bcd95c6cc05ce6fbb"},
- {file = "frozenlist-1.4.0-cp311-cp311-win32.whl", hash = "sha256:a02eb8ab2b8f200179b5f62b59757685ae9987996ae549ccf30f983f40602431"},
- {file = "frozenlist-1.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:515e1abc578dd3b275d6a5114030b1330ba044ffba03f94091842852f806f1c1"},
- {file = "frozenlist-1.4.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f0ed05f5079c708fe74bf9027e95125334b6978bf07fd5ab923e9e55e5fbb9d3"},
- {file = "frozenlist-1.4.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ca265542ca427bf97aed183c1676e2a9c66942e822b14dc6e5f42e038f92a503"},
- {file = "frozenlist-1.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:491e014f5c43656da08958808588cc6c016847b4360e327a62cb308c791bd2d9"},
- {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17ae5cd0f333f94f2e03aaf140bb762c64783935cc764ff9c82dff626089bebf"},
- {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1e78fb68cf9c1a6aa4a9a12e960a5c9dfbdb89b3695197aa7064705662515de2"},
- {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5655a942f5f5d2c9ed93d72148226d75369b4f6952680211972a33e59b1dfdc"},
- {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c11b0746f5d946fecf750428a95f3e9ebe792c1ee3b1e96eeba145dc631a9672"},
- {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e66d2a64d44d50d2543405fb183a21f76b3b5fd16f130f5c99187c3fb4e64919"},
- {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:88f7bc0fcca81f985f78dd0fa68d2c75abf8272b1f5c323ea4a01a4d7a614efc"},
- {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5833593c25ac59ede40ed4de6d67eb42928cca97f26feea219f21d0ed0959b79"},
- {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:fec520865f42e5c7f050c2a79038897b1c7d1595e907a9e08e3353293ffc948e"},
- {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:b826d97e4276750beca7c8f0f1a4938892697a6bcd8ec8217b3312dad6982781"},
- {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ceb6ec0a10c65540421e20ebd29083c50e6d1143278746a4ef6bcf6153171eb8"},
- {file = "frozenlist-1.4.0-cp38-cp38-win32.whl", hash = "sha256:2b8bcf994563466db019fab287ff390fffbfdb4f905fc77bc1c1d604b1c689cc"},
- {file = "frozenlist-1.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:a6c8097e01886188e5be3e6b14e94ab365f384736aa1fca6a0b9e35bd4a30bc7"},
- {file = "frozenlist-1.4.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:6c38721585f285203e4b4132a352eb3daa19121a035f3182e08e437cface44bf"},
- {file = "frozenlist-1.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a0c6da9aee33ff0b1a451e867da0c1f47408112b3391dd43133838339e410963"},
- {file = "frozenlist-1.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:93ea75c050c5bb3d98016b4ba2497851eadf0ac154d88a67d7a6816206f6fa7f"},
- {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f61e2dc5ad442c52b4887f1fdc112f97caeff4d9e6ebe78879364ac59f1663e1"},
- {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa384489fefeb62321b238e64c07ef48398fe80f9e1e6afeff22e140e0850eef"},
- {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:10ff5faaa22786315ef57097a279b833ecab1a0bfb07d604c9cbb1c4cdc2ed87"},
- {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:007df07a6e3eb3e33e9a1fe6a9db7af152bbd8a185f9aaa6ece10a3529e3e1c6"},
- {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f4f399d28478d1f604c2ff9119907af9726aed73680e5ed1ca634d377abb087"},
- {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c5374b80521d3d3f2ec5572e05adc94601985cc526fb276d0c8574a6d749f1b3"},
- {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ce31ae3e19f3c902de379cf1323d90c649425b86de7bbdf82871b8a2a0615f3d"},
- {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7211ef110a9194b6042449431e08c4d80c0481e5891e58d429df5899690511c2"},
- {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:556de4430ce324c836789fa4560ca62d1591d2538b8ceb0b4f68fb7b2384a27a"},
- {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:7645a8e814a3ee34a89c4a372011dcd817964ce8cb273c8ed6119d706e9613e3"},
- {file = "frozenlist-1.4.0-cp39-cp39-win32.whl", hash = "sha256:19488c57c12d4e8095a922f328df3f179c820c212940a498623ed39160bc3c2f"},
- {file = "frozenlist-1.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:6221d84d463fb110bdd7619b69cb43878a11d51cbb9394ae3105d082d5199167"},
- {file = "frozenlist-1.4.0.tar.gz", hash = "sha256:09163bdf0b2907454042edb19f887c6d33806adc71fbd54afc14908bfdc22251"},
+ {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f9aa1878d1083b276b0196f2dfbe00c9b7e752475ed3b682025ff20c1c1f51ac"},
+ {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29acab3f66f0f24674b7dc4736477bcd4bc3ad4b896f5f45379a67bce8b96868"},
+ {file = "frozenlist-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:74fb4bee6880b529a0c6560885fce4dc95936920f9f20f53d99a213f7bf66776"},
+ {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:590344787a90ae57d62511dd7c736ed56b428f04cd8c161fcc5e7232c130c69a"},
+ {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:068b63f23b17df8569b7fdca5517edef76171cf3897eb68beb01341131fbd2ad"},
+ {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c849d495bf5154cd8da18a9eb15db127d4dba2968d88831aff6f0331ea9bd4c"},
+ {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9750cc7fe1ae3b1611bb8cfc3f9ec11d532244235d75901fb6b8e42ce9229dfe"},
+ {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9b2de4cf0cdd5bd2dee4c4f63a653c61d2408055ab77b151c1957f221cabf2a"},
+ {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0633c8d5337cb5c77acbccc6357ac49a1770b8c487e5b3505c57b949b4b82e98"},
+ {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:27657df69e8801be6c3638054e202a135c7f299267f1a55ed3a598934f6c0d75"},
+ {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:f9a3ea26252bd92f570600098783d1371354d89d5f6b7dfd87359d669f2109b5"},
+ {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4f57dab5fe3407b6c0c1cc907ac98e8a189f9e418f3b6e54d65a718aaafe3950"},
+ {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e02a0e11cf6597299b9f3bbd3f93d79217cb90cfd1411aec33848b13f5c656cc"},
+ {file = "frozenlist-1.4.1-cp310-cp310-win32.whl", hash = "sha256:a828c57f00f729620a442881cc60e57cfcec6842ba38e1b19fd3e47ac0ff8dc1"},
+ {file = "frozenlist-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:f56e2333dda1fe0f909e7cc59f021eba0d2307bc6f012a1ccf2beca6ba362439"},
+ {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a0cb6f11204443f27a1628b0e460f37fb30f624be6051d490fa7d7e26d4af3d0"},
+ {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b46c8ae3a8f1f41a0d2ef350c0b6e65822d80772fe46b653ab6b6274f61d4a49"},
+ {file = "frozenlist-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fde5bd59ab5357e3853313127f4d3565fc7dad314a74d7b5d43c22c6a5ed2ced"},
+ {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:722e1124aec435320ae01ee3ac7bec11a5d47f25d0ed6328f2273d287bc3abb0"},
+ {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2471c201b70d58a0f0c1f91261542a03d9a5e088ed3dc6c160d614c01649c106"},
+ {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c757a9dd70d72b076d6f68efdbb9bc943665ae954dad2801b874c8c69e185068"},
+ {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f146e0911cb2f1da549fc58fc7bcd2b836a44b79ef871980d605ec392ff6b0d2"},
+ {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9c515e7914626b2a2e1e311794b4c35720a0be87af52b79ff8e1429fc25f19"},
+ {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c302220494f5c1ebeb0912ea782bcd5e2f8308037b3c7553fad0e48ebad6ad82"},
+ {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:442acde1e068288a4ba7acfe05f5f343e19fac87bfc96d89eb886b0363e977ec"},
+ {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:1b280e6507ea8a4fa0c0a7150b4e526a8d113989e28eaaef946cc77ffd7efc0a"},
+ {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:fe1a06da377e3a1062ae5fe0926e12b84eceb8a50b350ddca72dc85015873f74"},
+ {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:db9e724bebd621d9beca794f2a4ff1d26eed5965b004a97f1f1685a173b869c2"},
+ {file = "frozenlist-1.4.1-cp311-cp311-win32.whl", hash = "sha256:e774d53b1a477a67838a904131c4b0eef6b3d8a651f8b138b04f748fccfefe17"},
+ {file = "frozenlist-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb3c2db03683b5767dedb5769b8a40ebb47d6f7f45b1b3e3b4b51ec8ad9d9825"},
+ {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1979bc0aeb89b33b588c51c54ab0161791149f2461ea7c7c946d95d5f93b56ae"},
+ {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cc7b01b3754ea68a62bd77ce6020afaffb44a590c2289089289363472d13aedb"},
+ {file = "frozenlist-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9c92be9fd329ac801cc420e08452b70e7aeab94ea4233a4804f0915c14eba9b"},
+ {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3894db91f5a489fc8fa6a9991820f368f0b3cbdb9cd8849547ccfab3392d86"},
+ {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba60bb19387e13597fb059f32cd4d59445d7b18b69a745b8f8e5db0346f33480"},
+ {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8aefbba5f69d42246543407ed2461db31006b0f76c4e32dfd6f42215a2c41d09"},
+ {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780d3a35680ced9ce682fbcf4cb9c2bad3136eeff760ab33707b71db84664e3a"},
+ {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9acbb16f06fe7f52f441bb6f413ebae6c37baa6ef9edd49cdd567216da8600cd"},
+ {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23b701e65c7b36e4bf15546a89279bd4d8675faabc287d06bbcfac7d3c33e1e6"},
+ {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3e0153a805a98f5ada7e09826255ba99fb4f7524bb81bf6b47fb702666484ae1"},
+ {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:dd9b1baec094d91bf36ec729445f7769d0d0cf6b64d04d86e45baf89e2b9059b"},
+ {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:1a4471094e146b6790f61b98616ab8e44f72661879cc63fa1049d13ef711e71e"},
+ {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5667ed53d68d91920defdf4035d1cdaa3c3121dc0b113255124bcfada1cfa1b8"},
+ {file = "frozenlist-1.4.1-cp312-cp312-win32.whl", hash = "sha256:beee944ae828747fd7cb216a70f120767fc9f4f00bacae8543c14a6831673f89"},
+ {file = "frozenlist-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:64536573d0a2cb6e625cf309984e2d873979709f2cf22839bf2d61790b448ad5"},
+ {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20b51fa3f588ff2fe658663db52a41a4f7aa6c04f6201449c6c7c476bd255c0d"},
+ {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:410478a0c562d1a5bcc2f7ea448359fcb050ed48b3c6f6f4f18c313a9bdb1826"},
+ {file = "frozenlist-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c6321c9efe29975232da3bd0af0ad216800a47e93d763ce64f291917a381b8eb"},
+ {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48f6a4533887e189dae092f1cf981f2e3885175f7a0f33c91fb5b7b682b6bab6"},
+ {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6eb73fa5426ea69ee0e012fb59cdc76a15b1283d6e32e4f8dc4482ec67d1194d"},
+ {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fbeb989b5cc29e8daf7f976b421c220f1b8c731cbf22b9130d8815418ea45887"},
+ {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32453c1de775c889eb4e22f1197fe3bdfe457d16476ea407472b9442e6295f7a"},
+ {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:693945278a31f2086d9bf3df0fe8254bbeaef1fe71e1351c3bd730aa7d31c41b"},
+ {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1d0ce09d36d53bbbe566fe296965b23b961764c0bcf3ce2fa45f463745c04701"},
+ {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3a670dc61eb0d0eb7080890c13de3066790f9049b47b0de04007090807c776b0"},
+ {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:dca69045298ce5c11fd539682cff879cc1e664c245d1c64da929813e54241d11"},
+ {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a06339f38e9ed3a64e4c4e43aec7f59084033647f908e4259d279a52d3757d09"},
+ {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b7f2f9f912dca3934c1baec2e4585a674ef16fe00218d833856408c48d5beee7"},
+ {file = "frozenlist-1.4.1-cp38-cp38-win32.whl", hash = "sha256:e7004be74cbb7d9f34553a5ce5fb08be14fb33bc86f332fb71cbe5216362a497"},
+ {file = "frozenlist-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:5a7d70357e7cee13f470c7883a063aae5fe209a493c57d86eb7f5a6f910fae09"},
+ {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bfa4a17e17ce9abf47a74ae02f32d014c5e9404b6d9ac7f729e01562bbee601e"},
+ {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b7e3ed87d4138356775346e6845cccbe66cd9e207f3cd11d2f0b9fd13681359d"},
+ {file = "frozenlist-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c99169d4ff810155ca50b4da3b075cbde79752443117d89429595c2e8e37fed8"},
+ {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edb678da49d9f72c9f6c609fbe41a5dfb9a9282f9e6a2253d5a91e0fc382d7c0"},
+ {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6db4667b187a6742b33afbbaf05a7bc551ffcf1ced0000a571aedbb4aa42fc7b"},
+ {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55fdc093b5a3cb41d420884cdaf37a1e74c3c37a31f46e66286d9145d2063bd0"},
+ {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82e8211d69a4f4bc360ea22cd6555f8e61a1bd211d1d5d39d3d228b48c83a897"},
+ {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89aa2c2eeb20957be2d950b85974b30a01a762f3308cd02bb15e1ad632e22dc7"},
+ {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9d3e0c25a2350080e9319724dede4f31f43a6c9779be48021a7f4ebde8b2d742"},
+ {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7268252af60904bf52c26173cbadc3a071cece75f873705419c8681f24d3edea"},
+ {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:0c250a29735d4f15321007fb02865f0e6b6a41a6b88f1f523ca1596ab5f50bd5"},
+ {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:96ec70beabbd3b10e8bfe52616a13561e58fe84c0101dd031dc78f250d5128b9"},
+ {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:23b2d7679b73fe0e5a4560b672a39f98dfc6f60df63823b0a9970525325b95f6"},
+ {file = "frozenlist-1.4.1-cp39-cp39-win32.whl", hash = "sha256:a7496bfe1da7fb1a4e1cc23bb67c58fab69311cc7d32b5a99c2007b4b2a0e932"},
+ {file = "frozenlist-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e6a20a581f9ce92d389a8c7d7c3dd47c81fd5d6e655c8dddf341e14aa48659d0"},
+ {file = "frozenlist-1.4.1-py3-none-any.whl", hash = "sha256:04ced3e6a46b4cfffe20f9ae482818e34eba9b5fb0ce4056e4cc9b6e212d09b7"},
+ {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"},
]
[[package]]
name = "fsspec"
-version = "2023.10.0"
+version = "2024.2.0"
description = "File-system specification"
optional = false
python-versions = ">=3.8"
files = [
- {file = "fsspec-2023.10.0-py3-none-any.whl", hash = "sha256:346a8f024efeb749d2a5fca7ba8854474b1ff9af7c3faaf636a4548781136529"},
- {file = "fsspec-2023.10.0.tar.gz", hash = "sha256:330c66757591df346ad3091a53bd907e15348c2ba17d63fd54f5c39c4457d2a5"},
+ {file = "fsspec-2024.2.0-py3-none-any.whl", hash = "sha256:817f969556fa5916bc682e02ca2045f96ff7f586d45110fcb76022063ad2c7d8"},
+ {file = "fsspec-2024.2.0.tar.gz", hash = "sha256:b6ad1a679f760dda52b1168c859d01b7b80648ea6f7f7c7f5a8a91dc3f3ecb84"},
]
[package.dependencies]
aiohttp = {version = "<4.0.0a0 || >4.0.0a0,<4.0.0a1 || >4.0.0a1", optional = true, markers = "extra == \"http\""}
-requests = {version = "*", optional = true, markers = "extra == \"http\""}
[package.extras]
abfs = ["adlfs"]
@@ -1104,7 +1133,7 @@ github = ["requests"]
gs = ["gcsfs"]
gui = ["panel"]
hdfs = ["pyarrow (>=1)"]
-http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "requests"]
+http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)"]
libarchive = ["libarchive-c"]
oci = ["ocifs"]
s3 = ["s3fs"]
@@ -1146,35 +1175,92 @@ smmap = ">=3.0.1,<6"
[[package]]
name = "gitpython"
-version = "3.1.40"
+version = "3.1.43"
description = "GitPython is a Python library used to interact with Git repositories"
optional = false
python-versions = ">=3.7"
files = [
- {file = "GitPython-3.1.40-py3-none-any.whl", hash = "sha256:cf14627d5a8049ffbf49915732e5eddbe8134c3bdb9d476e6182b676fc573f8a"},
- {file = "GitPython-3.1.40.tar.gz", hash = "sha256:22b126e9ffb671fdd0c129796343a02bf67bf2994b35449ffc9321aa755e18a4"},
+ {file = "GitPython-3.1.43-py3-none-any.whl", hash = "sha256:eec7ec56b92aad751f9912a73404bc02ba212a23adb2c7098ee668417051a1ff"},
+ {file = "GitPython-3.1.43.tar.gz", hash = "sha256:35f314a9f878467f5453cc1fee295c3e18e52f1b99f10f6cf5b1682e968a9e7c"},
]
[package.dependencies]
gitdb = ">=4.0.1,<5"
[package.extras]
-test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest", "pytest-cov", "pytest-instafail", "pytest-subtests", "pytest-sugar"]
+doc = ["sphinx (==4.3.2)", "sphinx-autodoc-typehints", "sphinx-rtd-theme", "sphinxcontrib-applehelp (>=1.0.2,<=1.0.4)", "sphinxcontrib-devhelp (==1.0.2)", "sphinxcontrib-htmlhelp (>=2.0.0,<=2.0.1)", "sphinxcontrib-qthelp (==1.0.3)", "sphinxcontrib-serializinghtml (==1.1.5)"]
+test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions"]
+
+[[package]]
+name = "h11"
+version = "0.14.0"
+description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"},
+ {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"},
+]
+
+[[package]]
+name = "httpcore"
+version = "1.0.5"
+description = "A minimal low-level HTTP client."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"},
+ {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"},
+]
+
+[package.dependencies]
+certifi = "*"
+h11 = ">=0.13,<0.15"
+
+[package.extras]
+asyncio = ["anyio (>=4.0,<5.0)"]
+http2 = ["h2 (>=3,<5)"]
+socks = ["socksio (==1.*)"]
+trio = ["trio (>=0.22.0,<0.26.0)"]
+
+[[package]]
+name = "httpx"
+version = "0.27.0"
+description = "The next generation HTTP client."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"},
+ {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"},
+]
+
+[package.dependencies]
+anyio = "*"
+certifi = "*"
+httpcore = "==1.*"
+idna = "*"
+sniffio = "*"
+
+[package.extras]
+brotli = ["brotli", "brotlicffi"]
+cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
+http2 = ["h2 (>=3,<5)"]
+socks = ["socksio (==1.*)"]
[[package]]
name = "huggingface-hub"
-version = "0.17.3"
+version = "0.22.2"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false
python-versions = ">=3.8.0"
files = [
- {file = "huggingface_hub-0.17.3-py3-none-any.whl", hash = "sha256:545eb3665f6ac587add946e73984148f2ea5c7877eac2e845549730570c1933a"},
- {file = "huggingface_hub-0.17.3.tar.gz", hash = "sha256:40439632b211311f788964602bf8b0d9d6b7a2314fba4e8d67b2ce3ecea0e3fd"},
+ {file = "huggingface_hub-0.22.2-py3-none-any.whl", hash = "sha256:3429e25f38ccb834d310804a3b711e7e4953db5a9e420cc147a5e194ca90fd17"},
+ {file = "huggingface_hub-0.22.2.tar.gz", hash = "sha256:32e9a9a6843c92f253ff9ca16b9985def4d80a93fb357af5353f770ef74a81be"},
]
[package.dependencies]
filelock = "*"
-fsspec = "*"
+fsspec = ">=2023.5.0"
packaging = ">=20.9"
pyyaml = ">=5.1"
requests = "*"
@@ -1182,27 +1268,28 @@ tqdm = ">=4.42.1"
typing-extensions = ">=3.7.4.3"
[package.extras]
-all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (==23.7)", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (<2.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"]
+all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
cli = ["InquirerPy (==0.3.4)"]
-dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (==23.7)", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (<2.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"]
-docs = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (==23.7)", "gradio", "hf-doc-builder", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (<2.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)", "watchdog"]
+dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
-inference = ["aiohttp", "pydantic (<2.0)"]
-quality = ["black (==23.7)", "mypy (==1.5.1)", "ruff (>=0.0.241)"]
+hf-transfer = ["hf-transfer (>=0.1.4)"]
+inference = ["aiohttp", "minijinja (>=1.0)"]
+quality = ["mypy (==1.5.1)", "ruff (>=0.3.0)"]
tensorflow = ["graphviz", "pydot", "tensorflow"]
-testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (<2.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
-torch = ["torch"]
-typing = ["pydantic (<2.0)", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
+tensorflow-testing = ["keras (<3.0)", "tensorflow"]
+testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "minijinja (>=1.0)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
+torch = ["safetensors", "torch"]
+typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
[[package]]
name = "idna"
-version = "3.4"
+version = "3.6"
description = "Internationalized Domain Names in Applications (IDNA)"
optional = false
python-versions = ">=3.5"
files = [
- {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"},
- {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"},
+ {file = "idna-3.6-py3-none-any.whl", hash = "sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f"},
+ {file = "idna-3.6.tar.gz", hash = "sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca"},
]
[[package]]
@@ -1218,32 +1305,32 @@ files = [
[[package]]
name = "importlib-metadata"
-version = "6.8.0"
+version = "7.1.0"
description = "Read metadata from Python packages"
optional = false
python-versions = ">=3.8"
files = [
- {file = "importlib_metadata-6.8.0-py3-none-any.whl", hash = "sha256:3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb"},
- {file = "importlib_metadata-6.8.0.tar.gz", hash = "sha256:dbace7892d8c0c4ac1ad096662232f831d4e64f4c4545bd53016a3e9d4654743"},
+ {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"},
+ {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"},
]
[package.dependencies]
zipp = ">=0.5"
[package.extras]
-docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
+docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
perf = ["ipython"]
-testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"]
+testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"]
[[package]]
name = "importlib-resources"
-version = "6.1.0"
+version = "6.4.0"
description = "Read resources from Python packages"
optional = false
python-versions = ">=3.8"
files = [
- {file = "importlib_resources-6.1.0-py3-none-any.whl", hash = "sha256:aa50258bbfa56d4e33fbd8aa3ef48ded10d1735f11532b8df95388cc6bdb7e83"},
- {file = "importlib_resources-6.1.0.tar.gz", hash = "sha256:9d48dcccc213325e810fd723e7fbb45ccb39f6cf5c31f00cf2b965f5f10f3cb9"},
+ {file = "importlib_resources-6.4.0-py3-none-any.whl", hash = "sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c"},
+ {file = "importlib_resources-6.4.0.tar.gz", hash = "sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145"},
]
[package.dependencies]
@@ -1251,7 +1338,7 @@ zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""}
[package.extras]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
-testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-ruff", "zipp (>=3.17)"]
+testing = ["jaraco.test (>=5.4)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"]
[[package]]
name = "iniconfig"
@@ -1266,13 +1353,13 @@ files = [
[[package]]
name = "ipykernel"
-version = "6.26.0"
+version = "6.29.4"
description = "IPython Kernel for Jupyter"
optional = false
python-versions = ">=3.8"
files = [
- {file = "ipykernel-6.26.0-py3-none-any.whl", hash = "sha256:3ba3dc97424b87b31bb46586b5167b3161b32d7820b9201a9e698c71e271602c"},
- {file = "ipykernel-6.26.0.tar.gz", hash = "sha256:553856658eb8430bbe9653ea041a41bff63e9606fc4628873fc92a6cf3abd404"},
+ {file = "ipykernel-6.29.4-py3-none-any.whl", hash = "sha256:1181e653d95c6808039c509ef8e67c4126b3b3af7781496c7cbfb5ed938a27da"},
+ {file = "ipykernel-6.29.4.tar.gz", hash = "sha256:3d44070060f9475ac2092b760123fadf105d2e2493c24848b6691a7c4f42af5c"},
]
[package.dependencies]
@@ -1286,7 +1373,7 @@ matplotlib-inline = ">=0.1"
nest-asyncio = "*"
packaging = "*"
psutil = "*"
-pyzmq = ">=20"
+pyzmq = ">=24"
tornado = ">=6.1"
traitlets = ">=5.4.0"
@@ -1295,7 +1382,7 @@ cov = ["coverage[toml]", "curio", "matplotlib", "pytest-cov", "trio"]
docs = ["myst-parser", "pydata-sphinx-theme", "sphinx", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "trio"]
pyqt5 = ["pyqt5"]
pyside6 = ["pyside6"]
-test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio", "pytest-cov", "pytest-timeout"]
+test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio (>=0.23.5)", "pytest-cov", "pytest-timeout"]
[[package]]
name = "ipython"
@@ -1336,34 +1423,23 @@ qtconsole = ["qtconsole"]
test = ["pytest (<7.1)", "pytest-asyncio", "testpath"]
test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.21)", "pandas", "pytest (<7.1)", "pytest-asyncio", "testpath", "trio"]
-[[package]]
-name = "ipython-genutils"
-version = "0.2.0"
-description = "Vestigial utilities from IPython"
-optional = false
-python-versions = "*"
-files = [
- {file = "ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8"},
- {file = "ipython_genutils-0.2.0.tar.gz", hash = "sha256:eb2e116e75ecef9d4d228fdc66af54269afa26ab4463042e33785b887c628ba8"},
-]
-
[[package]]
name = "ipywidgets"
-version = "8.1.1"
+version = "8.1.2"
description = "Jupyter interactive widgets"
optional = false
python-versions = ">=3.7"
files = [
- {file = "ipywidgets-8.1.1-py3-none-any.whl", hash = "sha256:2b88d728656aea3bbfd05d32c747cfd0078f9d7e159cf982433b58ad717eed7f"},
- {file = "ipywidgets-8.1.1.tar.gz", hash = "sha256:40211efb556adec6fa450ccc2a77d59ca44a060f4f9f136833df59c9f538e6e8"},
+ {file = "ipywidgets-8.1.2-py3-none-any.whl", hash = "sha256:bbe43850d79fb5e906b14801d6c01402857996864d1e5b6fa62dd2ee35559f60"},
+ {file = "ipywidgets-8.1.2.tar.gz", hash = "sha256:d0b9b41e49bae926a866e613a39b0f0097745d2b9f1f3dd406641b4a57ec42c9"},
]
[package.dependencies]
comm = ">=0.1.3"
ipython = ">=6.1.0"
-jupyterlab-widgets = ">=3.0.9,<3.1.0"
+jupyterlab-widgets = ">=3.0.10,<3.1.0"
traitlets = ">=4.3.1"
-widgetsnbextension = ">=4.0.9,<4.1.0"
+widgetsnbextension = ">=4.0.10,<4.1.0"
[package.extras]
test = ["ipykernel", "jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"]
@@ -1435,13 +1511,13 @@ testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"]
[[package]]
name = "jinja2"
-version = "3.1.2"
+version = "3.1.3"
description = "A very fast and expressive template engine."
optional = false
python-versions = ">=3.7"
files = [
- {file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"},
- {file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"},
+ {file = "Jinja2-3.1.3-py3-none-any.whl", hash = "sha256:7d6d50dd97d52cbc355597bd845fabfbac3f551e1f99619e39a35ce8c370b5fa"},
+ {file = "Jinja2-3.1.3.tar.gz", hash = "sha256:ac8bd6544d4bb2c9792bf3a159e80bba8fda7f07e81bc3aed565432d5925ba90"},
]
[package.dependencies]
@@ -1452,18 +1528,15 @@ i18n = ["Babel (>=2.7)"]
[[package]]
name = "json5"
-version = "0.9.14"
+version = "0.9.24"
description = "A Python implementation of the JSON5 data format."
optional = false
-python-versions = "*"
+python-versions = ">=3.8"
files = [
- {file = "json5-0.9.14-py2.py3-none-any.whl", hash = "sha256:740c7f1b9e584a468dbb2939d8d458db3427f2c93ae2139d05f47e453eae964f"},
- {file = "json5-0.9.14.tar.gz", hash = "sha256:9ed66c3a6ca3510a976a9ef9b8c0787de24802724ab1860bc0153c7fdd589b02"},
+ {file = "json5-0.9.24-py3-none-any.whl", hash = "sha256:4ca101fd5c7cb47960c055ef8f4d0e31e15a7c6c48c3b6f1473fc83b6c462a13"},
+ {file = "json5-0.9.24.tar.gz", hash = "sha256:0c638399421da959a20952782800e5c1a78c14e08e1dc9738fa10d8ec14d58c8"},
]
-[package.extras]
-dev = ["hypothesis"]
-
[[package]]
name = "jsonpointer"
version = "2.4"
@@ -1477,13 +1550,13 @@ files = [
[[package]]
name = "jsonschema"
-version = "4.19.1"
+version = "4.21.1"
description = "An implementation of JSON Schema validation for Python"
optional = false
python-versions = ">=3.8"
files = [
- {file = "jsonschema-4.19.1-py3-none-any.whl", hash = "sha256:cd5f1f9ed9444e554b38ba003af06c0a8c2868131e56bfbef0550fb450c0330e"},
- {file = "jsonschema-4.19.1.tar.gz", hash = "sha256:ec84cc37cfa703ef7cd4928db24f9cb31428a5d0fa77747b8b51a847458e0bbf"},
+ {file = "jsonschema-4.21.1-py3-none-any.whl", hash = "sha256:7996507afae316306f9e2290407761157c6f78002dcf7419acb99822143d1c6f"},
+ {file = "jsonschema-4.21.1.tar.gz", hash = "sha256:85727c00279f5fa6bedbe6238d2aa6403bedd8b4864ab11207d07df3cc1b2ee5"},
]
[package.dependencies]
@@ -1508,18 +1581,18 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-
[[package]]
name = "jsonschema-specifications"
-version = "2023.7.1"
+version = "2023.12.1"
description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry"
optional = false
python-versions = ">=3.8"
files = [
- {file = "jsonschema_specifications-2023.7.1-py3-none-any.whl", hash = "sha256:05adf340b659828a004220a9613be00fa3f223f2b82002e273dee62fd50524b1"},
- {file = "jsonschema_specifications-2023.7.1.tar.gz", hash = "sha256:c91a50404e88a1f6ba40636778e2ee08f6e24c5613fe4c53ac24578a5a7f72bb"},
+ {file = "jsonschema_specifications-2023.12.1-py3-none-any.whl", hash = "sha256:87e4fdf3a94858b8a2ba2778d9ba57d8a9cafca7c7489c46ba0d30a8bc6a9c3c"},
+ {file = "jsonschema_specifications-2023.12.1.tar.gz", hash = "sha256:48a76787b3e70f5ed53f1160d2b81f586e4ca6d1548c5de7085d1682674764cc"},
]
[package.dependencies]
importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""}
-referencing = ">=0.28.0"
+referencing = ">=0.31.0"
[[package]]
name = "jupyter"
@@ -1543,13 +1616,13 @@ qtconsole = "*"
[[package]]
name = "jupyter-client"
-version = "8.5.0"
+version = "8.6.1"
description = "Jupyter protocol implementation and client libraries"
optional = false
python-versions = ">=3.8"
files = [
- {file = "jupyter_client-8.5.0-py3-none-any.whl", hash = "sha256:c3877aac7257ec68d79b5c622ce986bd2a992ca42f6ddc9b4dd1da50e89f7028"},
- {file = "jupyter_client-8.5.0.tar.gz", hash = "sha256:e8754066510ce456358df363f97eae64b50860f30dc1fe8c6771440db3be9a63"},
+ {file = "jupyter_client-8.6.1-py3-none-any.whl", hash = "sha256:3b7bd22f058434e3b9a7ea4b1500ed47de2713872288c0d511d19926f99b459f"},
+ {file = "jupyter_client-8.6.1.tar.gz", hash = "sha256:e842515e2bab8e19186d89fdfea7abd15e39dd581f94e399f00e2af5a1652d3f"},
]
[package.dependencies]
@@ -1590,13 +1663,13 @@ test = ["flaky", "pexpect", "pytest"]
[[package]]
name = "jupyter-core"
-version = "5.4.0"
+version = "5.7.2"
description = "Jupyter core package. A base package on which Jupyter projects rely."
optional = false
python-versions = ">=3.8"
files = [
- {file = "jupyter_core-5.4.0-py3-none-any.whl", hash = "sha256:66e252f675ac04dcf2feb6ed4afb3cd7f68cf92f483607522dc251f32d471571"},
- {file = "jupyter_core-5.4.0.tar.gz", hash = "sha256:e4b98344bb94ee2e3e6c4519a97d001656009f9cb2b7f2baf15b3c205770011d"},
+ {file = "jupyter_core-5.7.2-py3-none-any.whl", hash = "sha256:4f7315d2f6b4bcf2e3e7cb6e46772eba760ae459cd1f59d29eb57b0a01bd7409"},
+ {file = "jupyter_core-5.7.2.tar.gz", hash = "sha256:aa5f8d32bbf6b431ac830496da7392035d6f61b4f54872f15c4bd2a9c3f536d9"},
]
[package.dependencies]
@@ -1605,18 +1678,18 @@ pywin32 = {version = ">=300", markers = "sys_platform == \"win32\" and platform_
traitlets = ">=5.3"
[package.extras]
-docs = ["myst-parser", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"]
-test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"]
+docs = ["myst-parser", "pydata-sphinx-theme", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"]
+test = ["ipykernel", "pre-commit", "pytest (<8)", "pytest-cov", "pytest-timeout"]
[[package]]
name = "jupyter-events"
-version = "0.8.0"
+version = "0.10.0"
description = "Jupyter Event System library"
optional = false
python-versions = ">=3.8"
files = [
- {file = "jupyter_events-0.8.0-py3-none-any.whl", hash = "sha256:81f07375c7673ff298bfb9302b4a981864ec64edaed75ca0fe6f850b9b045525"},
- {file = "jupyter_events-0.8.0.tar.gz", hash = "sha256:fda08f0defce5e16930542ce60634ba48e010830d50073c3dfd235759cee77bf"},
+ {file = "jupyter_events-0.10.0-py3-none-any.whl", hash = "sha256:4b72130875e59d57716d327ea70d3ebc3af1944d3717e5a498b8a06c6c159960"},
+ {file = "jupyter_events-0.10.0.tar.gz", hash = "sha256:670b8229d3cc882ec782144ed22e0d29e1c2d639263f92ca8383e66682845e22"},
]
[package.dependencies]
@@ -1635,13 +1708,13 @@ test = ["click", "pre-commit", "pytest (>=7.0)", "pytest-asyncio (>=0.19.0)", "p
[[package]]
name = "jupyter-lsp"
-version = "2.2.0"
+version = "2.2.4"
description = "Multi-Language Server WebSocket proxy for Jupyter Notebook/Lab server"
optional = false
python-versions = ">=3.8"
files = [
- {file = "jupyter-lsp-2.2.0.tar.gz", hash = "sha256:8ebbcb533adb41e5d635eb8fe82956b0aafbf0fd443b6c4bfa906edeeb8635a1"},
- {file = "jupyter_lsp-2.2.0-py3-none-any.whl", hash = "sha256:9e06b8b4f7dd50300b70dd1a78c0c3b0c3d8fa68e0f2d8a5d1fbab62072aca3f"},
+ {file = "jupyter-lsp-2.2.4.tar.gz", hash = "sha256:5e50033149344065348e688608f3c6d654ef06d9856b67655bd7b6bac9ee2d59"},
+ {file = "jupyter_lsp-2.2.4-py3-none-any.whl", hash = "sha256:da61cb63a16b6dff5eac55c2699cc36eac975645adee02c41bdfc03bf4802e77"},
]
[package.dependencies]
@@ -1650,13 +1723,13 @@ jupyter-server = ">=1.1.2"
[[package]]
name = "jupyter-server"
-version = "2.9.1"
+version = "2.13.0"
description = "The backendāi.e. core services, APIs, and REST endpointsāto Jupyter web applications."
optional = false
python-versions = ">=3.8"
files = [
- {file = "jupyter_server-2.9.1-py3-none-any.whl", hash = "sha256:21ad1a3d455d5a79ce4bef5201925cd17510c17898cf9d54e3ccfb6b12734948"},
- {file = "jupyter_server-2.9.1.tar.gz", hash = "sha256:9ba71be4b9c16e479e4c50c929f8ac4b1015baf90237a08681397a98c76c7e5e"},
+ {file = "jupyter_server-2.13.0-py3-none-any.whl", hash = "sha256:77b2b49c3831fbbfbdb5048cef4350d12946191f833a24e5f83e5f8f4803e97b"},
+ {file = "jupyter_server-2.13.0.tar.gz", hash = "sha256:c80bfb049ea20053c3d9641c2add4848b38073bf79f1729cea1faed32fc1c78e"},
]
[package.dependencies]
@@ -1665,7 +1738,7 @@ argon2-cffi = "*"
jinja2 = "*"
jupyter-client = ">=7.4.4"
jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0"
-jupyter-events = ">=0.6.0"
+jupyter-events = ">=0.9.0"
jupyter-server-terminals = "*"
nbconvert = ">=6.4.4"
nbformat = ">=5.3.0"
@@ -1682,17 +1755,17 @@ websocket-client = "*"
[package.extras]
docs = ["ipykernel", "jinja2", "jupyter-client", "jupyter-server", "myst-parser", "nbformat", "prometheus-client", "pydata-sphinx-theme", "send2trash", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-openapi (>=0.8.0)", "sphinxcontrib-spelling", "sphinxemoji", "tornado", "typing-extensions"]
-test = ["flaky", "ipykernel", "pre-commit", "pytest (>=7.0)", "pytest-console-scripts", "pytest-jupyter[server] (>=0.4)", "pytest-timeout", "requests"]
+test = ["flaky", "ipykernel", "pre-commit", "pytest (>=7.0)", "pytest-console-scripts", "pytest-jupyter[server] (>=0.7)", "pytest-timeout", "requests"]
[[package]]
name = "jupyter-server-terminals"
-version = "0.4.4"
+version = "0.5.3"
description = "A Jupyter Server Extension Providing Terminals."
optional = false
python-versions = ">=3.8"
files = [
- {file = "jupyter_server_terminals-0.4.4-py3-none-any.whl", hash = "sha256:75779164661cec02a8758a5311e18bb8eb70c4e86c6b699403100f1585a12a36"},
- {file = "jupyter_server_terminals-0.4.4.tar.gz", hash = "sha256:57ab779797c25a7ba68e97bcfb5d7740f2b5e8a83b5e8102b10438041a7eac5d"},
+ {file = "jupyter_server_terminals-0.5.3-py3-none-any.whl", hash = "sha256:41ee0d7dc0ebf2809c668e0fc726dfaf258fcd3e769568996ca731b6194ae9aa"},
+ {file = "jupyter_server_terminals-0.5.3.tar.gz", hash = "sha256:5ae0295167220e9ace0edcfdb212afd2b01ee8d179fe6f23c899590e9b8a5269"},
]
[package.dependencies]
@@ -1700,22 +1773,23 @@ pywinpty = {version = ">=2.0.3", markers = "os_name == \"nt\""}
terminado = ">=0.8.3"
[package.extras]
-docs = ["jinja2", "jupyter-server", "mistune (<3.0)", "myst-parser", "nbformat", "packaging", "pydata-sphinx-theme", "sphinxcontrib-github-alt", "sphinxcontrib-openapi", "sphinxcontrib-spelling", "sphinxemoji", "tornado"]
-test = ["coverage", "jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-cov", "pytest-jupyter[server] (>=0.5.3)", "pytest-timeout"]
+docs = ["jinja2", "jupyter-server", "mistune (<4.0)", "myst-parser", "nbformat", "packaging", "pydata-sphinx-theme", "sphinxcontrib-github-alt", "sphinxcontrib-openapi", "sphinxcontrib-spelling", "sphinxemoji", "tornado"]
+test = ["jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-jupyter[server] (>=0.5.3)", "pytest-timeout"]
[[package]]
name = "jupyterlab"
-version = "4.0.7"
+version = "4.1.5"
description = "JupyterLab computational environment"
optional = false
python-versions = ">=3.8"
files = [
- {file = "jupyterlab-4.0.7-py3-none-any.whl", hash = "sha256:08683045117cc495531fdb39c22ababb9aaac6977a45e67cfad20046564c9c7c"},
- {file = "jupyterlab-4.0.7.tar.gz", hash = "sha256:48792efd9f962b2bcda1f87d72168ff122c288b1d97d32109e4a11b33dc862be"},
+ {file = "jupyterlab-4.1.5-py3-none-any.whl", hash = "sha256:3bc843382a25e1ab7bc31d9e39295a9f0463626692b7995597709c0ab236ab2c"},
+ {file = "jupyterlab-4.1.5.tar.gz", hash = "sha256:c9ad75290cb10bfaff3624bf3fbb852319b4cce4c456613f8ebbaa98d03524db"},
]
[package.dependencies]
async-lru = ">=1.0.0"
+httpx = ">=0.25.0"
importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""}
importlib-resources = {version = ">=1.4", markers = "python_version < \"3.9\""}
ipykernel = "*"
@@ -1731,31 +1805,31 @@ tornado = ">=6.2.0"
traitlets = "*"
[package.extras]
-dev = ["black[jupyter] (==23.7.0)", "build", "bump2version", "coverage", "hatch", "pre-commit", "pytest-cov", "ruff (==0.0.286)"]
-docs = ["jsx-lexer", "myst-parser", "pydata-sphinx-theme (>=0.13.0)", "pytest", "pytest-check-links", "pytest-tornasync", "sphinx (>=1.8,<7.2.0)", "sphinx-copybutton"]
-docs-screenshots = ["altair (==5.0.1)", "ipython (==8.14.0)", "ipywidgets (==8.0.6)", "jupyterlab-geojson (==3.4.0)", "jupyterlab-language-pack-zh-cn (==4.0.post0)", "matplotlib (==3.7.1)", "nbconvert (>=7.0.0)", "pandas (==2.0.2)", "scipy (==1.10.1)", "vega-datasets (==0.9.0)"]
+dev = ["build", "bump2version", "coverage", "hatch", "pre-commit", "pytest-cov", "ruff (==0.2.0)"]
+docs = ["jsx-lexer", "myst-parser", "pydata-sphinx-theme (>=0.13.0)", "pytest", "pytest-check-links", "pytest-jupyter", "sphinx (>=1.8,<7.3.0)", "sphinx-copybutton"]
+docs-screenshots = ["altair (==5.2.0)", "ipython (==8.16.1)", "ipywidgets (==8.1.1)", "jupyterlab-geojson (==3.4.0)", "jupyterlab-language-pack-zh-cn (==4.0.post6)", "matplotlib (==3.8.2)", "nbconvert (>=7.0.0)", "pandas (==2.2.0)", "scipy (==1.12.0)", "vega-datasets (==0.9.0)"]
test = ["coverage", "pytest (>=7.0)", "pytest-check-links (>=0.7)", "pytest-console-scripts", "pytest-cov", "pytest-jupyter (>=0.5.3)", "pytest-timeout", "pytest-tornasync", "requests", "requests-cache", "virtualenv"]
[[package]]
name = "jupyterlab-pygments"
-version = "0.2.2"
+version = "0.3.0"
description = "Pygments theme using JupyterLab CSS variables"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
files = [
- {file = "jupyterlab_pygments-0.2.2-py2.py3-none-any.whl", hash = "sha256:2405800db07c9f770863bcf8049a529c3dd4d3e28536638bd7c1c01d2748309f"},
- {file = "jupyterlab_pygments-0.2.2.tar.gz", hash = "sha256:7405d7fde60819d905a9fa8ce89e4cd830e318cdad22a0030f7a901da705585d"},
+ {file = "jupyterlab_pygments-0.3.0-py3-none-any.whl", hash = "sha256:841a89020971da1d8693f1a99997aefc5dc424bb1b251fd6322462a1b8842780"},
+ {file = "jupyterlab_pygments-0.3.0.tar.gz", hash = "sha256:721aca4d9029252b11cfa9d185e5b5af4d54772bb8072f9b7036f4170054d35d"},
]
[[package]]
name = "jupyterlab-server"
-version = "2.25.0"
+version = "2.25.4"
description = "A set of server components for JupyterLab and JupyterLab like applications."
optional = false
python-versions = ">=3.8"
files = [
- {file = "jupyterlab_server-2.25.0-py3-none-any.whl", hash = "sha256:c9f67a98b295c5dee87f41551b0558374e45d449f3edca153dd722140630dcb2"},
- {file = "jupyterlab_server-2.25.0.tar.gz", hash = "sha256:77c2f1f282d610f95e496e20d5bf1d2a7706826dfb7b18f3378ae2870d272fb7"},
+ {file = "jupyterlab_server-2.25.4-py3-none-any.whl", hash = "sha256:eb645ecc8f9b24bac5decc7803b6d5363250e16ec5af814e516bc2c54dd88081"},
+ {file = "jupyterlab_server-2.25.4.tar.gz", hash = "sha256:2098198e1e82e0db982440f9b5136175d73bea2cd42a6480aa6fd502cb23c4f9"},
]
[package.dependencies]
@@ -1771,17 +1845,17 @@ requests = ">=2.31"
[package.extras]
docs = ["autodoc-traits", "jinja2 (<3.2.0)", "mistune (<4)", "myst-parser", "pydata-sphinx-theme", "sphinx", "sphinx-copybutton", "sphinxcontrib-openapi (>0.8)"]
openapi = ["openapi-core (>=0.18.0,<0.19.0)", "ruamel-yaml"]
-test = ["hatch", "ipykernel", "openapi-core (>=0.18.0,<0.19.0)", "openapi-spec-validator (>=0.6.0,<0.7.0)", "pytest (>=7.0)", "pytest-console-scripts", "pytest-cov", "pytest-jupyter[server] (>=0.6.2)", "pytest-timeout", "requests-mock", "ruamel-yaml", "sphinxcontrib-spelling", "strict-rfc3339", "werkzeug"]
+test = ["hatch", "ipykernel", "openapi-core (>=0.18.0,<0.19.0)", "openapi-spec-validator (>=0.6.0,<0.8.0)", "pytest (>=7.0,<8)", "pytest-console-scripts", "pytest-cov", "pytest-jupyter[server] (>=0.6.2)", "pytest-timeout", "requests-mock", "ruamel-yaml", "sphinxcontrib-spelling", "strict-rfc3339", "werkzeug"]
[[package]]
name = "jupyterlab-widgets"
-version = "3.0.9"
+version = "3.0.10"
description = "Jupyter interactive widgets for JupyterLab"
optional = false
python-versions = ">=3.7"
files = [
- {file = "jupyterlab_widgets-3.0.9-py3-none-any.whl", hash = "sha256:3cf5bdf5b897bf3bccf1c11873aa4afd776d7430200f765e0686bd352487b58d"},
- {file = "jupyterlab_widgets-3.0.9.tar.gz", hash = "sha256:6005a4e974c7beee84060fdfba341a3218495046de8ae3ec64888e5fe19fdb4c"},
+ {file = "jupyterlab_widgets-3.0.10-py3-none-any.whl", hash = "sha256:dd61f3ae7a5a7f80299e14585ce6cf3d6925a96c9103c978eda293197730cb64"},
+ {file = "jupyterlab_widgets-3.0.10.tar.gz", hash = "sha256:04f2ac04976727e4f9d0fa91cdc2f1ab860f965e504c29dbd6a65c882c9d04c0"},
]
[[package]]
@@ -1873,61 +1947,71 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
[[package]]
name = "markupsafe"
-version = "2.1.3"
+version = "2.1.5"
description = "Safely add untrusted strings to HTML/XML markup."
optional = false
python-versions = ">=3.7"
files = [
- {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa"},
- {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57"},
- {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f"},
- {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52"},
- {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00"},
- {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6"},
- {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779"},
- {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7"},
- {file = "MarkupSafe-2.1.3-cp310-cp310-win32.whl", hash = "sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431"},
- {file = "MarkupSafe-2.1.3-cp310-cp310-win_amd64.whl", hash = "sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559"},
- {file = "MarkupSafe-2.1.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c"},
- {file = "MarkupSafe-2.1.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575"},
- {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee"},
- {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2"},
- {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9"},
- {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc"},
- {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9"},
- {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"},
- {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"},
- {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"},
- {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"},
- {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"},
- {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"},
- {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e"},
- {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc"},
- {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48"},
- {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155"},
- {file = "MarkupSafe-2.1.3-cp37-cp37m-win32.whl", hash = "sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0"},
- {file = "MarkupSafe-2.1.3-cp37-cp37m-win_amd64.whl", hash = "sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24"},
- {file = "MarkupSafe-2.1.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4"},
- {file = "MarkupSafe-2.1.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0"},
- {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee"},
- {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be"},
- {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e"},
- {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8"},
- {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3"},
- {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d"},
- {file = "MarkupSafe-2.1.3-cp38-cp38-win32.whl", hash = "sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5"},
- {file = "MarkupSafe-2.1.3-cp38-cp38-win_amd64.whl", hash = "sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc"},
- {file = "MarkupSafe-2.1.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198"},
- {file = "MarkupSafe-2.1.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b"},
- {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58"},
- {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e"},
- {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c"},
- {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636"},
- {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea"},
- {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e"},
- {file = "MarkupSafe-2.1.3-cp39-cp39-win32.whl", hash = "sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2"},
- {file = "MarkupSafe-2.1.3-cp39-cp39-win_amd64.whl", hash = "sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba"},
- {file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-win32.whl", hash = "sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-win_amd64.whl", hash = "sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-win32.whl", hash = "sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-win_amd64.whl", hash = "sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-win32.whl", hash = "sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-win_amd64.whl", hash = "sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb"},
+ {file = "MarkupSafe-2.1.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c8b29db45f8fe46ad280a7294f5c3ec36dbac9491f2d1c17345be8e69cc5928f"},
+ {file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec6a563cff360b50eed26f13adc43e61bc0c04d94b8be985e6fb24b81f6dcfdf"},
+ {file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a549b9c31bec33820e885335b451286e2969a2d9e24879f83fe904a5ce59d70a"},
+ {file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4f11aa001c540f62c6166c7726f71f7573b52c68c31f014c25cc7901deea0b52"},
+ {file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7b2e5a267c855eea6b4283940daa6e88a285f5f2a67f2220203786dfa59b37e9"},
+ {file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:2d2d793e36e230fd32babe143b04cec8a8b3eb8a3122d2aceb4a371e6b09b8df"},
+ {file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ce409136744f6521e39fd8e2a24c53fa18ad67aa5bc7c2cf83645cce5b5c4e50"},
+ {file = "MarkupSafe-2.1.5-cp37-cp37m-win32.whl", hash = "sha256:4096e9de5c6fdf43fb4f04c26fb114f61ef0bf2e5604b6ee3019d51b69e8c371"},
+ {file = "MarkupSafe-2.1.5-cp37-cp37m-win_amd64.whl", hash = "sha256:4275d846e41ecefa46e2015117a9f491e57a71ddd59bbead77e904dc02b1bed2"},
+ {file = "MarkupSafe-2.1.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:656f7526c69fac7f600bd1f400991cc282b417d17539a1b228617081106feb4a"},
+ {file = "MarkupSafe-2.1.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:97cafb1f3cbcd3fd2b6fbfb99ae11cdb14deea0736fc2b0952ee177f2b813a46"},
+ {file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f3fbcb7ef1f16e48246f704ab79d79da8a46891e2da03f8783a5b6fa41a9532"},
+ {file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa9db3f79de01457b03d4f01b34cf91bc0048eb2c3846ff26f66687c2f6d16ab"},
+ {file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffee1f21e5ef0d712f9033568f8344d5da8cc2869dbd08d87c84656e6a2d2f68"},
+ {file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5dedb4db619ba5a2787a94d877bc8ffc0566f92a01c0ef214865e54ecc9ee5e0"},
+ {file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:30b600cf0a7ac9234b2638fbc0fb6158ba5bdcdf46aeb631ead21248b9affbc4"},
+ {file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8dd717634f5a044f860435c1d8c16a270ddf0ef8588d4887037c5028b859b0c3"},
+ {file = "MarkupSafe-2.1.5-cp38-cp38-win32.whl", hash = "sha256:daa4ee5a243f0f20d528d939d06670a298dd39b1ad5f8a72a4275124a7819eff"},
+ {file = "MarkupSafe-2.1.5-cp38-cp38-win_amd64.whl", hash = "sha256:619bc166c4f2de5caa5a633b8b7326fbe98e0ccbfacabd87268a2b15ff73a029"},
+ {file = "MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7a68b554d356a91cce1236aa7682dc01df0edba8d043fd1ce607c49dd3c1edcf"},
+ {file = "MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:db0b55e0f3cc0be60c1f19efdde9a637c32740486004f20d1cff53c3c0ece4d2"},
+ {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e53af139f8579a6d5f7b76549125f0d94d7e630761a2111bc431fd820e163b8"},
+ {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:17b950fccb810b3293638215058e432159d2b71005c74371d784862b7e4683f3"},
+ {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c31f53cdae6ecfa91a77820e8b151dba54ab528ba65dfd235c80b086d68a465"},
+ {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bff1b4290a66b490a2f4719358c0cdcd9bafb6b8f061e45c7a2460866bf50c2e"},
+ {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bc1667f8b83f48511b94671e0e441401371dfd0f0a795c7daa4a3cd1dde55bea"},
+ {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5049256f536511ee3f7e1b3f87d1d1209d327e818e6ae1365e8653d7e3abb6a6"},
+ {file = "MarkupSafe-2.1.5-cp39-cp39-win32.whl", hash = "sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf"},
+ {file = "MarkupSafe-2.1.5-cp39-cp39-win_amd64.whl", hash = "sha256:fa173ec60341d6bb97a89f5ea19c85c5643c1e7dedebc22f5181eb73573142c5"},
+ {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"},
]
[[package]]
@@ -2004,149 +2088,161 @@ tests = ["pytest (>=4.6)"]
[[package]]
name = "multidict"
-version = "6.0.4"
+version = "6.0.5"
description = "multidict implementation"
optional = false
python-versions = ">=3.7"
files = [
- {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b1a97283e0c85772d613878028fec909f003993e1007eafa715b24b377cb9b8"},
- {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eeb6dcc05e911516ae3d1f207d4b0520d07f54484c49dfc294d6e7d63b734171"},
- {file = "multidict-6.0.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d6d635d5209b82a3492508cf5b365f3446afb65ae7ebd755e70e18f287b0adf7"},
- {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c048099e4c9e9d615545e2001d3d8a4380bd403e1a0578734e0d31703d1b0c0b"},
- {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea20853c6dbbb53ed34cb4d080382169b6f4554d394015f1bef35e881bf83547"},
- {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:16d232d4e5396c2efbbf4f6d4df89bfa905eb0d4dc5b3549d872ab898451f569"},
- {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36c63aaa167f6c6b04ef2c85704e93af16c11d20de1d133e39de6a0e84582a93"},
- {file = "multidict-6.0.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:64bdf1086b6043bf519869678f5f2757f473dee970d7abf6da91ec00acb9cb98"},
- {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:43644e38f42e3af682690876cff722d301ac585c5b9e1eacc013b7a3f7b696a0"},
- {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7582a1d1030e15422262de9f58711774e02fa80df0d1578995c76214f6954988"},
- {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ddff9c4e225a63a5afab9dd15590432c22e8057e1a9a13d28ed128ecf047bbdc"},
- {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ee2a1ece51b9b9e7752e742cfb661d2a29e7bcdba2d27e66e28a99f1890e4fa0"},
- {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a2e4369eb3d47d2034032a26c7a80fcb21a2cb22e1173d761a162f11e562caa5"},
- {file = "multidict-6.0.4-cp310-cp310-win32.whl", hash = "sha256:574b7eae1ab267e5f8285f0fe881f17efe4b98c39a40858247720935b893bba8"},
- {file = "multidict-6.0.4-cp310-cp310-win_amd64.whl", hash = "sha256:4dcbb0906e38440fa3e325df2359ac6cb043df8e58c965bb45f4e406ecb162cc"},
- {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0dfad7a5a1e39c53ed00d2dd0c2e36aed4650936dc18fd9a1826a5ae1cad6f03"},
- {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:64da238a09d6039e3bd39bb3aee9c21a5e34f28bfa5aa22518581f910ff94af3"},
- {file = "multidict-6.0.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ff959bee35038c4624250473988b24f846cbeb2c6639de3602c073f10410ceba"},
- {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:01a3a55bd90018c9c080fbb0b9f4891db37d148a0a18722b42f94694f8b6d4c9"},
- {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5cb09abb18c1ea940fb99360ea0396f34d46566f157122c92dfa069d3e0e982"},
- {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:666daae833559deb2d609afa4490b85830ab0dfca811a98b70a205621a6109fe"},
- {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11bdf3f5e1518b24530b8241529d2050014c884cf18b6fc69c0c2b30ca248710"},
- {file = "multidict-6.0.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d18748f2d30f94f498e852c67d61261c643b349b9d2a581131725595c45ec6c"},
- {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:458f37be2d9e4c95e2d8866a851663cbc76e865b78395090786f6cd9b3bbf4f4"},
- {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:b1a2eeedcead3a41694130495593a559a668f382eee0727352b9a41e1c45759a"},
- {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7d6ae9d593ef8641544d6263c7fa6408cc90370c8cb2bbb65f8d43e5b0351d9c"},
- {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:5979b5632c3e3534e42ca6ff856bb24b2e3071b37861c2c727ce220d80eee9ed"},
- {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dcfe792765fab89c365123c81046ad4103fcabbc4f56d1c1997e6715e8015461"},
- {file = "multidict-6.0.4-cp311-cp311-win32.whl", hash = "sha256:3601a3cece3819534b11d4efc1eb76047488fddd0c85a3948099d5da4d504636"},
- {file = "multidict-6.0.4-cp311-cp311-win_amd64.whl", hash = "sha256:81a4f0b34bd92df3da93315c6a59034df95866014ac08535fc819f043bfd51f0"},
- {file = "multidict-6.0.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:67040058f37a2a51ed8ea8f6b0e6ee5bd78ca67f169ce6122f3e2ec80dfe9b78"},
- {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:853888594621e6604c978ce2a0444a1e6e70c8d253ab65ba11657659dcc9100f"},
- {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:39ff62e7d0f26c248b15e364517a72932a611a9b75f35b45be078d81bdb86603"},
- {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af048912e045a2dc732847d33821a9d84ba553f5c5f028adbd364dd4765092ac"},
- {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1e8b901e607795ec06c9e42530788c45ac21ef3aaa11dbd0c69de543bfb79a9"},
- {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62501642008a8b9871ddfccbf83e4222cf8ac0d5aeedf73da36153ef2ec222d2"},
- {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:99b76c052e9f1bc0721f7541e5e8c05db3941eb9ebe7b8553c625ef88d6eefde"},
- {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:509eac6cf09c794aa27bcacfd4d62c885cce62bef7b2c3e8b2e49d365b5003fe"},
- {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:21a12c4eb6ddc9952c415f24eef97e3e55ba3af61f67c7bc388dcdec1404a067"},
- {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:5cad9430ab3e2e4fa4a2ef4450f548768400a2ac635841bc2a56a2052cdbeb87"},
- {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ab55edc2e84460694295f401215f4a58597f8f7c9466faec545093045476327d"},
- {file = "multidict-6.0.4-cp37-cp37m-win32.whl", hash = "sha256:5a4dcf02b908c3b8b17a45fb0f15b695bf117a67b76b7ad18b73cf8e92608775"},
- {file = "multidict-6.0.4-cp37-cp37m-win_amd64.whl", hash = "sha256:6ed5f161328b7df384d71b07317f4d8656434e34591f20552c7bcef27b0ab88e"},
- {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5fc1b16f586f049820c5c5b17bb4ee7583092fa0d1c4e28b5239181ff9532e0c"},
- {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1502e24330eb681bdaa3eb70d6358e818e8e8f908a22a1851dfd4e15bc2f8161"},
- {file = "multidict-6.0.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b692f419760c0e65d060959df05f2a531945af31fda0c8a3b3195d4efd06de11"},
- {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45e1ecb0379bfaab5eef059f50115b54571acfbe422a14f668fc8c27ba410e7e"},
- {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ddd3915998d93fbcd2566ddf9cf62cdb35c9e093075f862935573d265cf8f65d"},
- {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:59d43b61c59d82f2effb39a93c48b845efe23a3852d201ed2d24ba830d0b4cf2"},
- {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc8e1d0c705233c5dd0c5e6460fbad7827d5d36f310a0fadfd45cc3029762258"},
- {file = "multidict-6.0.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6aa0418fcc838522256761b3415822626f866758ee0bc6632c9486b179d0b52"},
- {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6748717bb10339c4760c1e63da040f5f29f5ed6e59d76daee30305894069a660"},
- {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4d1a3d7ef5e96b1c9e92f973e43aa5e5b96c659c9bc3124acbbd81b0b9c8a951"},
- {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4372381634485bec7e46718edc71528024fcdc6f835baefe517b34a33c731d60"},
- {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:fc35cb4676846ef752816d5be2193a1e8367b4c1397b74a565a9d0389c433a1d"},
- {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4b9d9e4e2b37daddb5c23ea33a3417901fa7c7b3dee2d855f63ee67a0b21e5b1"},
- {file = "multidict-6.0.4-cp38-cp38-win32.whl", hash = "sha256:e41b7e2b59679edfa309e8db64fdf22399eec4b0b24694e1b2104fb789207779"},
- {file = "multidict-6.0.4-cp38-cp38-win_amd64.whl", hash = "sha256:d6c254ba6e45d8e72739281ebc46ea5eb5f101234f3ce171f0e9f5cc86991480"},
- {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:16ab77bbeb596e14212e7bab8429f24c1579234a3a462105cda4a66904998664"},
- {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc779e9e6f7fda81b3f9aa58e3a6091d49ad528b11ed19f6621408806204ad35"},
- {file = "multidict-6.0.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ceef517eca3e03c1cceb22030a3e39cb399ac86bff4e426d4fc6ae49052cc60"},
- {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:281af09f488903fde97923c7744bb001a9b23b039a909460d0f14edc7bf59706"},
- {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:52f2dffc8acaba9a2f27174c41c9e57f60b907bb9f096b36b1a1f3be71c6284d"},
- {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b41156839806aecb3641f3208c0dafd3ac7775b9c4c422d82ee2a45c34ba81ca"},
- {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5e3fc56f88cc98ef8139255cf8cd63eb2c586531e43310ff859d6bb3a6b51f1"},
- {file = "multidict-6.0.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8316a77808c501004802f9beebde51c9f857054a0c871bd6da8280e718444449"},
- {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f70b98cd94886b49d91170ef23ec5c0e8ebb6f242d734ed7ed677b24d50c82cf"},
- {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bf6774e60d67a9efe02b3616fee22441d86fab4c6d335f9d2051d19d90a40063"},
- {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:e69924bfcdda39b722ef4d9aa762b2dd38e4632b3641b1d9a57ca9cd18f2f83a"},
- {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:6b181d8c23da913d4ff585afd1155a0e1194c0b50c54fcfe286f70cdaf2b7176"},
- {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:52509b5be062d9eafc8170e53026fbc54cf3b32759a23d07fd935fb04fc22d95"},
- {file = "multidict-6.0.4-cp39-cp39-win32.whl", hash = "sha256:27c523fbfbdfd19c6867af7346332b62b586eed663887392cff78d614f9ec313"},
- {file = "multidict-6.0.4-cp39-cp39-win_amd64.whl", hash = "sha256:33029f5734336aa0d4c0384525da0387ef89148dc7191aae00ca5fb23d7aafc2"},
- {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"},
+ {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b644ae063c10e7f324ab1ab6b548bdf6f8b47f3ec234fef1093bc2735e5f9"},
+ {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:896ebdcf62683551312c30e20614305f53125750803b614e9e6ce74a96232604"},
+ {file = "multidict-6.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:411bf8515f3be9813d06004cac41ccf7d1cd46dfe233705933dd163b60e37600"},
+ {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d147090048129ce3c453f0292e7697d333db95e52616b3793922945804a433c"},
+ {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:215ed703caf15f578dca76ee6f6b21b7603791ae090fbf1ef9d865571039ade5"},
+ {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c6390cf87ff6234643428991b7359b5f59cc15155695deb4eda5c777d2b880f"},
+ {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fd81c4ebdb4f214161be351eb5bcf385426bf023041da2fd9e60681f3cebae"},
+ {file = "multidict-6.0.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3cc2ad10255f903656017363cd59436f2111443a76f996584d1077e43ee51182"},
+ {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6939c95381e003f54cd4c5516740faba40cf5ad3eeff460c3ad1d3e0ea2549bf"},
+ {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:220dd781e3f7af2c2c1053da9fa96d9cf3072ca58f057f4c5adaaa1cab8fc442"},
+ {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:766c8f7511df26d9f11cd3a8be623e59cca73d44643abab3f8c8c07620524e4a"},
+ {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:fe5d7785250541f7f5019ab9cba2c71169dc7d74d0f45253f8313f436458a4ef"},
+ {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1c1496e73051918fcd4f58ff2e0f2f3066d1c76a0c6aeffd9b45d53243702cc"},
+ {file = "multidict-6.0.5-cp310-cp310-win32.whl", hash = "sha256:7afcdd1fc07befad18ec4523a782cde4e93e0a2bf71239894b8d61ee578c1319"},
+ {file = "multidict-6.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:99f60d34c048c5c2fabc766108c103612344c46e35d4ed9ae0673d33c8fb26e8"},
+ {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f285e862d2f153a70586579c15c44656f888806ed0e5b56b64489afe4a2dbfba"},
+ {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:53689bb4e102200a4fafa9de9c7c3c212ab40a7ab2c8e474491914d2305f187e"},
+ {file = "multidict-6.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:612d1156111ae11d14afaf3a0669ebf6c170dbb735e510a7438ffe2369a847fd"},
+ {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7be7047bd08accdb7487737631d25735c9a04327911de89ff1b26b81745bd4e3"},
+ {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de170c7b4fe6859beb8926e84f7d7d6c693dfe8e27372ce3b76f01c46e489fcf"},
+ {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04bde7a7b3de05732a4eb39c94574db1ec99abb56162d6c520ad26f83267de29"},
+ {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85f67aed7bb647f93e7520633d8f51d3cbc6ab96957c71272b286b2f30dc70ed"},
+ {file = "multidict-6.0.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425bf820055005bfc8aa9a0b99ccb52cc2f4070153e34b701acc98d201693733"},
+ {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3eb1ceec286eba8220c26f3b0096cf189aea7057b6e7b7a2e60ed36b373b77f"},
+ {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7901c05ead4b3fb75113fb1dd33eb1253c6d3ee37ce93305acd9d38e0b5f21a4"},
+ {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e0e79d91e71b9867c73323a3444724d496c037e578a0e1755ae159ba14f4f3d1"},
+ {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:29bfeb0dff5cb5fdab2023a7a9947b3b4af63e9c47cae2a10ad58394b517fddc"},
+ {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e030047e85cbcedbfc073f71836d62dd5dadfbe7531cae27789ff66bc551bd5e"},
+ {file = "multidict-6.0.5-cp311-cp311-win32.whl", hash = "sha256:2f4848aa3baa109e6ab81fe2006c77ed4d3cd1e0ac2c1fbddb7b1277c168788c"},
+ {file = "multidict-6.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:2faa5ae9376faba05f630d7e5e6be05be22913782b927b19d12b8145968a85ea"},
+ {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:51d035609b86722963404f711db441cf7134f1889107fb171a970c9701f92e1e"},
+ {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cbebcd5bcaf1eaf302617c114aa67569dd3f090dd0ce8ba9e35e9985b41ac35b"},
+ {file = "multidict-6.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ffc42c922dbfddb4a4c3b438eb056828719f07608af27d163191cb3e3aa6cc5"},
+ {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceb3b7e6a0135e092de86110c5a74e46bda4bd4fbfeeb3a3bcec79c0f861e450"},
+ {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79660376075cfd4b2c80f295528aa6beb2058fd289f4c9252f986751a4cd0496"},
+ {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4428b29611e989719874670fd152b6625500ad6c686d464e99f5aaeeaca175a"},
+ {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d84a5c3a5f7ce6db1f999fb9438f686bc2e09d38143f2d93d8406ed2dd6b9226"},
+ {file = "multidict-6.0.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76c0de87358b192de7ea9649beb392f107dcad9ad27276324c24c91774ca5271"},
+ {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:79a6d2ba910adb2cbafc95dad936f8b9386e77c84c35bc0add315b856d7c3abb"},
+ {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:92d16a3e275e38293623ebf639c471d3e03bb20b8ebb845237e0d3664914caef"},
+ {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:fb616be3538599e797a2017cccca78e354c767165e8858ab5116813146041a24"},
+ {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:14c2976aa9038c2629efa2c148022ed5eb4cb939e15ec7aace7ca932f48f9ba6"},
+ {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:435a0984199d81ca178b9ae2c26ec3d49692d20ee29bc4c11a2a8d4514c67eda"},
+ {file = "multidict-6.0.5-cp312-cp312-win32.whl", hash = "sha256:9fe7b0653ba3d9d65cbe7698cca585bf0f8c83dbbcc710db9c90f478e175f2d5"},
+ {file = "multidict-6.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:01265f5e40f5a17f8241d52656ed27192be03bfa8764d88e8220141d1e4b3556"},
+ {file = "multidict-6.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:19fe01cea168585ba0f678cad6f58133db2aa14eccaf22f88e4a6dccadfad8b3"},
+ {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf7a982604375a8d49b6cc1b781c1747f243d91b81035a9b43a2126c04766f5"},
+ {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:107c0cdefe028703fb5dafe640a409cb146d44a6ae201e55b35a4af8e95457dd"},
+ {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:403c0911cd5d5791605808b942c88a8155c2592e05332d2bf78f18697a5fa15e"},
+ {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aeaf541ddbad8311a87dd695ed9642401131ea39ad7bc8cf3ef3967fd093b626"},
+ {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4972624066095e52b569e02b5ca97dbd7a7ddd4294bf4e7247d52635630dd83"},
+ {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d946b0a9eb8aaa590df1fe082cee553ceab173e6cb5b03239716338629c50c7a"},
+ {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b55358304d7a73d7bdf5de62494aaf70bd33015831ffd98bc498b433dfe5b10c"},
+ {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:a3145cb08d8625b2d3fee1b2d596a8766352979c9bffe5d7833e0503d0f0b5e5"},
+ {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d65f25da8e248202bd47445cec78e0025c0fe7582b23ec69c3b27a640dd7a8e3"},
+ {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:c9bf56195c6bbd293340ea82eafd0071cb3d450c703d2c93afb89f93b8386ccc"},
+ {file = "multidict-6.0.5-cp37-cp37m-win32.whl", hash = "sha256:69db76c09796b313331bb7048229e3bee7928eb62bab5e071e9f7fcc4879caee"},
+ {file = "multidict-6.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:fce28b3c8a81b6b36dfac9feb1de115bab619b3c13905b419ec71d03a3fc1423"},
+ {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:76f067f5121dcecf0d63a67f29080b26c43c71a98b10c701b0677e4a065fbd54"},
+ {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b82cc8ace10ab5bd93235dfaab2021c70637005e1ac787031f4d1da63d493c1d"},
+ {file = "multidict-6.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5cb241881eefd96b46f89b1a056187ea8e9ba14ab88ba632e68d7a2ecb7aadf7"},
+ {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8e94e6912639a02ce173341ff62cc1201232ab86b8a8fcc05572741a5dc7d93"},
+ {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09a892e4a9fb47331da06948690ae38eaa2426de97b4ccbfafbdcbe5c8f37ff8"},
+ {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55205d03e8a598cfc688c71ca8ea5f66447164efff8869517f175ea632c7cb7b"},
+ {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37b15024f864916b4951adb95d3a80c9431299080341ab9544ed148091b53f50"},
+ {file = "multidict-6.0.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2a1dee728b52b33eebff5072817176c172050d44d67befd681609b4746e1c2e"},
+ {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:edd08e6f2f1a390bf137080507e44ccc086353c8e98c657e666c017718561b89"},
+ {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:60d698e8179a42ec85172d12f50b1668254628425a6bd611aba022257cac1386"},
+ {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:3d25f19500588cbc47dc19081d78131c32637c25804df8414463ec908631e453"},
+ {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4cc0ef8b962ac7a5e62b9e826bd0cd5040e7d401bc45a6835910ed699037a461"},
+ {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:eca2e9d0cc5a889850e9bbd68e98314ada174ff6ccd1129500103df7a94a7a44"},
+ {file = "multidict-6.0.5-cp38-cp38-win32.whl", hash = "sha256:4a6a4f196f08c58c59e0b8ef8ec441d12aee4125a7d4f4fef000ccb22f8d7241"},
+ {file = "multidict-6.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:0275e35209c27a3f7951e1ce7aaf93ce0d163b28948444bec61dd7badc6d3f8c"},
+ {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e7be68734bd8c9a513f2b0cfd508802d6609da068f40dc57d4e3494cefc92929"},
+ {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1d9ea7a7e779d7a3561aade7d596649fbecfa5c08a7674b11b423783217933f9"},
+ {file = "multidict-6.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ea1456df2a27c73ce51120fa2f519f1bea2f4a03a917f4a43c8707cf4cbbae1a"},
+ {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf590b134eb70629e350691ecca88eac3e3b8b3c86992042fb82e3cb1830d5e1"},
+ {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5c0631926c4f58e9a5ccce555ad7747d9a9f8b10619621f22f9635f069f6233e"},
+ {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dce1c6912ab9ff5f179eaf6efe7365c1f425ed690b03341911bf4939ef2f3046"},
+ {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0868d64af83169e4d4152ec612637a543f7a336e4a307b119e98042e852ad9c"},
+ {file = "multidict-6.0.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:141b43360bfd3bdd75f15ed811850763555a251e38b2405967f8e25fb43f7d40"},
+ {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7df704ca8cf4a073334e0427ae2345323613e4df18cc224f647f251e5e75a527"},
+ {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6214c5a5571802c33f80e6c84713b2c79e024995b9c5897f794b43e714daeec9"},
+ {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd6c8fca38178e12c00418de737aef1261576bd1b6e8c6134d3e729a4e858b38"},
+ {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:e02021f87a5b6932fa6ce916ca004c4d441509d33bbdbeca70d05dff5e9d2479"},
+ {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ebd8d160f91a764652d3e51ce0d2956b38efe37c9231cd82cfc0bed2e40b581c"},
+ {file = "multidict-6.0.5-cp39-cp39-win32.whl", hash = "sha256:04da1bb8c8dbadf2a18a452639771951c662c5ad03aefe4884775454be322c9b"},
+ {file = "multidict-6.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:d6f6d4f185481c9669b9447bf9d9cf3b95a0e9df9d169bbc17e363b7d5487755"},
+ {file = "multidict-6.0.5-py3-none-any.whl", hash = "sha256:0d63c74e3d7ab26de115c49bffc92cc77ed23395303d496eae515d4204a625e7"},
+ {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"},
]
[[package]]
name = "multiprocess"
-version = "0.70.15"
+version = "0.70.16"
description = "better multiprocessing and multithreading in Python"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
files = [
- {file = "multiprocess-0.70.15-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:aa36c7ed16f508091438687fe9baa393a7a8e206731d321e443745e743a0d4e5"},
- {file = "multiprocess-0.70.15-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:20e024018c46d0d1602024c613007ac948f9754659e3853b0aa705e83f6931d8"},
- {file = "multiprocess-0.70.15-pp37-pypy37_pp73-manylinux_2_24_i686.whl", hash = "sha256:e576062981c91f0fe8a463c3d52506e598dfc51320a8dd8d78b987dfca91c5db"},
- {file = "multiprocess-0.70.15-pp37-pypy37_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:e73f497e6696a0f5433ada2b3d599ae733b87a6e8b008e387c62ac9127add177"},
- {file = "multiprocess-0.70.15-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:73db2e7b32dcc7f9b0f075c2ffa45c90b6729d3f1805f27e88534c8d321a1be5"},
- {file = "multiprocess-0.70.15-pp38-pypy38_pp73-manylinux_2_24_i686.whl", hash = "sha256:4271647bd8a49c28ecd6eb56a7fdbd3c212c45529ad5303b40b3c65fc6928e5f"},
- {file = "multiprocess-0.70.15-pp38-pypy38_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:cf981fb998d6ec3208cb14f0cf2e9e80216e834f5d51fd09ebc937c32b960902"},
- {file = "multiprocess-0.70.15-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:18f9f2c7063346d1617bd1684fdcae8d33380ae96b99427260f562e1a1228b67"},
- {file = "multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_i686.whl", hash = "sha256:0eac53214d664c49a34695e5824872db4006b1a465edd7459a251809c3773370"},
- {file = "multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:1a51dd34096db47fb21fa2b839e615b051d51b97af9a67afbcdaa67186b44883"},
- {file = "multiprocess-0.70.15-py310-none-any.whl", hash = "sha256:7dd58e33235e83cf09d625e55cffd7b0f0eede7ee9223cdd666a87624f60c21a"},
- {file = "multiprocess-0.70.15-py311-none-any.whl", hash = "sha256:134f89053d82c9ed3b73edd3a2531eb791e602d4f4156fc92a79259590bd9670"},
- {file = "multiprocess-0.70.15-py37-none-any.whl", hash = "sha256:f7d4a1629bccb433114c3b4885f69eccc200994323c80f6feee73b0edc9199c5"},
- {file = "multiprocess-0.70.15-py38-none-any.whl", hash = "sha256:bee9afba476c91f9ebee7beeee0601face9eff67d822e893f9a893725fbd6316"},
- {file = "multiprocess-0.70.15-py39-none-any.whl", hash = "sha256:3e0953f5d52b4c76f1c973eaf8214554d146f2be5decb48e928e55c7a2d19338"},
- {file = "multiprocess-0.70.15.tar.gz", hash = "sha256:f20eed3036c0ef477b07a4177cf7c1ba520d9a2677870a4f47fe026f0cd6787e"},
+ {file = "multiprocess-0.70.16-pp310-pypy310_pp73-macosx_10_13_x86_64.whl", hash = "sha256:476887be10e2f59ff183c006af746cb6f1fd0eadcfd4ef49e605cbe2659920ee"},
+ {file = "multiprocess-0.70.16-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d951bed82c8f73929ac82c61f01a7b5ce8f3e5ef40f5b52553b4f547ce2b08ec"},
+ {file = "multiprocess-0.70.16-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:37b55f71c07e2d741374998c043b9520b626a8dddc8b3129222ca4f1a06ef67a"},
+ {file = "multiprocess-0.70.16-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:ba8c31889abf4511c7308a8c52bb4a30b9d590e7f58523302ba00237702ca054"},
+ {file = "multiprocess-0.70.16-pp39-pypy39_pp73-macosx_10_13_x86_64.whl", hash = "sha256:0dfd078c306e08d46d7a8d06fb120313d87aa43af60d66da43ffff40b44d2f41"},
+ {file = "multiprocess-0.70.16-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e7b9d0f307cd9bd50851afaac0dba2cb6c44449efff697df7c7645f7d3f2be3a"},
+ {file = "multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02"},
+ {file = "multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a"},
+ {file = "multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e"},
+ {file = "multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435"},
+ {file = "multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3"},
+ {file = "multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1"},
]
[package.dependencies]
-dill = ">=0.3.7"
+dill = ">=0.3.8"
[[package]]
name = "mypy"
-version = "1.6.1"
+version = "1.9.0"
description = "Optional static typing for Python"
optional = false
python-versions = ">=3.8"
files = [
- {file = "mypy-1.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e5012e5cc2ac628177eaac0e83d622b2dd499e28253d4107a08ecc59ede3fc2c"},
- {file = "mypy-1.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d8fbb68711905f8912e5af474ca8b78d077447d8f3918997fecbf26943ff3cbb"},
- {file = "mypy-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21a1ad938fee7d2d96ca666c77b7c494c3c5bd88dff792220e1afbebb2925b5e"},
- {file = "mypy-1.6.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b96ae2c1279d1065413965c607712006205a9ac541895004a1e0d4f281f2ff9f"},
- {file = "mypy-1.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:40b1844d2e8b232ed92e50a4bd11c48d2daa351f9deee6c194b83bf03e418b0c"},
- {file = "mypy-1.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:81af8adaa5e3099469e7623436881eff6b3b06db5ef75e6f5b6d4871263547e5"},
- {file = "mypy-1.6.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8c223fa57cb154c7eab5156856c231c3f5eace1e0bed9b32a24696b7ba3c3245"},
- {file = "mypy-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8032e00ce71c3ceb93eeba63963b864bf635a18f6c0c12da6c13c450eedb183"},
- {file = "mypy-1.6.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4c46b51de523817a0045b150ed11b56f9fff55f12b9edd0f3ed35b15a2809de0"},
- {file = "mypy-1.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:19f905bcfd9e167159b3d63ecd8cb5e696151c3e59a1742e79bc3bcb540c42c7"},
- {file = "mypy-1.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:82e469518d3e9a321912955cc702d418773a2fd1e91c651280a1bda10622f02f"},
- {file = "mypy-1.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d4473c22cc296425bbbce7e9429588e76e05bc7342da359d6520b6427bf76660"},
- {file = "mypy-1.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59a0d7d24dfb26729e0a068639a6ce3500e31d6655df8557156c51c1cb874ce7"},
- {file = "mypy-1.6.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:cfd13d47b29ed3bbaafaff7d8b21e90d827631afda134836962011acb5904b71"},
- {file = "mypy-1.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:eb4f18589d196a4cbe5290b435d135dee96567e07c2b2d43b5c4621b6501531a"},
- {file = "mypy-1.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:41697773aa0bf53ff917aa077e2cde7aa50254f28750f9b88884acea38a16169"},
- {file = "mypy-1.6.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7274b0c57737bd3476d2229c6389b2ec9eefeb090bbaf77777e9d6b1b5a9d143"},
- {file = "mypy-1.6.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbaf4662e498c8c2e352da5f5bca5ab29d378895fa2d980630656178bd607c46"},
- {file = "mypy-1.6.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bb8ccb4724f7d8601938571bf3f24da0da791fe2db7be3d9e79849cb64e0ae85"},
- {file = "mypy-1.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:68351911e85145f582b5aa6cd9ad666c8958bcae897a1bfda8f4940472463c45"},
- {file = "mypy-1.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:49ae115da099dcc0922a7a895c1eec82c1518109ea5c162ed50e3b3594c71208"},
- {file = "mypy-1.6.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8b27958f8c76bed8edaa63da0739d76e4e9ad4ed325c814f9b3851425582a3cd"},
- {file = "mypy-1.6.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:925cd6a3b7b55dfba252b7c4561892311c5358c6b5a601847015a1ad4eb7d332"},
- {file = "mypy-1.6.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8f57e6b6927a49550da3d122f0cb983d400f843a8a82e65b3b380d3d7259468f"},
- {file = "mypy-1.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:a43ef1c8ddfdb9575691720b6352761f3f53d85f1b57d7745701041053deff30"},
- {file = "mypy-1.6.1-py3-none-any.whl", hash = "sha256:4cbe68ef919c28ea561165206a2dcb68591c50f3bcf777932323bc208d949cf1"},
- {file = "mypy-1.6.1.tar.gz", hash = "sha256:4d01c00d09a0be62a4ca3f933e315455bde83f37f892ba4b08ce92f3cf44bcc1"},
+ {file = "mypy-1.9.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f8a67616990062232ee4c3952f41c779afac41405806042a8126fe96e098419f"},
+ {file = "mypy-1.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d357423fa57a489e8c47b7c85dfb96698caba13d66e086b412298a1a0ea3b0ed"},
+ {file = "mypy-1.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49c87c15aed320de9b438ae7b00c1ac91cd393c1b854c2ce538e2a72d55df150"},
+ {file = "mypy-1.9.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:48533cdd345c3c2e5ef48ba3b0d3880b257b423e7995dada04248725c6f77374"},
+ {file = "mypy-1.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:4d3dbd346cfec7cb98e6cbb6e0f3c23618af826316188d587d1c1bc34f0ede03"},
+ {file = "mypy-1.9.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:653265f9a2784db65bfca694d1edd23093ce49740b2244cde583aeb134c008f3"},
+ {file = "mypy-1.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3a3c007ff3ee90f69cf0a15cbcdf0995749569b86b6d2f327af01fd1b8aee9dc"},
+ {file = "mypy-1.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2418488264eb41f69cc64a69a745fad4a8f86649af4b1041a4c64ee61fc61129"},
+ {file = "mypy-1.9.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:68edad3dc7d70f2f17ae4c6c1b9471a56138ca22722487eebacfd1eb5321d612"},
+ {file = "mypy-1.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:85ca5fcc24f0b4aeedc1d02f93707bccc04733f21d41c88334c5482219b1ccb3"},
+ {file = "mypy-1.9.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aceb1db093b04db5cd390821464504111b8ec3e351eb85afd1433490163d60cd"},
+ {file = "mypy-1.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0235391f1c6f6ce487b23b9dbd1327b4ec33bb93934aa986efe8a9563d9349e6"},
+ {file = "mypy-1.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4d5ddc13421ba3e2e082a6c2d74c2ddb3979c39b582dacd53dd5d9431237185"},
+ {file = "mypy-1.9.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:190da1ee69b427d7efa8aa0d5e5ccd67a4fb04038c380237a0d96829cb157913"},
+ {file = "mypy-1.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:fe28657de3bfec596bbeef01cb219833ad9d38dd5393fc649f4b366840baefe6"},
+ {file = "mypy-1.9.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e54396d70be04b34f31d2edf3362c1edd023246c82f1730bbf8768c28db5361b"},
+ {file = "mypy-1.9.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5e6061f44f2313b94f920e91b204ec600982961e07a17e0f6cd83371cb23f5c2"},
+ {file = "mypy-1.9.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81a10926e5473c5fc3da8abb04119a1f5811a236dc3a38d92015cb1e6ba4cb9e"},
+ {file = "mypy-1.9.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b685154e22e4e9199fc95f298661deea28aaede5ae16ccc8cbb1045e716b3e04"},
+ {file = "mypy-1.9.0-cp38-cp38-win_amd64.whl", hash = "sha256:5d741d3fc7c4da608764073089e5f58ef6352bedc223ff58f2f038c2c4698a89"},
+ {file = "mypy-1.9.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:587ce887f75dd9700252a3abbc9c97bbe165a4a630597845c61279cf32dfbf02"},
+ {file = "mypy-1.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f88566144752999351725ac623471661c9d1cd8caa0134ff98cceeea181789f4"},
+ {file = "mypy-1.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61758fabd58ce4b0720ae1e2fea5cfd4431591d6d590b197775329264f86311d"},
+ {file = "mypy-1.9.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e49499be624dead83927e70c756970a0bc8240e9f769389cdf5714b0784ca6bf"},
+ {file = "mypy-1.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:571741dc4194b4f82d344b15e8837e8c5fcc462d66d076748142327626a1b6e9"},
+ {file = "mypy-1.9.0-py3-none-any.whl", hash = "sha256:a260627a570559181a9ea5de61ac6297aa5af202f06fd7ab093ce74e7181e43e"},
+ {file = "mypy-1.9.0.tar.gz", hash = "sha256:3cc5da0127e6a478cddd906068496a97a7618a21ce9b54bde5bf7e539c7af974"},
]
[package.dependencies]
@@ -2157,6 +2253,7 @@ typing-extensions = ">=4.1.0"
[package.extras]
dmypy = ["psutil (>=4.0)"]
install-types = ["pip"]
+mypyc = ["setuptools (>=50)"]
reports = ["lxml"]
[[package]]
@@ -2198,13 +2295,13 @@ testing-docutils = ["pygments", "pytest (>=7,<8)", "pytest-param-files (>=0.3.4,
[[package]]
name = "nbclient"
-version = "0.8.0"
+version = "0.10.0"
description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor."
optional = false
python-versions = ">=3.8.0"
files = [
- {file = "nbclient-0.8.0-py3-none-any.whl", hash = "sha256:25e861299e5303a0477568557c4045eccc7a34c17fc08e7959558707b9ebe548"},
- {file = "nbclient-0.8.0.tar.gz", hash = "sha256:f9b179cd4b2d7bca965f900a2ebf0db4a12ebff2f36a711cb66861e4ae158e55"},
+ {file = "nbclient-0.10.0-py3-none-any.whl", hash = "sha256:f13e3529332a1f1f81d82a53210322476a168bb7090a0289c795fe9cc11c9d3f"},
+ {file = "nbclient-0.10.0.tar.gz", hash = "sha256:4b3f1b7dba531e498449c4db4f53da339c91d449dc11e9af3a43b4eb5c5abb09"},
]
[package.dependencies]
@@ -2216,17 +2313,17 @@ traitlets = ">=5.4"
[package.extras]
dev = ["pre-commit"]
docs = ["autodoc-traits", "mock", "moto", "myst-parser", "nbclient[test]", "sphinx (>=1.7)", "sphinx-book-theme", "sphinxcontrib-spelling"]
-test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>=7.0.0)", "pytest (>=7.0)", "pytest-asyncio", "pytest-cov (>=4.0)", "testpath", "xmltodict"]
+test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>=7.0.0)", "pytest (>=7.0,<8)", "pytest-asyncio", "pytest-cov (>=4.0)", "testpath", "xmltodict"]
[[package]]
name = "nbconvert"
-version = "7.9.2"
-description = "Converting Jupyter Notebooks"
+version = "7.16.3"
+description = "Converting Jupyter Notebooks (.ipynb files) to other formats. Output formats include asciidoc, html, latex, markdown, pdf, py, rst, script. nbconvert can be used both as a Python library (`import nbconvert`) or as a command line tool (invoked as `jupyter nbconvert ...`)."
optional = false
python-versions = ">=3.8"
files = [
- {file = "nbconvert-7.9.2-py3-none-any.whl", hash = "sha256:39fe4b8bdd1b0104fdd86fc8a43a9077ba64c720bda4c6132690d917a0a154ee"},
- {file = "nbconvert-7.9.2.tar.gz", hash = "sha256:e56cc7588acc4f93e2bb5a34ec69028e4941797b2bfaf6462f18a41d1cc258c9"},
+ {file = "nbconvert-7.16.3-py3-none-any.whl", hash = "sha256:ddeff14beeeedf3dd0bc506623e41e4507e551736de59df69a91f86700292b3b"},
+ {file = "nbconvert-7.16.3.tar.gz", hash = "sha256:a6733b78ce3d47c3f85e504998495b07e6ea9cf9bf6ec1c98dda63ec6ad19142"},
]
[package.dependencies]
@@ -2253,24 +2350,24 @@ docs = ["ipykernel", "ipython", "myst-parser", "nbsphinx (>=0.2.12)", "pydata-sp
qtpdf = ["nbconvert[qtpng]"]
qtpng = ["pyqtwebengine (>=5.15)"]
serve = ["tornado (>=6.1)"]
-test = ["flaky", "ipykernel", "ipywidgets (>=7)", "pytest", "pytest-dependency"]
+test = ["flaky", "ipykernel", "ipywidgets (>=7.5)", "pytest (>=7)"]
webpdf = ["playwright"]
[[package]]
name = "nbformat"
-version = "5.9.2"
+version = "5.10.4"
description = "The Jupyter Notebook format"
optional = false
python-versions = ">=3.8"
files = [
- {file = "nbformat-5.9.2-py3-none-any.whl", hash = "sha256:1c5172d786a41b82bcfd0c23f9e6b6f072e8fb49c39250219e4acfff1efe89e9"},
- {file = "nbformat-5.9.2.tar.gz", hash = "sha256:5f98b5ba1997dff175e77e0c17d5c10a96eaed2cbd1de3533d1fc35d5e111192"},
+ {file = "nbformat-5.10.4-py3-none-any.whl", hash = "sha256:3b48d6c8fbca4b299bf3982ea7db1af21580e4fec269ad087b9e81588891200b"},
+ {file = "nbformat-5.10.4.tar.gz", hash = "sha256:322168b14f937a5d11362988ecac2a4952d3d8e3a2cbeb2319584631226d5b3a"},
]
[package.dependencies]
-fastjsonschema = "*"
+fastjsonschema = ">=2.15"
jsonschema = ">=2.6"
-jupyter-core = "*"
+jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0"
traitlets = ">=5.1"
[package.extras]
@@ -2316,13 +2413,13 @@ pytest = ">=2.8"
[[package]]
name = "nest-asyncio"
-version = "1.5.8"
+version = "1.6.0"
description = "Patch asyncio to allow nested event loops"
optional = false
python-versions = ">=3.5"
files = [
- {file = "nest_asyncio-1.5.8-py3-none-any.whl", hash = "sha256:accda7a339a70599cb08f9dd09a67e0c2ef8d8d6f4c07f96ab203f2ae254e48d"},
- {file = "nest_asyncio-1.5.8.tar.gz", hash = "sha256:25aa2ca0d2a5b5531956b9e273b45cf664cae2b145101d73b86b199978d48fdb"},
+ {file = "nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c"},
+ {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"},
]
[[package]]
@@ -2345,18 +2442,18 @@ test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"]
[[package]]
name = "notebook"
-version = "7.0.6"
+version = "7.1.2"
description = "Jupyter Notebook - A web-based notebook environment for interactive computing"
optional = false
python-versions = ">=3.8"
files = [
- {file = "notebook-7.0.6-py3-none-any.whl", hash = "sha256:0fe8f67102fea3744fedf652e4c15339390902ca70c5a31c4f547fa23da697cc"},
- {file = "notebook-7.0.6.tar.gz", hash = "sha256:ec6113b06529019f7f287819af06c97a2baf7a95ac21a8f6e32192898e9f9a58"},
+ {file = "notebook-7.1.2-py3-none-any.whl", hash = "sha256:fc6c24b9aef18d0cd57157c9c47e95833b9b0bdc599652639acf0bdb61dc7d5f"},
+ {file = "notebook-7.1.2.tar.gz", hash = "sha256:efc2c80043909e0faa17fce9e9b37c059c03af0ec99a4d4db84cb21d9d2e936a"},
]
[package.dependencies]
jupyter-server = ">=2.4.0,<3"
-jupyterlab = ">=4.0.2,<5"
+jupyterlab = ">=4.1.1,<4.2"
jupyterlab-server = ">=2.22.1,<3"
notebook-shim = ">=0.2,<0.3"
tornado = ">=6.2.0"
@@ -2368,13 +2465,13 @@ test = ["importlib-resources (>=5.0)", "ipykernel", "jupyter-server[test] (>=2.4
[[package]]
name = "notebook-shim"
-version = "0.2.3"
+version = "0.2.4"
description = "A shim layer for notebook traits and config"
optional = false
python-versions = ">=3.7"
files = [
- {file = "notebook_shim-0.2.3-py3-none-any.whl", hash = "sha256:a83496a43341c1674b093bfcebf0fe8e74cbe7eda5fd2bbc56f8e39e1486c0c7"},
- {file = "notebook_shim-0.2.3.tar.gz", hash = "sha256:f69388ac283ae008cd506dda10d0288b09a017d822d5e8c7129a152cbd3ce7e9"},
+ {file = "notebook_shim-0.2.4-py3-none-any.whl", hash = "sha256:411a5be4e9dc882a074ccbcae671eda64cceb068767e9a3419096986560e1cef"},
+ {file = "notebook_shim-0.2.4.tar.gz", hash = "sha256:b4b2cfa1b65d98307ca24361f5b30fe785b53c3fd07b7a47e89acb5e6ac638cb"},
]
[package.dependencies]
@@ -2422,43 +2519,47 @@ files = [
[[package]]
name = "numpy"
-version = "1.26.1"
+version = "1.26.4"
description = "Fundamental package for array computing in Python"
optional = false
-python-versions = "<3.13,>=3.9"
-files = [
- {file = "numpy-1.26.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:82e871307a6331b5f09efda3c22e03c095d957f04bf6bc1804f30048d0e5e7af"},
- {file = "numpy-1.26.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cdd9ec98f0063d93baeb01aad472a1a0840dee302842a2746a7a8e92968f9575"},
- {file = "numpy-1.26.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d78f269e0c4fd365fc2992c00353e4530d274ba68f15e968d8bc3c69ce5f5244"},
- {file = "numpy-1.26.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ab9163ca8aeb7fd32fe93866490654d2f7dda4e61bc6297bf72ce07fdc02f67"},
- {file = "numpy-1.26.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:78ca54b2f9daffa5f323f34cdf21e1d9779a54073f0018a3094ab907938331a2"},
- {file = "numpy-1.26.1-cp310-cp310-win32.whl", hash = "sha256:d1cfc92db6af1fd37a7bb58e55c8383b4aa1ba23d012bdbba26b4bcca45ac297"},
- {file = "numpy-1.26.1-cp310-cp310-win_amd64.whl", hash = "sha256:d2984cb6caaf05294b8466966627e80bf6c7afd273279077679cb010acb0e5ab"},
- {file = "numpy-1.26.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cd7837b2b734ca72959a1caf3309457a318c934abef7a43a14bb984e574bbb9a"},
- {file = "numpy-1.26.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1c59c046c31a43310ad0199d6299e59f57a289e22f0f36951ced1c9eac3665b9"},
- {file = "numpy-1.26.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d58e8c51a7cf43090d124d5073bc29ab2755822181fcad978b12e144e5e5a4b3"},
- {file = "numpy-1.26.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6081aed64714a18c72b168a9276095ef9155dd7888b9e74b5987808f0dd0a974"},
- {file = "numpy-1.26.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:97e5d6a9f0702c2863aaabf19f0d1b6c2628fbe476438ce0b5ce06e83085064c"},
- {file = "numpy-1.26.1-cp311-cp311-win32.whl", hash = "sha256:b9d45d1dbb9de84894cc50efece5b09939752a2d75aab3a8b0cef6f3a35ecd6b"},
- {file = "numpy-1.26.1-cp311-cp311-win_amd64.whl", hash = "sha256:3649d566e2fc067597125428db15d60eb42a4e0897fc48d28cb75dc2e0454e53"},
- {file = "numpy-1.26.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1d1bd82d539607951cac963388534da3b7ea0e18b149a53cf883d8f699178c0f"},
- {file = "numpy-1.26.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:afd5ced4e5a96dac6725daeb5242a35494243f2239244fad10a90ce58b071d24"},
- {file = "numpy-1.26.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a03fb25610ef560a6201ff06df4f8105292ba56e7cdd196ea350d123fc32e24e"},
- {file = "numpy-1.26.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcfaf015b79d1f9f9c9fd0731a907407dc3e45769262d657d754c3a028586124"},
- {file = "numpy-1.26.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e509cbc488c735b43b5ffea175235cec24bbc57b227ef1acc691725beb230d1c"},
- {file = "numpy-1.26.1-cp312-cp312-win32.whl", hash = "sha256:af22f3d8e228d84d1c0c44c1fbdeb80f97a15a0abe4f080960393a00db733b66"},
- {file = "numpy-1.26.1-cp312-cp312-win_amd64.whl", hash = "sha256:9f42284ebf91bdf32fafac29d29d4c07e5e9d1af862ea73686581773ef9e73a7"},
- {file = "numpy-1.26.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bb894accfd16b867d8643fc2ba6c8617c78ba2828051e9a69511644ce86ce83e"},
- {file = "numpy-1.26.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e44ccb93f30c75dfc0c3aa3ce38f33486a75ec9abadabd4e59f114994a9c4617"},
- {file = "numpy-1.26.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9696aa2e35cc41e398a6d42d147cf326f8f9d81befcb399bc1ed7ffea339b64e"},
- {file = "numpy-1.26.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a5b411040beead47a228bde3b2241100454a6abde9df139ed087bd73fc0a4908"},
- {file = "numpy-1.26.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1e11668d6f756ca5ef534b5be8653d16c5352cbb210a5c2a79ff288e937010d5"},
- {file = "numpy-1.26.1-cp39-cp39-win32.whl", hash = "sha256:d1d2c6b7dd618c41e202c59c1413ef9b2c8e8a15f5039e344af64195459e3104"},
- {file = "numpy-1.26.1-cp39-cp39-win_amd64.whl", hash = "sha256:59227c981d43425ca5e5c01094d59eb14e8772ce6975d4b2fc1e106a833d5ae2"},
- {file = "numpy-1.26.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:06934e1a22c54636a059215d6da99e23286424f316fddd979f5071093b648668"},
- {file = "numpy-1.26.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76ff661a867d9272cd2a99eed002470f46dbe0943a5ffd140f49be84f68ffc42"},
- {file = "numpy-1.26.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:6965888d65d2848e8768824ca8288db0a81263c1efccec881cb35a0d805fcd2f"},
- {file = "numpy-1.26.1.tar.gz", hash = "sha256:c8c6c72d4a9f831f328efb1312642a1cafafaa88981d9ab76368d50d07d93cbe"},
+python-versions = ">=3.9"
+files = [
+ {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"},
+ {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"},
+ {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"},
+ {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"},
+ {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"},
+ {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"},
+ {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"},
+ {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"},
+ {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"},
+ {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"},
+ {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"},
+ {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"},
+ {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"},
+ {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"},
+ {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"},
+ {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"},
+ {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"},
+ {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"},
+ {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"},
+ {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"},
+ {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"},
+ {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"},
+ {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"},
+ {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"},
+ {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"},
+ {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"},
+ {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"},
+ {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"},
+ {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"},
+ {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"},
+ {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"},
+ {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"},
+ {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"},
+ {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"},
+ {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"},
+ {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
]
[[package]]
@@ -2582,13 +2683,13 @@ files = [
[[package]]
name = "nvidia-nvjitlink-cu12"
-version = "12.3.52"
+version = "12.4.127"
description = "Nvidia JIT LTO Library"
optional = false
python-versions = ">=3"
files = [
- {file = "nvidia_nvjitlink_cu12-12.3.52-py3-none-manylinux1_x86_64.whl", hash = "sha256:93db4dba8cb66fe2a351791e557208345bb9d0ace1bfb9dd05a4812f9a3ac74e"},
- {file = "nvidia_nvjitlink_cu12-12.3.52-py3-none-win_amd64.whl", hash = "sha256:9e403610da6ebceee897371a6982433ec997a9279d2320840413ce82a1d28ddc"},
+ {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
+ {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"},
]
[[package]]
@@ -2604,24 +2705,24 @@ files = [
[[package]]
name = "overrides"
-version = "7.4.0"
+version = "7.7.0"
description = "A decorator to automatically detect mismatch when overriding a method."
optional = false
python-versions = ">=3.6"
files = [
- {file = "overrides-7.4.0-py3-none-any.whl", hash = "sha256:3ad24583f86d6d7a49049695efe9933e67ba62f0c7625d53c59fa832ce4b8b7d"},
- {file = "overrides-7.4.0.tar.gz", hash = "sha256:9502a3cca51f4fac40b5feca985b6703a5c1f6ad815588a7ca9e285b9dca6757"},
+ {file = "overrides-7.7.0-py3-none-any.whl", hash = "sha256:c7ed9d062f78b8e4c1a7b70bd8796b35ead4d9f510227ef9c5dc7626c60d7e49"},
+ {file = "overrides-7.7.0.tar.gz", hash = "sha256:55158fa3d93b98cc75299b1e67078ad9003ca27945c76162c1c0766d6f91820a"},
]
[[package]]
name = "packaging"
-version = "23.2"
+version = "24.0"
description = "Core utilities for Python packages"
optional = false
python-versions = ">=3.7"
files = [
- {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"},
- {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"},
+ {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"},
+ {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"},
]
[[package]]
@@ -2661,7 +2762,7 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.20.3", markers = "python_version < \"3.10\""},
- {version = ">=1.21.0", markers = "python_version >= \"3.10\""},
+ {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
]
python-dateutil = ">=2.8.2"
@@ -2707,60 +2808,50 @@ ply = "*"
[[package]]
name = "pandocfilters"
-version = "1.5.0"
+version = "1.5.1"
description = "Utilities for writing pandoc filters in python"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
files = [
- {file = "pandocfilters-1.5.0-py2.py3-none-any.whl", hash = "sha256:33aae3f25fd1a026079f5d27bdd52496f0e0803b3469282162bafdcbdf6ef14f"},
- {file = "pandocfilters-1.5.0.tar.gz", hash = "sha256:0b679503337d233b4339a817bfc8c50064e2eff681314376a47cb582305a7a38"},
+ {file = "pandocfilters-1.5.1-py2.py3-none-any.whl", hash = "sha256:93be382804a9cdb0a7267585f157e5d1731bbe5545a85b268d6f5fe6232de2bc"},
+ {file = "pandocfilters-1.5.1.tar.gz", hash = "sha256:002b4a555ee4ebc03f8b66307e287fa492e4a77b4ea14d3f934328297bb4939e"},
]
[[package]]
name = "parso"
-version = "0.8.3"
+version = "0.8.4"
description = "A Python Parser"
optional = false
python-versions = ">=3.6"
files = [
- {file = "parso-0.8.3-py2.py3-none-any.whl", hash = "sha256:c001d4636cd3aecdaf33cbb40aebb59b094be2a74c556778ef5576c175e19e75"},
- {file = "parso-0.8.3.tar.gz", hash = "sha256:8c07be290bb59f03588915921e29e8a50002acaf2cdc5fa0e0114f91709fafa0"},
+ {file = "parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18"},
+ {file = "parso-0.8.4.tar.gz", hash = "sha256:eb3a7b58240fb99099a345571deecc0f9540ea5f4dd2fe14c2a99d6b281ab92d"},
]
[package.extras]
-qa = ["flake8 (==3.8.3)", "mypy (==0.782)"]
-testing = ["docopt", "pytest (<6.0.0)"]
+qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"]
+testing = ["docopt", "pytest"]
[[package]]
name = "pathspec"
-version = "0.11.2"
+version = "0.12.1"
description = "Utility library for gitignore style pattern matching of file paths."
optional = false
-python-versions = ">=3.7"
-files = [
- {file = "pathspec-0.11.2-py3-none-any.whl", hash = "sha256:1d6ed233af05e679efb96b1851550ea95bbb64b7c490b0f5aa52996c11e92a20"},
- {file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"},
-]
-
-[[package]]
-name = "pathtools"
-version = "0.1.2"
-description = "File system general utilities"
-optional = false
-python-versions = "*"
+python-versions = ">=3.8"
files = [
- {file = "pathtools-0.1.2.tar.gz", hash = "sha256:7c35c5421a39bb82e58018febd90e3b6e5db34c5443aaaf742b3f33d4655f1c0"},
+ {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"},
+ {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"},
]
[[package]]
name = "pexpect"
-version = "4.8.0"
+version = "4.9.0"
description = "Pexpect allows easy control of interactive console applications."
optional = false
python-versions = "*"
files = [
- {file = "pexpect-4.8.0-py2.py3-none-any.whl", hash = "sha256:0b48a55dcb3c05f3329815901ea4fc1537514d6ba867a152b581d69ae3710937"},
- {file = "pexpect-4.8.0.tar.gz", hash = "sha256:fc65a43959d153d0114afe13997d439c22823a27cefceb5ff35c2178c6784c0c"},
+ {file = "pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523"},
+ {file = "pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f"},
]
[package.dependencies]
@@ -2790,28 +2881,28 @@ files = [
[[package]]
name = "platformdirs"
-version = "3.11.0"
+version = "4.2.0"
description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
files = [
- {file = "platformdirs-3.11.0-py3-none-any.whl", hash = "sha256:e9d171d00af68be50e9202731309c4e658fd8bc76f55c11c7dd760d023bda68e"},
- {file = "platformdirs-3.11.0.tar.gz", hash = "sha256:cf8ee52a3afdb965072dcc652433e0c7e3e40cf5ea1477cd4b3b1d2eb75495b3"},
+ {file = "platformdirs-4.2.0-py3-none-any.whl", hash = "sha256:0614df2a2f37e1a662acbd8e2b25b92ccf8632929bc6d43467e17fe89c75e068"},
+ {file = "platformdirs-4.2.0.tar.gz", hash = "sha256:ef0cc731df711022c174543cb70a9b5bd22e5a9337c8624ef2c2ceb8ddad8768"},
]
[package.extras]
-docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-autodoc-typehints (>=1.24)"]
-test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"]
+docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"]
+test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"]
[[package]]
name = "plotly"
-version = "5.18.0"
+version = "5.20.0"
description = "An open-source, interactive data visualization library for Python"
optional = false
-python-versions = ">=3.6"
+python-versions = ">=3.8"
files = [
- {file = "plotly-5.18.0-py3-none-any.whl", hash = "sha256:23aa8ea2f4fb364a20d34ad38235524bd9d691bf5299e800bca608c31e8db8de"},
- {file = "plotly-5.18.0.tar.gz", hash = "sha256:360a31e6fbb49d12b007036eb6929521343d6bee2236f8459915821baefa2cbb"},
+ {file = "plotly-5.20.0-py3-none-any.whl", hash = "sha256:837a9c8aa90f2c0a2f0d747b82544d014dc2a2bdde967b5bb1da25b53932d1a9"},
+ {file = "plotly-5.20.0.tar.gz", hash = "sha256:bf901c805d22032cfa534b2ff7c5aa6b0659e037f19ec1e0cca7f585918b5c89"},
]
[package.dependencies]
@@ -2820,13 +2911,13 @@ tenacity = ">=6.2.0"
[[package]]
name = "pluggy"
-version = "1.3.0"
+version = "1.4.0"
description = "plugin and hook calling mechanisms for python"
optional = false
python-versions = ">=3.8"
files = [
- {file = "pluggy-1.3.0-py3-none-any.whl", hash = "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7"},
- {file = "pluggy-1.3.0.tar.gz", hash = "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12"},
+ {file = "pluggy-1.4.0-py3-none-any.whl", hash = "sha256:7db9f7b503d67d1c5b95f59773ebb58a8c1c288129a88665838012cfb07b8981"},
+ {file = "pluggy-1.4.0.tar.gz", hash = "sha256:8c85c2876142a764e5b7548e7d9a0e0ddb46f5185161049a79b7e974454223be"},
]
[package.extras]
@@ -2879,13 +2970,13 @@ six = ">=1.5.2"
[[package]]
name = "prometheus-client"
-version = "0.17.1"
+version = "0.20.0"
description = "Python client for the Prometheus monitoring system."
optional = false
-python-versions = ">=3.6"
+python-versions = ">=3.8"
files = [
- {file = "prometheus_client-0.17.1-py3-none-any.whl", hash = "sha256:e537f37160f6807b8202a6fc4764cdd19bac5480ddd3e0d463c3002b34462101"},
- {file = "prometheus_client-0.17.1.tar.gz", hash = "sha256:21e674f39831ae3f8acde238afd9a27a37d0d2fb5a28ea094f0ce25d2cbf2091"},
+ {file = "prometheus_client-0.20.0-py3-none-any.whl", hash = "sha256:cde524a85bce83ca359cc837f28b8c0db5cac7aa653a588fd7e84ba061c329e7"},
+ {file = "prometheus_client-0.20.0.tar.gz", hash = "sha256:287629d00b147a32dcb2be0b9df905da599b2d82f80377083ec8463309a4bb89"},
]
[package.extras]
@@ -2893,13 +2984,13 @@ twisted = ["twisted"]
[[package]]
name = "prompt-toolkit"
-version = "3.0.39"
+version = "3.0.43"
description = "Library for building powerful interactive command lines in Python"
optional = false
python-versions = ">=3.7.0"
files = [
- {file = "prompt_toolkit-3.0.39-py3-none-any.whl", hash = "sha256:9dffbe1d8acf91e3de75f3b544e4842382fc06c6babe903ac9acb74dc6e08d88"},
- {file = "prompt_toolkit-3.0.39.tar.gz", hash = "sha256:04505ade687dc26dc4284b1ad19a83be2f2afe83e7a828ace0c72f3a1df72aac"},
+ {file = "prompt_toolkit-3.0.43-py3-none-any.whl", hash = "sha256:a11a29cb3bf0a28a387fe5122cdb649816a957cd9261dcedf8c9f1fef33eacf6"},
+ {file = "prompt_toolkit-3.0.43.tar.gz", hash = "sha256:3527b7af26106cbc65a040bcc84839a3566ec1b051bb0bfe953631e704b0ff7d"},
]
[package.dependencies]
@@ -2907,49 +2998,47 @@ wcwidth = "*"
[[package]]
name = "protobuf"
-version = "4.24.4"
+version = "4.25.3"
description = ""
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
files = [
- {file = "protobuf-4.24.4-cp310-abi3-win32.whl", hash = "sha256:ec9912d5cb6714a5710e28e592ee1093d68c5ebfeda61983b3f40331da0b1ebb"},
- {file = "protobuf-4.24.4-cp310-abi3-win_amd64.whl", hash = "sha256:1badab72aa8a3a2b812eacfede5020472e16c6b2212d737cefd685884c191085"},
- {file = "protobuf-4.24.4-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:8e61a27f362369c2f33248a0ff6896c20dcd47b5d48239cb9720134bef6082e4"},
- {file = "protobuf-4.24.4-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:bffa46ad9612e6779d0e51ae586fde768339b791a50610d85eb162daeb23661e"},
- {file = "protobuf-4.24.4-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:b493cb590960ff863743b9ff1452c413c2ee12b782f48beca77c8da3e2ffe9d9"},
- {file = "protobuf-4.24.4-cp37-cp37m-win32.whl", hash = "sha256:dbbed8a56e56cee8d9d522ce844a1379a72a70f453bde6243e3c86c30c2a3d46"},
- {file = "protobuf-4.24.4-cp37-cp37m-win_amd64.whl", hash = "sha256:6b7d2e1c753715dcfe9d284a25a52d67818dd43c4932574307daf836f0071e37"},
- {file = "protobuf-4.24.4-cp38-cp38-win32.whl", hash = "sha256:02212557a76cd99574775a81fefeba8738d0f668d6abd0c6b1d3adcc75503dbe"},
- {file = "protobuf-4.24.4-cp38-cp38-win_amd64.whl", hash = "sha256:2fa3886dfaae6b4c5ed2730d3bf47c7a38a72b3a1f0acb4d4caf68e6874b947b"},
- {file = "protobuf-4.24.4-cp39-cp39-win32.whl", hash = "sha256:b77272f3e28bb416e2071186cb39efd4abbf696d682cbb5dc731308ad37fa6dd"},
- {file = "protobuf-4.24.4-cp39-cp39-win_amd64.whl", hash = "sha256:9fee5e8aa20ef1b84123bb9232b3f4a5114d9897ed89b4b8142d81924e05d79b"},
- {file = "protobuf-4.24.4-py3-none-any.whl", hash = "sha256:80797ce7424f8c8d2f2547e2d42bfbb6c08230ce5832d6c099a37335c9c90a92"},
- {file = "protobuf-4.24.4.tar.gz", hash = "sha256:5a70731910cd9104762161719c3d883c960151eea077134458503723b60e3667"},
+ {file = "protobuf-4.25.3-cp310-abi3-win32.whl", hash = "sha256:d4198877797a83cbfe9bffa3803602bbe1625dc30d8a097365dbc762e5790faa"},
+ {file = "protobuf-4.25.3-cp310-abi3-win_amd64.whl", hash = "sha256:209ba4cc916bab46f64e56b85b090607a676f66b473e6b762e6f1d9d591eb2e8"},
+ {file = "protobuf-4.25.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:f1279ab38ecbfae7e456a108c5c0681e4956d5b1090027c1de0f934dfdb4b35c"},
+ {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:e7cb0ae90dd83727f0c0718634ed56837bfeeee29a5f82a7514c03ee1364c019"},
+ {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:7c8daa26095f82482307bc717364e7c13f4f1c99659be82890dcfc215194554d"},
+ {file = "protobuf-4.25.3-cp38-cp38-win32.whl", hash = "sha256:f4f118245c4a087776e0a8408be33cf09f6c547442c00395fbfb116fac2f8ac2"},
+ {file = "protobuf-4.25.3-cp38-cp38-win_amd64.whl", hash = "sha256:c053062984e61144385022e53678fbded7aea14ebb3e0305ae3592fb219ccfa4"},
+ {file = "protobuf-4.25.3-cp39-cp39-win32.whl", hash = "sha256:19b270aeaa0099f16d3ca02628546b8baefe2955bbe23224aaf856134eccf1e4"},
+ {file = "protobuf-4.25.3-cp39-cp39-win_amd64.whl", hash = "sha256:e3c97a1555fd6388f857770ff8b9703083de6bf1f9274a002a332d65fbb56c8c"},
+ {file = "protobuf-4.25.3-py3-none-any.whl", hash = "sha256:f0700d54bcf45424477e46a9f0944155b46fb0639d69728739c0e47bab83f2b9"},
+ {file = "protobuf-4.25.3.tar.gz", hash = "sha256:25b5d0b42fd000320bd7830b349e3b696435f3b329810427a6bcce6a5492cc5c"},
]
[[package]]
name = "psutil"
-version = "5.9.6"
+version = "5.9.8"
description = "Cross-platform lib for process and system monitoring in Python."
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
files = [
- {file = "psutil-5.9.6-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:fb8a697f11b0f5994550555fcfe3e69799e5b060c8ecf9e2f75c69302cc35c0d"},
- {file = "psutil-5.9.6-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:91ecd2d9c00db9817a4b4192107cf6954addb5d9d67a969a4f436dbc9200f88c"},
- {file = "psutil-5.9.6-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:10e8c17b4f898d64b121149afb136c53ea8b68c7531155147867b7b1ac9e7e28"},
- {file = "psutil-5.9.6-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:18cd22c5db486f33998f37e2bb054cc62fd06646995285e02a51b1e08da97017"},
- {file = "psutil-5.9.6-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:ca2780f5e038379e520281e4c032dddd086906ddff9ef0d1b9dcf00710e5071c"},
- {file = "psutil-5.9.6-cp27-none-win32.whl", hash = "sha256:70cb3beb98bc3fd5ac9ac617a327af7e7f826373ee64c80efd4eb2856e5051e9"},
- {file = "psutil-5.9.6-cp27-none-win_amd64.whl", hash = "sha256:51dc3d54607c73148f63732c727856f5febec1c7c336f8f41fcbd6315cce76ac"},
- {file = "psutil-5.9.6-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c69596f9fc2f8acd574a12d5f8b7b1ba3765a641ea5d60fb4736bf3c08a8214a"},
- {file = "psutil-5.9.6-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:92e0cc43c524834af53e9d3369245e6cc3b130e78e26100d1f63cdb0abeb3d3c"},
- {file = "psutil-5.9.6-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:748c9dd2583ed86347ed65d0035f45fa8c851e8d90354c122ab72319b5f366f4"},
- {file = "psutil-5.9.6-cp36-cp36m-win32.whl", hash = "sha256:3ebf2158c16cc69db777e3c7decb3c0f43a7af94a60d72e87b2823aebac3d602"},
- {file = "psutil-5.9.6-cp36-cp36m-win_amd64.whl", hash = "sha256:ff18b8d1a784b810df0b0fff3bcb50ab941c3b8e2c8de5726f9c71c601c611aa"},
- {file = "psutil-5.9.6-cp37-abi3-win32.whl", hash = "sha256:a6f01f03bf1843280f4ad16f4bde26b817847b4c1a0db59bf6419807bc5ce05c"},
- {file = "psutil-5.9.6-cp37-abi3-win_amd64.whl", hash = "sha256:6e5fb8dc711a514da83098bc5234264e551ad980cec5f85dabf4d38ed6f15e9a"},
- {file = "psutil-5.9.6-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:daecbcbd29b289aac14ece28eca6a3e60aa361754cf6da3dfb20d4d32b6c7f57"},
- {file = "psutil-5.9.6.tar.gz", hash = "sha256:e4b92ddcd7dd4cdd3f900180ea1e104932c7bce234fb88976e2a3b296441225a"},
+ {file = "psutil-5.9.8-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8"},
+ {file = "psutil-5.9.8-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73"},
+ {file = "psutil-5.9.8-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7"},
+ {file = "psutil-5.9.8-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36"},
+ {file = "psutil-5.9.8-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d"},
+ {file = "psutil-5.9.8-cp27-none-win32.whl", hash = "sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e"},
+ {file = "psutil-5.9.8-cp27-none-win_amd64.whl", hash = "sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631"},
+ {file = "psutil-5.9.8-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81"},
+ {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421"},
+ {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4"},
+ {file = "psutil-5.9.8-cp36-cp36m-win32.whl", hash = "sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee"},
+ {file = "psutil-5.9.8-cp36-cp36m-win_amd64.whl", hash = "sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2"},
+ {file = "psutil-5.9.8-cp37-abi3-win32.whl", hash = "sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0"},
+ {file = "psutil-5.9.8-cp37-abi3-win_amd64.whl", hash = "sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf"},
+ {file = "psutil-5.9.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8"},
+ {file = "psutil-5.9.8.tar.gz", hash = "sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c"},
]
[package.extras]
@@ -2982,58 +3071,76 @@ tests = ["pytest"]
[[package]]
name = "pyarrow"
-version = "13.0.0"
+version = "15.0.2"
description = "Python library for Apache Arrow"
optional = false
python-versions = ">=3.8"
files = [
- {file = "pyarrow-13.0.0-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:1afcc2c33f31f6fb25c92d50a86b7a9f076d38acbcb6f9e74349636109550148"},
- {file = "pyarrow-13.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:70fa38cdc66b2fc1349a082987f2b499d51d072faaa6b600f71931150de2e0e3"},
- {file = "pyarrow-13.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cd57b13a6466822498238877892a9b287b0a58c2e81e4bdb0b596dbb151cbb73"},
- {file = "pyarrow-13.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8ce69f7bf01de2e2764e14df45b8404fc6f1a5ed9871e8e08a12169f87b7a26"},
- {file = "pyarrow-13.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:588f0d2da6cf1b1680974d63be09a6530fd1bd825dc87f76e162404779a157dc"},
- {file = "pyarrow-13.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:6241afd72b628787b4abea39e238e3ff9f34165273fad306c7acf780dd850956"},
- {file = "pyarrow-13.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:fda7857e35993673fcda603c07d43889fca60a5b254052a462653f8656c64f44"},
- {file = "pyarrow-13.0.0-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:aac0ae0146a9bfa5e12d87dda89d9ef7c57a96210b899459fc2f785303dcbb67"},
- {file = "pyarrow-13.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d7759994217c86c161c6a8060509cfdf782b952163569606bb373828afdd82e8"},
- {file = "pyarrow-13.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:868a073fd0ff6468ae7d869b5fc1f54de5c4255b37f44fb890385eb68b68f95d"},
- {file = "pyarrow-13.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51be67e29f3cfcde263a113c28e96aa04362ed8229cb7c6e5f5c719003659d33"},
- {file = "pyarrow-13.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:d1b4e7176443d12610874bb84d0060bf080f000ea9ed7c84b2801df851320295"},
- {file = "pyarrow-13.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:69b6f9a089d116a82c3ed819eea8fe67dae6105f0d81eaf0fdd5e60d0c6e0944"},
- {file = "pyarrow-13.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:ab1268db81aeb241200e321e220e7cd769762f386f92f61b898352dd27e402ce"},
- {file = "pyarrow-13.0.0-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:ee7490f0f3f16a6c38f8c680949551053c8194e68de5046e6c288e396dccee80"},
- {file = "pyarrow-13.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e3ad79455c197a36eefbd90ad4aa832bece7f830a64396c15c61a0985e337287"},
- {file = "pyarrow-13.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68fcd2dc1b7d9310b29a15949cdd0cb9bc34b6de767aff979ebf546020bf0ba0"},
- {file = "pyarrow-13.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc6fd330fd574c51d10638e63c0d00ab456498fc804c9d01f2a61b9264f2c5b2"},
- {file = "pyarrow-13.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:e66442e084979a97bb66939e18f7b8709e4ac5f887e636aba29486ffbf373763"},
- {file = "pyarrow-13.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:0f6eff839a9e40e9c5610d3ff8c5bdd2f10303408312caf4c8003285d0b49565"},
- {file = "pyarrow-13.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:8b30a27f1cddf5c6efcb67e598d7823a1e253d743d92ac32ec1eb4b6a1417867"},
- {file = "pyarrow-13.0.0-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:09552dad5cf3de2dc0aba1c7c4b470754c69bd821f5faafc3d774bedc3b04bb7"},
- {file = "pyarrow-13.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3896ae6c205d73ad192d2fc1489cd0edfab9f12867c85b4c277af4d37383c18c"},
- {file = "pyarrow-13.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6647444b21cb5e68b593b970b2a9a07748dd74ea457c7dadaa15fd469c48ada1"},
- {file = "pyarrow-13.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47663efc9c395e31d09c6aacfa860f4473815ad6804311c5433f7085415d62a7"},
- {file = "pyarrow-13.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:b9ba6b6d34bd2563345488cf444510588ea42ad5613df3b3509f48eb80250afd"},
- {file = "pyarrow-13.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:d00d374a5625beeb448a7fa23060df79adb596074beb3ddc1838adb647b6ef09"},
- {file = "pyarrow-13.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:c51afd87c35c8331b56f796eff954b9c7f8d4b7fef5903daf4e05fcf017d23a8"},
- {file = "pyarrow-13.0.0.tar.gz", hash = "sha256:83333726e83ed44b0ac94d8d7a21bbdee4a05029c3b1e8db58a863eec8fd8a33"},
+ {file = "pyarrow-15.0.2-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:88b340f0a1d05b5ccc3d2d986279045655b1fe8e41aba6ca44ea28da0d1455d8"},
+ {file = "pyarrow-15.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:eaa8f96cecf32da508e6c7f69bb8401f03745c050c1dd42ec2596f2e98deecac"},
+ {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23c6753ed4f6adb8461e7c383e418391b8d8453c5d67e17f416c3a5d5709afbd"},
+ {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f639c059035011db8c0497e541a8a45d98a58dbe34dc8fadd0ef128f2cee46e5"},
+ {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:290e36a59a0993e9a5224ed2fb3e53375770f07379a0ea03ee2fce2e6d30b423"},
+ {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:06c2bb2a98bc792f040bef31ad3e9be6a63d0cb39189227c08a7d955db96816e"},
+ {file = "pyarrow-15.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:f7a197f3670606a960ddc12adbe8075cea5f707ad7bf0dffa09637fdbb89f76c"},
+ {file = "pyarrow-15.0.2-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:5f8bc839ea36b1f99984c78e06e7a06054693dc2af8920f6fb416b5bca9944e4"},
+ {file = "pyarrow-15.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f5e81dfb4e519baa6b4c80410421528c214427e77ca0ea9461eb4097c328fa33"},
+ {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a4f240852b302a7af4646c8bfe9950c4691a419847001178662a98915fd7ee7"},
+ {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e7d9cfb5a1e648e172428c7a42b744610956f3b70f524aa3a6c02a448ba853e"},
+ {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2d4f905209de70c0eb5b2de6763104d5a9a37430f137678edfb9a675bac9cd98"},
+ {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:90adb99e8ce5f36fbecbbc422e7dcbcbed07d985eed6062e459e23f9e71fd197"},
+ {file = "pyarrow-15.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:b116e7fd7889294cbd24eb90cd9bdd3850be3738d61297855a71ac3b8124ee38"},
+ {file = "pyarrow-15.0.2-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:25335e6f1f07fdaa026a61c758ee7d19ce824a866b27bba744348fa73bb5a440"},
+ {file = "pyarrow-15.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:90f19e976d9c3d8e73c80be84ddbe2f830b6304e4c576349d9360e335cd627fc"},
+ {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a22366249bf5fd40ddacc4f03cd3160f2d7c247692945afb1899bab8a140ddfb"},
+ {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2a335198f886b07e4b5ea16d08ee06557e07db54a8400cc0d03c7f6a22f785f"},
+ {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3e6d459c0c22f0b9c810a3917a1de3ee704b021a5fb8b3bacf968eece6df098f"},
+ {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:033b7cad32198754d93465dcfb71d0ba7cb7cd5c9afd7052cab7214676eec38b"},
+ {file = "pyarrow-15.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:29850d050379d6e8b5a693098f4de7fd6a2bea4365bfd073d7c57c57b95041ee"},
+ {file = "pyarrow-15.0.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:7167107d7fb6dcadb375b4b691b7e316f4368f39f6f45405a05535d7ad5e5058"},
+ {file = "pyarrow-15.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e85241b44cc3d365ef950432a1b3bd44ac54626f37b2e3a0cc89c20e45dfd8bf"},
+ {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:248723e4ed3255fcd73edcecc209744d58a9ca852e4cf3d2577811b6d4b59818"},
+ {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ff3bdfe6f1b81ca5b73b70a8d482d37a766433823e0c21e22d1d7dde76ca33f"},
+ {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:f3d77463dee7e9f284ef42d341689b459a63ff2e75cee2b9302058d0d98fe142"},
+ {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:8c1faf2482fb89766e79745670cbca04e7018497d85be9242d5350cba21357e1"},
+ {file = "pyarrow-15.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:28f3016958a8e45a1069303a4a4f6a7d4910643fc08adb1e2e4a7ff056272ad3"},
+ {file = "pyarrow-15.0.2-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:89722cb64286ab3d4daf168386f6968c126057b8c7ec3ef96302e81d8cdb8ae4"},
+ {file = "pyarrow-15.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cd0ba387705044b3ac77b1b317165c0498299b08261d8122c96051024f953cd5"},
+ {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad2459bf1f22b6a5cdcc27ebfd99307d5526b62d217b984b9f5c974651398832"},
+ {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58922e4bfece8b02abf7159f1f53a8f4d9f8e08f2d988109126c17c3bb261f22"},
+ {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:adccc81d3dc0478ea0b498807b39a8d41628fa9210729b2f718b78cb997c7c91"},
+ {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:8bd2baa5fe531571847983f36a30ddbf65261ef23e496862ece83bdceb70420d"},
+ {file = "pyarrow-15.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6669799a1d4ca9da9c7e06ef48368320f5856f36f9a4dd31a11839dda3f6cc8c"},
+ {file = "pyarrow-15.0.2.tar.gz", hash = "sha256:9c9bc803cb3b7bfacc1e96ffbfd923601065d9d3f911179d81e72d99fd74a3d9"},
]
[package.dependencies]
-numpy = ">=1.16.6"
+numpy = ">=1.16.6,<2"
+
+[[package]]
+name = "pyarrow-hotfix"
+version = "0.6"
+description = ""
+optional = false
+python-versions = ">=3.5"
+files = [
+ {file = "pyarrow_hotfix-0.6-py3-none-any.whl", hash = "sha256:dcc9ae2d220dff0083be6a9aa8e0cdee5182ad358d4931fce825c545e5c89178"},
+ {file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"},
+]
[[package]]
name = "pycln"
-version = "2.3.0"
+version = "2.4.0"
description = "A formatter for finding and removing unused import statements."
optional = false
-python-versions = ">=3.6.2,<4"
+python-versions = ">=3.7.0,<4"
files = [
- {file = "pycln-2.3.0-py3-none-any.whl", hash = "sha256:d6731e17a60728b827211de2ca4bfc9b40ea1df99a12f3e0fd06a98a0c9e6caa"},
- {file = "pycln-2.3.0.tar.gz", hash = "sha256:8759b36753234c8f95895a31dde329479ffed2218f49d1a1c77c7edccc02e09b"},
+ {file = "pycln-2.4.0-py3-none-any.whl", hash = "sha256:d1bf648df17077306100815d255d45430035b36f66bac635df04a323c61ba126"},
+ {file = "pycln-2.4.0.tar.gz", hash = "sha256:1f3eefb7be18a9ee06c3bdd0ba2e91218cd39317e20130325f107e96eb84b9f6"},
]
[package.dependencies]
-libcst = {version = ">=0.3.10", markers = "python_version >= \"3.7\""}
+libcst = ">=0.3.10"
pathspec = ">=0.9.0"
pyyaml = ">=5.3.1"
tomlkit = ">=0.11.1"
@@ -3041,38 +3148,39 @@ typer = ">=0.4.1"
[[package]]
name = "pycparser"
-version = "2.21"
+version = "2.22"
description = "C parser in Python"
optional = false
-python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+python-versions = ">=3.8"
files = [
- {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"},
- {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"},
+ {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"},
+ {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"},
]
[[package]]
name = "pygments"
-version = "2.16.1"
+version = "2.17.2"
description = "Pygments is a syntax highlighting package written in Python."
optional = false
python-versions = ">=3.7"
files = [
- {file = "Pygments-2.16.1-py3-none-any.whl", hash = "sha256:13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692"},
- {file = "Pygments-2.16.1.tar.gz", hash = "sha256:1daff0494820c69bc8941e407aa20f577374ee88364ee10a98fdbe0aece96e29"},
+ {file = "pygments-2.17.2-py3-none-any.whl", hash = "sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c"},
+ {file = "pygments-2.17.2.tar.gz", hash = "sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367"},
]
[package.extras]
plugins = ["importlib-metadata"]
+windows-terminal = ["colorama (>=0.4.6)"]
[[package]]
name = "pytest"
-version = "7.4.3"
+version = "8.1.1"
description = "pytest: simple powerful testing with Python"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
files = [
- {file = "pytest-7.4.3-py3-none-any.whl", hash = "sha256:0d009c083ea859a71b76adf7c1d502e4bc170b80a8ef002da5806527b9591fac"},
- {file = "pytest-7.4.3.tar.gz", hash = "sha256:d989d136982de4e3b29dabcc838ad581c64e8ed52c11fbe86ddebd9da0818cd5"},
+ {file = "pytest-8.1.1-py3-none-any.whl", hash = "sha256:2a8386cfc11fa9d2c50ee7b2a57e7d898ef90470a7a34c4b949ff59662bb78b7"},
+ {file = "pytest-8.1.1.tar.gz", hash = "sha256:ac978141a75948948817d360297b7aae0fcb9d6ff6bc9ec6d514b85d5a65c044"},
]
[package.dependencies]
@@ -3080,21 +3188,21 @@ colorama = {version = "*", markers = "sys_platform == \"win32\""}
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
iniconfig = "*"
packaging = "*"
-pluggy = ">=0.12,<2.0"
-tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
+pluggy = ">=1.4,<2.0"
+tomli = {version = ">=1", markers = "python_version < \"3.11\""}
[package.extras]
-testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
+testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-cov"
-version = "4.1.0"
+version = "5.0.0"
description = "Pytest plugin for measuring coverage."
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
files = [
- {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"},
- {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"},
+ {file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"},
+ {file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"},
]
[package.dependencies]
@@ -3102,17 +3210,17 @@ coverage = {version = ">=5.2.1", extras = ["toml"]}
pytest = ">=4.6"
[package.extras]
-testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"]
+testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
[[package]]
name = "pytest-doctestplus"
-version = "1.0.0"
+version = "1.2.1"
description = "Pytest plugin with advanced doctest features."
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
files = [
- {file = "pytest-doctestplus-1.0.0.tar.gz", hash = "sha256:f650440dcaede13ed6d7da73bfb4ac585d40a80444ba3542d3e6eecdb275d49f"},
- {file = "pytest_doctestplus-1.0.0-py3-none-any.whl", hash = "sha256:dcba88e1e38bc4871c355e44b778ccfd49b25e33f6aa5393eed6b56440decb2a"},
+ {file = "pytest-doctestplus-1.2.1.tar.gz", hash = "sha256:2472a8a2c8cea34d2f65f6499543faeb748eecb59c597852fd98839b47307679"},
+ {file = "pytest_doctestplus-1.2.1-py3-none-any.whl", hash = "sha256:103705daee8d4468eb59d444c29b0d71eb85b8f6d582295c8bc3d68ee1d88911"},
]
[package.dependencies]
@@ -3125,13 +3233,13 @@ test = ["numpy", "pytest-remotedata (>=0.3.2)", "sphinx"]
[[package]]
name = "python-dateutil"
-version = "2.8.2"
+version = "2.9.0.post0"
description = "Extensions to the standard Python datetime module"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
files = [
- {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
- {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
+ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"},
+ {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"},
]
[package.dependencies]
@@ -3150,13 +3258,13 @@ files = [
[[package]]
name = "pytz"
-version = "2023.3.post1"
+version = "2024.1"
description = "World timezone definitions, modern and historical"
optional = false
python-versions = "*"
files = [
- {file = "pytz-2023.3.post1-py2.py3-none-any.whl", hash = "sha256:ce42d816b81b68506614c11e8937d3aa9e41007ceb50bfdcb0749b921bf646c7"},
- {file = "pytz-2023.3.post1.tar.gz", hash = "sha256:7b4fddbeb94a1eba4b557da24f19fdf9db575192544270a9101d8509f9f43d7b"},
+ {file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"},
+ {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"},
]
[[package]]
@@ -3184,17 +3292,17 @@ files = [
[[package]]
name = "pywinpty"
-version = "2.0.12"
+version = "2.0.13"
description = "Pseudo terminal support for Windows from Python."
optional = false
python-versions = ">=3.8"
files = [
- {file = "pywinpty-2.0.12-cp310-none-win_amd64.whl", hash = "sha256:21319cd1d7c8844fb2c970fb3a55a3db5543f112ff9cfcd623746b9c47501575"},
- {file = "pywinpty-2.0.12-cp311-none-win_amd64.whl", hash = "sha256:853985a8f48f4731a716653170cd735da36ffbdc79dcb4c7b7140bce11d8c722"},
- {file = "pywinpty-2.0.12-cp312-none-win_amd64.whl", hash = "sha256:1617b729999eb6713590e17665052b1a6ae0ad76ee31e60b444147c5b6a35dca"},
- {file = "pywinpty-2.0.12-cp38-none-win_amd64.whl", hash = "sha256:189380469ca143d06e19e19ff3fba0fcefe8b4a8cc942140a6b863aed7eebb2d"},
- {file = "pywinpty-2.0.12-cp39-none-win_amd64.whl", hash = "sha256:7520575b6546db23e693cbd865db2764097bd6d4ef5dc18c92555904cd62c3d4"},
- {file = "pywinpty-2.0.12.tar.gz", hash = "sha256:8197de460ae8ebb7f5d1701dfa1b5df45b157bb832e92acba316305e18ca00dd"},
+ {file = "pywinpty-2.0.13-cp310-none-win_amd64.whl", hash = "sha256:697bff211fb5a6508fee2dc6ff174ce03f34a9a233df9d8b5fe9c8ce4d5eaf56"},
+ {file = "pywinpty-2.0.13-cp311-none-win_amd64.whl", hash = "sha256:b96fb14698db1284db84ca38c79f15b4cfdc3172065b5137383910567591fa99"},
+ {file = "pywinpty-2.0.13-cp312-none-win_amd64.whl", hash = "sha256:2fd876b82ca750bb1333236ce98488c1be96b08f4f7647cfdf4129dfad83c2d4"},
+ {file = "pywinpty-2.0.13-cp38-none-win_amd64.whl", hash = "sha256:61d420c2116c0212808d31625611b51caf621fe67f8a6377e2e8b617ea1c1f7d"},
+ {file = "pywinpty-2.0.13-cp39-none-win_amd64.whl", hash = "sha256:71cb613a9ee24174730ac7ae439fd179ca34ccb8c5349e8d7b72ab5dea2c6f4b"},
+ {file = "pywinpty-2.0.13.tar.gz", hash = "sha256:c34e32351a3313ddd0d7da23d27f835c860d32fe4ac814d372a3ea9594f41dde"},
]
[[package]]
@@ -3209,6 +3317,7 @@ files = [
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
+ {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
{file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
{file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
{file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
@@ -3216,8 +3325,15 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
+ {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
{file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
+ {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
+ {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
+ {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
+ {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
+ {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
+ {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
{file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
@@ -3234,6 +3350,7 @@ files = [
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
+ {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"},
{file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
{file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
{file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
@@ -3241,6 +3358,7 @@ files = [
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
+ {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
{file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
{file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
{file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
@@ -3248,104 +3366,104 @@ files = [
[[package]]
name = "pyzmq"
-version = "25.1.1"
+version = "25.1.2"
description = "Python bindings for 0MQ"
optional = false
python-versions = ">=3.6"
files = [
- {file = "pyzmq-25.1.1-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:381469297409c5adf9a0e884c5eb5186ed33137badcbbb0560b86e910a2f1e76"},
- {file = "pyzmq-25.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:955215ed0604dac5b01907424dfa28b40f2b2292d6493445dd34d0dfa72586a8"},
- {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:985bbb1316192b98f32e25e7b9958088431d853ac63aca1d2c236f40afb17c83"},
- {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:afea96f64efa98df4da6958bae37f1cbea7932c35878b185e5982821bc883369"},
- {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76705c9325d72a81155bb6ab48d4312e0032bf045fb0754889133200f7a0d849"},
- {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:77a41c26205d2353a4c94d02be51d6cbdf63c06fbc1295ea57dad7e2d3381b71"},
- {file = "pyzmq-25.1.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:12720a53e61c3b99d87262294e2b375c915fea93c31fc2336898c26d7aed34cd"},
- {file = "pyzmq-25.1.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:57459b68e5cd85b0be8184382cefd91959cafe79ae019e6b1ae6e2ba8a12cda7"},
- {file = "pyzmq-25.1.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:292fe3fc5ad4a75bc8df0dfaee7d0babe8b1f4ceb596437213821f761b4589f9"},
- {file = "pyzmq-25.1.1-cp310-cp310-win32.whl", hash = "sha256:35b5ab8c28978fbbb86ea54958cd89f5176ce747c1fb3d87356cf698048a7790"},
- {file = "pyzmq-25.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:11baebdd5fc5b475d484195e49bae2dc64b94a5208f7c89954e9e354fc609d8f"},
- {file = "pyzmq-25.1.1-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:d20a0ddb3e989e8807d83225a27e5c2eb2260eaa851532086e9e0fa0d5287d83"},
- {file = "pyzmq-25.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e1c1be77bc5fb77d923850f82e55a928f8638f64a61f00ff18a67c7404faf008"},
- {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d89528b4943d27029a2818f847c10c2cecc79fa9590f3cb1860459a5be7933eb"},
- {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:90f26dc6d5f241ba358bef79be9ce06de58d477ca8485e3291675436d3827cf8"},
- {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2b92812bd214018e50b6380ea3ac0c8bb01ac07fcc14c5f86a5bb25e74026e9"},
- {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:2f957ce63d13c28730f7fd6b72333814221c84ca2421298f66e5143f81c9f91f"},
- {file = "pyzmq-25.1.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:047a640f5c9c6ade7b1cc6680a0e28c9dd5a0825135acbd3569cc96ea00b2505"},
- {file = "pyzmq-25.1.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7f7e58effd14b641c5e4dec8c7dab02fb67a13df90329e61c869b9cc607ef752"},
- {file = "pyzmq-25.1.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c2910967e6ab16bf6fbeb1f771c89a7050947221ae12a5b0b60f3bca2ee19bca"},
- {file = "pyzmq-25.1.1-cp311-cp311-win32.whl", hash = "sha256:76c1c8efb3ca3a1818b837aea423ff8a07bbf7aafe9f2f6582b61a0458b1a329"},
- {file = "pyzmq-25.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:44e58a0554b21fc662f2712814a746635ed668d0fbc98b7cb9d74cb798d202e6"},
- {file = "pyzmq-25.1.1-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:e1ffa1c924e8c72778b9ccd386a7067cddf626884fd8277f503c48bb5f51c762"},
- {file = "pyzmq-25.1.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1af379b33ef33757224da93e9da62e6471cf4a66d10078cf32bae8127d3d0d4a"},
- {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cff084c6933680d1f8b2f3b4ff5bbb88538a4aac00d199ac13f49d0698727ecb"},
- {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2400a94f7dd9cb20cd012951a0cbf8249e3d554c63a9c0cdfd5cbb6c01d2dec"},
- {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d81f1ddae3858b8299d1da72dd7d19dd36aab654c19671aa8a7e7fb02f6638a"},
- {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:255ca2b219f9e5a3a9ef3081512e1358bd4760ce77828e1028b818ff5610b87b"},
- {file = "pyzmq-25.1.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a882ac0a351288dd18ecae3326b8a49d10c61a68b01419f3a0b9a306190baf69"},
- {file = "pyzmq-25.1.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:724c292bb26365659fc434e9567b3f1adbdb5e8d640c936ed901f49e03e5d32e"},
- {file = "pyzmq-25.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ca1ed0bb2d850aa8471387882247c68f1e62a4af0ce9c8a1dbe0d2bf69e41fb"},
- {file = "pyzmq-25.1.1-cp312-cp312-win32.whl", hash = "sha256:b3451108ab861040754fa5208bca4a5496c65875710f76789a9ad27c801a0075"},
- {file = "pyzmq-25.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:eadbefd5e92ef8a345f0525b5cfd01cf4e4cc651a2cffb8f23c0dd184975d787"},
- {file = "pyzmq-25.1.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:db0b2af416ba735c6304c47f75d348f498b92952f5e3e8bff449336d2728795d"},
- {file = "pyzmq-25.1.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c7c133e93b405eb0d36fa430c94185bdd13c36204a8635470cccc200723c13bb"},
- {file = "pyzmq-25.1.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:273bc3959bcbff3f48606b28229b4721716598d76b5aaea2b4a9d0ab454ec062"},
- {file = "pyzmq-25.1.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:cbc8df5c6a88ba5ae385d8930da02201165408dde8d8322072e3e5ddd4f68e22"},
- {file = "pyzmq-25.1.1-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:18d43df3f2302d836f2a56f17e5663e398416e9dd74b205b179065e61f1a6edf"},
- {file = "pyzmq-25.1.1-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:73461eed88a88c866656e08f89299720a38cb4e9d34ae6bf5df6f71102570f2e"},
- {file = "pyzmq-25.1.1-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:34c850ce7976d19ebe7b9d4b9bb8c9dfc7aac336c0958e2651b88cbd46682123"},
- {file = "pyzmq-25.1.1-cp36-cp36m-win32.whl", hash = "sha256:d2045d6d9439a0078f2a34b57c7b18c4a6aef0bee37f22e4ec9f32456c852c71"},
- {file = "pyzmq-25.1.1-cp36-cp36m-win_amd64.whl", hash = "sha256:458dea649f2f02a0b244ae6aef8dc29325a2810aa26b07af8374dc2a9faf57e3"},
- {file = "pyzmq-25.1.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7cff25c5b315e63b07a36f0c2bab32c58eafbe57d0dce61b614ef4c76058c115"},
- {file = "pyzmq-25.1.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1579413ae492b05de5a6174574f8c44c2b9b122a42015c5292afa4be2507f28"},
- {file = "pyzmq-25.1.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3d0a409d3b28607cc427aa5c30a6f1e4452cc44e311f843e05edb28ab5e36da0"},
- {file = "pyzmq-25.1.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:21eb4e609a154a57c520e3d5bfa0d97e49b6872ea057b7c85257b11e78068222"},
- {file = "pyzmq-25.1.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:034239843541ef7a1aee0c7b2cb7f6aafffb005ede965ae9cbd49d5ff4ff73cf"},
- {file = "pyzmq-25.1.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f8115e303280ba09f3898194791a153862cbf9eef722ad8f7f741987ee2a97c7"},
- {file = "pyzmq-25.1.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:1a5d26fe8f32f137e784f768143728438877d69a586ddeaad898558dc971a5ae"},
- {file = "pyzmq-25.1.1-cp37-cp37m-win32.whl", hash = "sha256:f32260e556a983bc5c7ed588d04c942c9a8f9c2e99213fec11a031e316874c7e"},
- {file = "pyzmq-25.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:abf34e43c531bbb510ae7e8f5b2b1f2a8ab93219510e2b287a944432fad135f3"},
- {file = "pyzmq-25.1.1-cp38-cp38-macosx_10_15_universal2.whl", hash = "sha256:87e34f31ca8f168c56d6fbf99692cc8d3b445abb5bfd08c229ae992d7547a92a"},
- {file = "pyzmq-25.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c9c6c9b2c2f80747a98f34ef491c4d7b1a8d4853937bb1492774992a120f475d"},
- {file = "pyzmq-25.1.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5619f3f5a4db5dbb572b095ea3cb5cc035335159d9da950830c9c4db2fbb6995"},
- {file = "pyzmq-25.1.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5a34d2395073ef862b4032343cf0c32a712f3ab49d7ec4f42c9661e0294d106f"},
- {file = "pyzmq-25.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25f0e6b78220aba09815cd1f3a32b9c7cb3e02cb846d1cfc526b6595f6046618"},
- {file = "pyzmq-25.1.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:3669cf8ee3520c2f13b2e0351c41fea919852b220988d2049249db10046a7afb"},
- {file = "pyzmq-25.1.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:2d163a18819277e49911f7461567bda923461c50b19d169a062536fffe7cd9d2"},
- {file = "pyzmq-25.1.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:df27ffddff4190667d40de7beba4a950b5ce78fe28a7dcc41d6f8a700a80a3c0"},
- {file = "pyzmq-25.1.1-cp38-cp38-win32.whl", hash = "sha256:a382372898a07479bd34bda781008e4a954ed8750f17891e794521c3e21c2e1c"},
- {file = "pyzmq-25.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:52533489f28d62eb1258a965f2aba28a82aa747202c8fa5a1c7a43b5db0e85c1"},
- {file = "pyzmq-25.1.1-cp39-cp39-macosx_10_15_universal2.whl", hash = "sha256:03b3f49b57264909aacd0741892f2aecf2f51fb053e7d8ac6767f6c700832f45"},
- {file = "pyzmq-25.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:330f9e188d0d89080cde66dc7470f57d1926ff2fb5576227f14d5be7ab30b9fa"},
- {file = "pyzmq-25.1.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:2ca57a5be0389f2a65e6d3bb2962a971688cbdd30b4c0bd188c99e39c234f414"},
- {file = "pyzmq-25.1.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d457aed310f2670f59cc5b57dcfced452aeeed77f9da2b9763616bd57e4dbaae"},
- {file = "pyzmq-25.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c56d748ea50215abef7030c72b60dd723ed5b5c7e65e7bc2504e77843631c1a6"},
- {file = "pyzmq-25.1.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8f03d3f0d01cb5a018debeb412441996a517b11c5c17ab2001aa0597c6d6882c"},
- {file = "pyzmq-25.1.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:820c4a08195a681252f46926de10e29b6bbf3e17b30037bd4250d72dd3ddaab8"},
- {file = "pyzmq-25.1.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:17ef5f01d25b67ca8f98120d5fa1d21efe9611604e8eb03a5147360f517dd1e2"},
- {file = "pyzmq-25.1.1-cp39-cp39-win32.whl", hash = "sha256:04ccbed567171579ec2cebb9c8a3e30801723c575601f9a990ab25bcac6b51e2"},
- {file = "pyzmq-25.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:e61f091c3ba0c3578411ef505992d356a812fb200643eab27f4f70eed34a29ef"},
- {file = "pyzmq-25.1.1-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ade6d25bb29c4555d718ac6d1443a7386595528c33d6b133b258f65f963bb0f6"},
- {file = "pyzmq-25.1.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0c95ddd4f6e9fca4e9e3afaa4f9df8552f0ba5d1004e89ef0a68e1f1f9807c7"},
- {file = "pyzmq-25.1.1-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48e466162a24daf86f6b5ca72444d2bf39a5e58da5f96370078be67c67adc978"},
- {file = "pyzmq-25.1.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abc719161780932c4e11aaebb203be3d6acc6b38d2f26c0f523b5b59d2fc1996"},
- {file = "pyzmq-25.1.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:1ccf825981640b8c34ae54231b7ed00271822ea1c6d8ba1090ebd4943759abf5"},
- {file = "pyzmq-25.1.1-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c2f20ce161ebdb0091a10c9ca0372e023ce24980d0e1f810f519da6f79c60800"},
- {file = "pyzmq-25.1.1-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:deee9ca4727f53464daf089536e68b13e6104e84a37820a88b0a057b97bba2d2"},
- {file = "pyzmq-25.1.1-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:aa8d6cdc8b8aa19ceb319aaa2b660cdaccc533ec477eeb1309e2a291eaacc43a"},
- {file = "pyzmq-25.1.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:019e59ef5c5256a2c7378f2fb8560fc2a9ff1d315755204295b2eab96b254d0a"},
- {file = "pyzmq-25.1.1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:b9af3757495c1ee3b5c4e945c1df7be95562277c6e5bccc20a39aec50f826cd0"},
- {file = "pyzmq-25.1.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:548d6482dc8aadbe7e79d1b5806585c8120bafa1ef841167bc9090522b610fa6"},
- {file = "pyzmq-25.1.1-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:057e824b2aae50accc0f9a0570998adc021b372478a921506fddd6c02e60308e"},
- {file = "pyzmq-25.1.1-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2243700cc5548cff20963f0ca92d3e5e436394375ab8a354bbea2b12911b20b0"},
- {file = "pyzmq-25.1.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79986f3b4af059777111409ee517da24a529bdbd46da578b33f25580adcff728"},
- {file = "pyzmq-25.1.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:11d58723d44d6ed4dd677c5615b2ffb19d5c426636345567d6af82be4dff8a55"},
- {file = "pyzmq-25.1.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:49d238cf4b69652257db66d0c623cd3e09b5d2e9576b56bc067a396133a00d4a"},
- {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fedbdc753827cf014c01dbbee9c3be17e5a208dcd1bf8641ce2cd29580d1f0d4"},
- {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bc16ac425cc927d0a57d242589f87ee093884ea4804c05a13834d07c20db203c"},
- {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11c1d2aed9079c6b0c9550a7257a836b4a637feb334904610f06d70eb44c56d2"},
- {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e8a701123029cc240cea61dd2d16ad57cab4691804143ce80ecd9286b464d180"},
- {file = "pyzmq-25.1.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:61706a6b6c24bdece85ff177fec393545a3191eeda35b07aaa1458a027ad1304"},
- {file = "pyzmq-25.1.1.tar.gz", hash = "sha256:259c22485b71abacdfa8bf79720cd7bcf4b9d128b30ea554f01ae71fdbfdaa23"},
+ {file = "pyzmq-25.1.2-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:e624c789359f1a16f83f35e2c705d07663ff2b4d4479bad35621178d8f0f6ea4"},
+ {file = "pyzmq-25.1.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:49151b0efece79f6a79d41a461d78535356136ee70084a1c22532fc6383f4ad0"},
+ {file = "pyzmq-25.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9a5f194cf730f2b24d6af1f833c14c10f41023da46a7f736f48b6d35061e76e"},
+ {file = "pyzmq-25.1.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:faf79a302f834d9e8304fafdc11d0d042266667ac45209afa57e5efc998e3872"},
+ {file = "pyzmq-25.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f51a7b4ead28d3fca8dda53216314a553b0f7a91ee8fc46a72b402a78c3e43d"},
+ {file = "pyzmq-25.1.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:0ddd6d71d4ef17ba5a87becf7ddf01b371eaba553c603477679ae817a8d84d75"},
+ {file = "pyzmq-25.1.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:246747b88917e4867e2367b005fc8eefbb4a54b7db363d6c92f89d69abfff4b6"},
+ {file = "pyzmq-25.1.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:00c48ae2fd81e2a50c3485de1b9d5c7c57cd85dc8ec55683eac16846e57ac979"},
+ {file = "pyzmq-25.1.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5a68d491fc20762b630e5db2191dd07ff89834086740f70e978bb2ef2668be08"},
+ {file = "pyzmq-25.1.2-cp310-cp310-win32.whl", hash = "sha256:09dfe949e83087da88c4a76767df04b22304a682d6154de2c572625c62ad6886"},
+ {file = "pyzmq-25.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:fa99973d2ed20417744fca0073390ad65ce225b546febb0580358e36aa90dba6"},
+ {file = "pyzmq-25.1.2-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:82544e0e2d0c1811482d37eef297020a040c32e0687c1f6fc23a75b75db8062c"},
+ {file = "pyzmq-25.1.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:01171fc48542348cd1a360a4b6c3e7d8f46cdcf53a8d40f84db6707a6768acc1"},
+ {file = "pyzmq-25.1.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc69c96735ab501419c432110016329bf0dea8898ce16fab97c6d9106dc0b348"},
+ {file = "pyzmq-25.1.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3e124e6b1dd3dfbeb695435dff0e383256655bb18082e094a8dd1f6293114642"},
+ {file = "pyzmq-25.1.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7598d2ba821caa37a0f9d54c25164a4fa351ce019d64d0b44b45540950458840"},
+ {file = "pyzmq-25.1.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:d1299d7e964c13607efd148ca1f07dcbf27c3ab9e125d1d0ae1d580a1682399d"},
+ {file = "pyzmq-25.1.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4e6f689880d5ad87918430957297c975203a082d9a036cc426648fcbedae769b"},
+ {file = "pyzmq-25.1.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cc69949484171cc961e6ecd4a8911b9ce7a0d1f738fcae717177c231bf77437b"},
+ {file = "pyzmq-25.1.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9880078f683466b7f567b8624bfc16cad65077be046b6e8abb53bed4eeb82dd3"},
+ {file = "pyzmq-25.1.2-cp311-cp311-win32.whl", hash = "sha256:4e5837af3e5aaa99a091302df5ee001149baff06ad22b722d34e30df5f0d9097"},
+ {file = "pyzmq-25.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:25c2dbb97d38b5ac9fd15586e048ec5eb1e38f3d47fe7d92167b0c77bb3584e9"},
+ {file = "pyzmq-25.1.2-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:11e70516688190e9c2db14fcf93c04192b02d457b582a1f6190b154691b4c93a"},
+ {file = "pyzmq-25.1.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:313c3794d650d1fccaaab2df942af9f2c01d6217c846177cfcbc693c7410839e"},
+ {file = "pyzmq-25.1.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b3cbba2f47062b85fe0ef9de5b987612140a9ba3a9c6d2543c6dec9f7c2ab27"},
+ {file = "pyzmq-25.1.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fc31baa0c32a2ca660784d5af3b9487e13b61b3032cb01a115fce6588e1bed30"},
+ {file = "pyzmq-25.1.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02c9087b109070c5ab0b383079fa1b5f797f8d43e9a66c07a4b8b8bdecfd88ee"},
+ {file = "pyzmq-25.1.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:f8429b17cbb746c3e043cb986328da023657e79d5ed258b711c06a70c2ea7537"},
+ {file = "pyzmq-25.1.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:5074adeacede5f810b7ef39607ee59d94e948b4fd954495bdb072f8c54558181"},
+ {file = "pyzmq-25.1.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:7ae8f354b895cbd85212da245f1a5ad8159e7840e37d78b476bb4f4c3f32a9fe"},
+ {file = "pyzmq-25.1.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:b264bf2cc96b5bc43ce0e852be995e400376bd87ceb363822e2cb1964fcdc737"},
+ {file = "pyzmq-25.1.2-cp312-cp312-win32.whl", hash = "sha256:02bbc1a87b76e04fd780b45e7f695471ae6de747769e540da909173d50ff8e2d"},
+ {file = "pyzmq-25.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:ced111c2e81506abd1dc142e6cd7b68dd53747b3b7ae5edbea4578c5eeff96b7"},
+ {file = "pyzmq-25.1.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:7b6d09a8962a91151f0976008eb7b29b433a560fde056ec7a3db9ec8f1075438"},
+ {file = "pyzmq-25.1.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:967668420f36878a3c9ecb5ab33c9d0ff8d054f9c0233d995a6d25b0e95e1b6b"},
+ {file = "pyzmq-25.1.2-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5edac3f57c7ddaacdb4d40f6ef2f9e299471fc38d112f4bc6d60ab9365445fb0"},
+ {file = "pyzmq-25.1.2-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:0dabfb10ef897f3b7e101cacba1437bd3a5032ee667b7ead32bbcdd1a8422fe7"},
+ {file = "pyzmq-25.1.2-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:2c6441e0398c2baacfe5ba30c937d274cfc2dc5b55e82e3749e333aabffde561"},
+ {file = "pyzmq-25.1.2-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:16b726c1f6c2e7625706549f9dbe9b06004dfbec30dbed4bf50cbdfc73e5b32a"},
+ {file = "pyzmq-25.1.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:a86c2dd76ef71a773e70551a07318b8e52379f58dafa7ae1e0a4be78efd1ff16"},
+ {file = "pyzmq-25.1.2-cp36-cp36m-win32.whl", hash = "sha256:359f7f74b5d3c65dae137f33eb2bcfa7ad9ebefd1cab85c935f063f1dbb245cc"},
+ {file = "pyzmq-25.1.2-cp36-cp36m-win_amd64.whl", hash = "sha256:55875492f820d0eb3417b51d96fea549cde77893ae3790fd25491c5754ea2f68"},
+ {file = "pyzmq-25.1.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b8c8a419dfb02e91b453615c69568442e897aaf77561ee0064d789705ff37a92"},
+ {file = "pyzmq-25.1.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8807c87fa893527ae8a524c15fc505d9950d5e856f03dae5921b5e9aa3b8783b"},
+ {file = "pyzmq-25.1.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5e319ed7d6b8f5fad9b76daa0a68497bc6f129858ad956331a5835785761e003"},
+ {file = "pyzmq-25.1.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:3c53687dde4d9d473c587ae80cc328e5b102b517447456184b485587ebd18b62"},
+ {file = "pyzmq-25.1.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:9add2e5b33d2cd765ad96d5eb734a5e795a0755f7fc49aa04f76d7ddda73fd70"},
+ {file = "pyzmq-25.1.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:e690145a8c0c273c28d3b89d6fb32c45e0d9605b2293c10e650265bf5c11cfec"},
+ {file = "pyzmq-25.1.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:00a06faa7165634f0cac1abb27e54d7a0b3b44eb9994530b8ec73cf52e15353b"},
+ {file = "pyzmq-25.1.2-cp37-cp37m-win32.whl", hash = "sha256:0f97bc2f1f13cb16905a5f3e1fbdf100e712d841482b2237484360f8bc4cb3d7"},
+ {file = "pyzmq-25.1.2-cp37-cp37m-win_amd64.whl", hash = "sha256:6cc0020b74b2e410287e5942e1e10886ff81ac77789eb20bec13f7ae681f0fdd"},
+ {file = "pyzmq-25.1.2-cp38-cp38-macosx_10_15_universal2.whl", hash = "sha256:bef02cfcbded83473bdd86dd8d3729cd82b2e569b75844fb4ea08fee3c26ae41"},
+ {file = "pyzmq-25.1.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e10a4b5a4b1192d74853cc71a5e9fd022594573926c2a3a4802020360aa719d8"},
+ {file = "pyzmq-25.1.2-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8c5f80e578427d4695adac6fdf4370c14a2feafdc8cb35549c219b90652536ae"},
+ {file = "pyzmq-25.1.2-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5dde6751e857910c1339890f3524de74007958557593b9e7e8c5f01cd919f8a7"},
+ {file = "pyzmq-25.1.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea1608dd169da230a0ad602d5b1ebd39807ac96cae1845c3ceed39af08a5c6df"},
+ {file = "pyzmq-25.1.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:0f513130c4c361201da9bc69df25a086487250e16b5571ead521b31ff6b02220"},
+ {file = "pyzmq-25.1.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:019744b99da30330798bb37df33549d59d380c78e516e3bab9c9b84f87a9592f"},
+ {file = "pyzmq-25.1.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2e2713ef44be5d52dd8b8e2023d706bf66cb22072e97fc71b168e01d25192755"},
+ {file = "pyzmq-25.1.2-cp38-cp38-win32.whl", hash = "sha256:07cd61a20a535524906595e09344505a9bd46f1da7a07e504b315d41cd42eb07"},
+ {file = "pyzmq-25.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:eb7e49a17fb8c77d3119d41a4523e432eb0c6932187c37deb6fbb00cc3028088"},
+ {file = "pyzmq-25.1.2-cp39-cp39-macosx_10_15_universal2.whl", hash = "sha256:94504ff66f278ab4b7e03e4cba7e7e400cb73bfa9d3d71f58d8972a8dc67e7a6"},
+ {file = "pyzmq-25.1.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6dd0d50bbf9dca1d0bdea219ae6b40f713a3fb477c06ca3714f208fd69e16fd8"},
+ {file = "pyzmq-25.1.2-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:004ff469d21e86f0ef0369717351073e0e577428e514c47c8480770d5e24a565"},
+ {file = "pyzmq-25.1.2-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c0b5ca88a8928147b7b1e2dfa09f3b6c256bc1135a1338536cbc9ea13d3b7add"},
+ {file = "pyzmq-25.1.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c9a79f1d2495b167119d02be7448bfba57fad2a4207c4f68abc0bab4b92925b"},
+ {file = "pyzmq-25.1.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:518efd91c3d8ac9f9b4f7dd0e2b7b8bf1a4fe82a308009016b07eaa48681af82"},
+ {file = "pyzmq-25.1.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:1ec23bd7b3a893ae676d0e54ad47d18064e6c5ae1fadc2f195143fb27373f7f6"},
+ {file = "pyzmq-25.1.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db36c27baed588a5a8346b971477b718fdc66cf5b80cbfbd914b4d6d355e44e2"},
+ {file = "pyzmq-25.1.2-cp39-cp39-win32.whl", hash = "sha256:39b1067f13aba39d794a24761e385e2eddc26295826530a8c7b6c6c341584289"},
+ {file = "pyzmq-25.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:8e9f3fabc445d0ce320ea2c59a75fe3ea591fdbdeebec5db6de530dd4b09412e"},
+ {file = "pyzmq-25.1.2-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a8c1d566344aee826b74e472e16edae0a02e2a044f14f7c24e123002dcff1c05"},
+ {file = "pyzmq-25.1.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:759cfd391a0996345ba94b6a5110fca9c557ad4166d86a6e81ea526c376a01e8"},
+ {file = "pyzmq-25.1.2-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c61e346ac34b74028ede1c6b4bcecf649d69b707b3ff9dc0fab453821b04d1e"},
+ {file = "pyzmq-25.1.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4cb8fc1f8d69b411b8ec0b5f1ffbcaf14c1db95b6bccea21d83610987435f1a4"},
+ {file = "pyzmq-25.1.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:3c00c9b7d1ca8165c610437ca0c92e7b5607b2f9076f4eb4b095c85d6e680a1d"},
+ {file = "pyzmq-25.1.2-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:df0c7a16ebb94452d2909b9a7b3337940e9a87a824c4fc1c7c36bb4404cb0cde"},
+ {file = "pyzmq-25.1.2-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:45999e7f7ed5c390f2e87ece7f6c56bf979fb213550229e711e45ecc7d42ccb8"},
+ {file = "pyzmq-25.1.2-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ac170e9e048b40c605358667aca3d94e98f604a18c44bdb4c102e67070f3ac9b"},
+ {file = "pyzmq-25.1.2-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1b604734bec94f05f81b360a272fc824334267426ae9905ff32dc2be433ab96"},
+ {file = "pyzmq-25.1.2-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:a793ac733e3d895d96f865f1806f160696422554e46d30105807fdc9841b9f7d"},
+ {file = "pyzmq-25.1.2-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0806175f2ae5ad4b835ecd87f5f85583316b69f17e97786f7443baaf54b9bb98"},
+ {file = "pyzmq-25.1.2-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:ef12e259e7bc317c7597d4f6ef59b97b913e162d83b421dd0db3d6410f17a244"},
+ {file = "pyzmq-25.1.2-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ea253b368eb41116011add00f8d5726762320b1bda892f744c91997b65754d73"},
+ {file = "pyzmq-25.1.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b9b1f2ad6498445a941d9a4fee096d387fee436e45cc660e72e768d3d8ee611"},
+ {file = "pyzmq-25.1.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:8b14c75979ce932c53b79976a395cb2a8cd3aaf14aef75e8c2cb55a330b9b49d"},
+ {file = "pyzmq-25.1.2-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:889370d5174a741a62566c003ee8ddba4b04c3f09a97b8000092b7ca83ec9c49"},
+ {file = "pyzmq-25.1.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9a18fff090441a40ffda8a7f4f18f03dc56ae73f148f1832e109f9bffa85df15"},
+ {file = "pyzmq-25.1.2-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:99a6b36f95c98839ad98f8c553d8507644c880cf1e0a57fe5e3a3f3969040882"},
+ {file = "pyzmq-25.1.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4345c9a27f4310afbb9c01750e9461ff33d6fb74cd2456b107525bbeebcb5be3"},
+ {file = "pyzmq-25.1.2-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:3516e0b6224cf6e43e341d56da15fd33bdc37fa0c06af4f029f7d7dfceceabbc"},
+ {file = "pyzmq-25.1.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:146b9b1f29ead41255387fb07be56dc29639262c0f7344f570eecdcd8d683314"},
+ {file = "pyzmq-25.1.2.tar.gz", hash = "sha256:93f1aa311e8bb912e34f004cf186407a4e90eec4f0ecc0efd26056bf7eda0226"},
]
[package.dependencies]
@@ -3353,18 +3471,17 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""}
[[package]]
name = "qtconsole"
-version = "5.4.4"
+version = "5.5.1"
description = "Jupyter Qt console"
optional = false
-python-versions = ">= 3.7"
+python-versions = ">= 3.8"
files = [
- {file = "qtconsole-5.4.4-py3-none-any.whl", hash = "sha256:a3b69b868e041c2c698bdc75b0602f42e130ffb256d6efa48f9aa756c97672aa"},
- {file = "qtconsole-5.4.4.tar.gz", hash = "sha256:b7ffb53d74f23cee29f4cdb55dd6fabc8ec312d94f3c46ba38e1dde458693dfb"},
+ {file = "qtconsole-5.5.1-py3-none-any.whl", hash = "sha256:8c75fa3e9b4ed884880ff7cea90a1b67451219279ec33deaee1d59e3df1a5d2b"},
+ {file = "qtconsole-5.5.1.tar.gz", hash = "sha256:a0e806c6951db9490628e4df80caec9669b65149c7ba40f9bf033c025a5b56bc"},
]
[package.dependencies]
ipykernel = ">=4.1"
-ipython-genutils = "*"
jupyter-client = ">=4.1"
jupyter-core = "*"
packaging = "*"
@@ -3396,13 +3513,13 @@ test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"]
[[package]]
name = "referencing"
-version = "0.30.2"
+version = "0.34.0"
description = "JSON Referencing + Python"
optional = false
python-versions = ">=3.8"
files = [
- {file = "referencing-0.30.2-py3-none-any.whl", hash = "sha256:449b6669b6121a9e96a7f9e410b245d471e8d48964c67113ce9afe50c8dd7bdf"},
- {file = "referencing-0.30.2.tar.gz", hash = "sha256:794ad8003c65938edcdbc027f1933215e0d0ccc0291e3ce20a4d87432b59efc0"},
+ {file = "referencing-0.34.0-py3-none-any.whl", hash = "sha256:d53ae300ceddd3169f1ffa9caf2cb7b769e92657e4fafb23d34b93679116dfd4"},
+ {file = "referencing-0.34.0.tar.gz", hash = "sha256:5773bd84ef41799a5a8ca72dc34590c041eb01bf9aa02632b4a973fb0181a844"},
]
[package.dependencies]
@@ -3411,99 +3528,104 @@ rpds-py = ">=0.7.0"
[[package]]
name = "regex"
-version = "2023.10.3"
+version = "2023.12.25"
description = "Alternative regular expression module, to replace re."
optional = false
python-versions = ">=3.7"
files = [
- {file = "regex-2023.10.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4c34d4f73ea738223a094d8e0ffd6d2c1a1b4c175da34d6b0de3d8d69bee6bcc"},
- {file = "regex-2023.10.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a8f4e49fc3ce020f65411432183e6775f24e02dff617281094ba6ab079ef0915"},
- {file = "regex-2023.10.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4cd1bccf99d3ef1ab6ba835308ad85be040e6a11b0977ef7ea8c8005f01a3c29"},
- {file = "regex-2023.10.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:81dce2ddc9f6e8f543d94b05d56e70d03a0774d32f6cca53e978dc01e4fc75b8"},
- {file = "regex-2023.10.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c6b4d23c04831e3ab61717a707a5d763b300213db49ca680edf8bf13ab5d91b"},
- {file = "regex-2023.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c15ad0aee158a15e17e0495e1e18741573d04eb6da06d8b84af726cfc1ed02ee"},
- {file = "regex-2023.10.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6239d4e2e0b52c8bd38c51b760cd870069f0bdf99700a62cd509d7a031749a55"},
- {file = "regex-2023.10.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4a8bf76e3182797c6b1afa5b822d1d5802ff30284abe4599e1247be4fd6b03be"},
- {file = "regex-2023.10.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d9c727bbcf0065cbb20f39d2b4f932f8fa1631c3e01fcedc979bd4f51fe051c5"},
- {file = "regex-2023.10.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:3ccf2716add72f80714b9a63899b67fa711b654be3fcdd34fa391d2d274ce767"},
- {file = "regex-2023.10.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:107ac60d1bfdc3edb53be75e2a52aff7481b92817cfdddd9b4519ccf0e54a6ff"},
- {file = "regex-2023.10.3-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:00ba3c9818e33f1fa974693fb55d24cdc8ebafcb2e4207680669d8f8d7cca79a"},
- {file = "regex-2023.10.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f0a47efb1dbef13af9c9a54a94a0b814902e547b7f21acb29434504d18f36e3a"},
- {file = "regex-2023.10.3-cp310-cp310-win32.whl", hash = "sha256:36362386b813fa6c9146da6149a001b7bd063dabc4d49522a1f7aa65b725c7ec"},
- {file = "regex-2023.10.3-cp310-cp310-win_amd64.whl", hash = "sha256:c65a3b5330b54103e7d21cac3f6bf3900d46f6d50138d73343d9e5b2900b2353"},
- {file = "regex-2023.10.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:90a79bce019c442604662d17bf69df99090e24cdc6ad95b18b6725c2988a490e"},
- {file = "regex-2023.10.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c7964c2183c3e6cce3f497e3a9f49d182e969f2dc3aeeadfa18945ff7bdd7051"},
- {file = "regex-2023.10.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ef80829117a8061f974b2fda8ec799717242353bff55f8a29411794d635d964"},
- {file = "regex-2023.10.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5addc9d0209a9afca5fc070f93b726bf7003bd63a427f65ef797a931782e7edc"},
- {file = "regex-2023.10.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c148bec483cc4b421562b4bcedb8e28a3b84fcc8f0aa4418e10898f3c2c0eb9b"},
- {file = "regex-2023.10.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d1f21af4c1539051049796a0f50aa342f9a27cde57318f2fc41ed50b0dbc4ac"},
- {file = "regex-2023.10.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0b9ac09853b2a3e0d0082104036579809679e7715671cfbf89d83c1cb2a30f58"},
- {file = "regex-2023.10.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ebedc192abbc7fd13c5ee800e83a6df252bec691eb2c4bedc9f8b2e2903f5e2a"},
- {file = "regex-2023.10.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d8a993c0a0ffd5f2d3bda23d0cd75e7086736f8f8268de8a82fbc4bd0ac6791e"},
- {file = "regex-2023.10.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:be6b7b8d42d3090b6c80793524fa66c57ad7ee3fe9722b258aec6d0672543fd0"},
- {file = "regex-2023.10.3-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4023e2efc35a30e66e938de5aef42b520c20e7eda7bb5fb12c35e5d09a4c43f6"},
- {file = "regex-2023.10.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0d47840dc05e0ba04fe2e26f15126de7c755496d5a8aae4a08bda4dd8d646c54"},
- {file = "regex-2023.10.3-cp311-cp311-win32.whl", hash = "sha256:9145f092b5d1977ec8c0ab46e7b3381b2fd069957b9862a43bd383e5c01d18c2"},
- {file = "regex-2023.10.3-cp311-cp311-win_amd64.whl", hash = "sha256:b6104f9a46bd8743e4f738afef69b153c4b8b592d35ae46db07fc28ae3d5fb7c"},
- {file = "regex-2023.10.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:bff507ae210371d4b1fe316d03433ac099f184d570a1a611e541923f78f05037"},
- {file = "regex-2023.10.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:be5e22bbb67924dea15039c3282fa4cc6cdfbe0cbbd1c0515f9223186fc2ec5f"},
- {file = "regex-2023.10.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a992f702c9be9c72fa46f01ca6e18d131906a7180950958f766c2aa294d4b41"},
- {file = "regex-2023.10.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7434a61b158be563c1362d9071358f8ab91b8d928728cd2882af060481244c9e"},
- {file = "regex-2023.10.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c2169b2dcabf4e608416f7f9468737583ce5f0a6e8677c4efbf795ce81109d7c"},
- {file = "regex-2023.10.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9e908ef5889cda4de038892b9accc36d33d72fb3e12c747e2799a0e806ec841"},
- {file = "regex-2023.10.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:12bd4bc2c632742c7ce20db48e0d99afdc05e03f0b4c1af90542e05b809a03d9"},
- {file = "regex-2023.10.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:bc72c231f5449d86d6c7d9cc7cd819b6eb30134bb770b8cfdc0765e48ef9c420"},
- {file = "regex-2023.10.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bce8814b076f0ce5766dc87d5a056b0e9437b8e0cd351b9a6c4e1134a7dfbda9"},
- {file = "regex-2023.10.3-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:ba7cd6dc4d585ea544c1412019921570ebd8a597fabf475acc4528210d7c4a6f"},
- {file = "regex-2023.10.3-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b0c7d2f698e83f15228ba41c135501cfe7d5740181d5903e250e47f617eb4292"},
- {file = "regex-2023.10.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5a8f91c64f390ecee09ff793319f30a0f32492e99f5dc1c72bc361f23ccd0a9a"},
- {file = "regex-2023.10.3-cp312-cp312-win32.whl", hash = "sha256:ad08a69728ff3c79866d729b095872afe1e0557251da4abb2c5faff15a91d19a"},
- {file = "regex-2023.10.3-cp312-cp312-win_amd64.whl", hash = "sha256:39cdf8d141d6d44e8d5a12a8569d5a227f645c87df4f92179bd06e2e2705e76b"},
- {file = "regex-2023.10.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4a3ee019a9befe84fa3e917a2dd378807e423d013377a884c1970a3c2792d293"},
- {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76066d7ff61ba6bf3cb5efe2428fc82aac91802844c022d849a1f0f53820502d"},
- {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bfe50b61bab1b1ec260fa7cd91106fa9fece57e6beba05630afe27c71259c59b"},
- {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9fd88f373cb71e6b59b7fa597e47e518282455c2734fd4306a05ca219a1991b0"},
- {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3ab05a182c7937fb374f7e946f04fb23a0c0699c0450e9fb02ef567412d2fa3"},
- {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dac37cf08fcf2094159922edc7a2784cfcc5c70f8354469f79ed085f0328ebdf"},
- {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e54ddd0bb8fb626aa1f9ba7b36629564544954fff9669b15da3610c22b9a0991"},
- {file = "regex-2023.10.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3367007ad1951fde612bf65b0dffc8fd681a4ab98ac86957d16491400d661302"},
- {file = "regex-2023.10.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:16f8740eb6dbacc7113e3097b0a36065a02e37b47c936b551805d40340fb9971"},
- {file = "regex-2023.10.3-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:f4f2ca6df64cbdd27f27b34f35adb640b5d2d77264228554e68deda54456eb11"},
- {file = "regex-2023.10.3-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:39807cbcbe406efca2a233884e169d056c35aa7e9f343d4e78665246a332f597"},
- {file = "regex-2023.10.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:7eece6fbd3eae4a92d7c748ae825cbc1ee41a89bb1c3db05b5578ed3cfcfd7cb"},
- {file = "regex-2023.10.3-cp37-cp37m-win32.whl", hash = "sha256:ce615c92d90df8373d9e13acddd154152645c0dc060871abf6bd43809673d20a"},
- {file = "regex-2023.10.3-cp37-cp37m-win_amd64.whl", hash = "sha256:0f649fa32fe734c4abdfd4edbb8381c74abf5f34bc0b3271ce687b23729299ed"},
- {file = "regex-2023.10.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9b98b7681a9437262947f41c7fac567c7e1f6eddd94b0483596d320092004533"},
- {file = "regex-2023.10.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:91dc1d531f80c862441d7b66c4505cd6ea9d312f01fb2f4654f40c6fdf5cc37a"},
- {file = "regex-2023.10.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82fcc1f1cc3ff1ab8a57ba619b149b907072e750815c5ba63e7aa2e1163384a4"},
- {file = "regex-2023.10.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7979b834ec7a33aafae34a90aad9f914c41fd6eaa8474e66953f3f6f7cbd4368"},
- {file = "regex-2023.10.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ef71561f82a89af6cfcbee47f0fabfdb6e63788a9258e913955d89fdd96902ab"},
- {file = "regex-2023.10.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd829712de97753367153ed84f2de752b86cd1f7a88b55a3a775eb52eafe8a94"},
- {file = "regex-2023.10.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:00e871d83a45eee2f8688d7e6849609c2ca2a04a6d48fba3dff4deef35d14f07"},
- {file = "regex-2023.10.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:706e7b739fdd17cb89e1fbf712d9dc21311fc2333f6d435eac2d4ee81985098c"},
- {file = "regex-2023.10.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:cc3f1c053b73f20c7ad88b0d1d23be7e7b3901229ce89f5000a8399746a6e039"},
- {file = "regex-2023.10.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6f85739e80d13644b981a88f529d79c5bdf646b460ba190bffcaf6d57b2a9863"},
- {file = "regex-2023.10.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:741ba2f511cc9626b7561a440f87d658aabb3d6b744a86a3c025f866b4d19e7f"},
- {file = "regex-2023.10.3-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:e77c90ab5997e85901da85131fd36acd0ed2221368199b65f0d11bca44549711"},
- {file = "regex-2023.10.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:979c24cbefaf2420c4e377ecd1f165ea08cc3d1fbb44bdc51bccbbf7c66a2cb4"},
- {file = "regex-2023.10.3-cp38-cp38-win32.whl", hash = "sha256:58837f9d221744d4c92d2cf7201c6acd19623b50c643b56992cbd2b745485d3d"},
- {file = "regex-2023.10.3-cp38-cp38-win_amd64.whl", hash = "sha256:c55853684fe08d4897c37dfc5faeff70607a5f1806c8be148f1695be4a63414b"},
- {file = "regex-2023.10.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2c54e23836650bdf2c18222c87f6f840d4943944146ca479858404fedeb9f9af"},
- {file = "regex-2023.10.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:69c0771ca5653c7d4b65203cbfc5e66db9375f1078689459fe196fe08b7b4930"},
- {file = "regex-2023.10.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ac965a998e1388e6ff2e9781f499ad1eaa41e962a40d11c7823c9952c77123e"},
- {file = "regex-2023.10.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1c0e8fae5b27caa34177bdfa5a960c46ff2f78ee2d45c6db15ae3f64ecadde14"},
- {file = "regex-2023.10.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6c56c3d47da04f921b73ff9415fbaa939f684d47293f071aa9cbb13c94afc17d"},
- {file = "regex-2023.10.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ef1e014eed78ab650bef9a6a9cbe50b052c0aebe553fb2881e0453717573f52"},
- {file = "regex-2023.10.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d29338556a59423d9ff7b6eb0cb89ead2b0875e08fe522f3e068b955c3e7b59b"},
- {file = "regex-2023.10.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:9c6d0ced3c06d0f183b73d3c5920727268d2201aa0fe6d55c60d68c792ff3588"},
- {file = "regex-2023.10.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:994645a46c6a740ee8ce8df7911d4aee458d9b1bc5639bc968226763d07f00fa"},
- {file = "regex-2023.10.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:66e2fe786ef28da2b28e222c89502b2af984858091675044d93cb50e6f46d7af"},
- {file = "regex-2023.10.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:11175910f62b2b8c055f2b089e0fedd694fe2be3941b3e2633653bc51064c528"},
- {file = "regex-2023.10.3-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:06e9abc0e4c9ab4779c74ad99c3fc10d3967d03114449acc2c2762ad4472b8ca"},
- {file = "regex-2023.10.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:fb02e4257376ae25c6dd95a5aec377f9b18c09be6ebdefa7ad209b9137b73d48"},
- {file = "regex-2023.10.3-cp39-cp39-win32.whl", hash = "sha256:3b2c3502603fab52d7619b882c25a6850b766ebd1b18de3df23b2f939360e1bd"},
- {file = "regex-2023.10.3-cp39-cp39-win_amd64.whl", hash = "sha256:adbccd17dcaff65704c856bd29951c58a1bd4b2b0f8ad6b826dbd543fe740988"},
- {file = "regex-2023.10.3.tar.gz", hash = "sha256:3fef4f844d2290ee0ba57addcec17eec9e3df73f10a2748485dfd6a3a188cc0f"},
+ {file = "regex-2023.12.25-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0694219a1d54336fd0445ea382d49d36882415c0134ee1e8332afd1529f0baa5"},
+ {file = "regex-2023.12.25-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b014333bd0217ad3d54c143de9d4b9a3ca1c5a29a6d0d554952ea071cff0f1f8"},
+ {file = "regex-2023.12.25-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d865984b3f71f6d0af64d0d88f5733521698f6c16f445bb09ce746c92c97c586"},
+ {file = "regex-2023.12.25-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1e0eabac536b4cc7f57a5f3d095bfa557860ab912f25965e08fe1545e2ed8b4c"},
+ {file = "regex-2023.12.25-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c25a8ad70e716f96e13a637802813f65d8a6760ef48672aa3502f4c24ea8b400"},
+ {file = "regex-2023.12.25-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a9b6d73353f777630626f403b0652055ebfe8ff142a44ec2cf18ae470395766e"},
+ {file = "regex-2023.12.25-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9cc99d6946d750eb75827cb53c4371b8b0fe89c733a94b1573c9dd16ea6c9e4"},
+ {file = "regex-2023.12.25-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:88d1f7bef20c721359d8675f7d9f8e414ec5003d8f642fdfd8087777ff7f94b5"},
+ {file = "regex-2023.12.25-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:cb3fe77aec8f1995611f966d0c656fdce398317f850d0e6e7aebdfe61f40e1cd"},
+ {file = "regex-2023.12.25-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7aa47c2e9ea33a4a2a05f40fcd3ea36d73853a2aae7b4feab6fc85f8bf2c9704"},
+ {file = "regex-2023.12.25-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:df26481f0c7a3f8739fecb3e81bc9da3fcfae34d6c094563b9d4670b047312e1"},
+ {file = "regex-2023.12.25-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:c40281f7d70baf6e0db0c2f7472b31609f5bc2748fe7275ea65a0b4601d9b392"},
+ {file = "regex-2023.12.25-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:d94a1db462d5690ebf6ae86d11c5e420042b9898af5dcf278bd97d6bda065423"},
+ {file = "regex-2023.12.25-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ba1b30765a55acf15dce3f364e4928b80858fa8f979ad41f862358939bdd1f2f"},
+ {file = "regex-2023.12.25-cp310-cp310-win32.whl", hash = "sha256:150c39f5b964e4d7dba46a7962a088fbc91f06e606f023ce57bb347a3b2d4630"},
+ {file = "regex-2023.12.25-cp310-cp310-win_amd64.whl", hash = "sha256:09da66917262d9481c719599116c7dc0c321ffcec4b1f510c4f8a066f8768105"},
+ {file = "regex-2023.12.25-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:1b9d811f72210fa9306aeb88385b8f8bcef0dfbf3873410413c00aa94c56c2b6"},
+ {file = "regex-2023.12.25-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d902a43085a308cef32c0d3aea962524b725403fd9373dea18110904003bac97"},
+ {file = "regex-2023.12.25-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d166eafc19f4718df38887b2bbe1467a4f74a9830e8605089ea7a30dd4da8887"},
+ {file = "regex-2023.12.25-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c7ad32824b7f02bb3c9f80306d405a1d9b7bb89362d68b3c5a9be53836caebdb"},
+ {file = "regex-2023.12.25-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:636ba0a77de609d6510235b7f0e77ec494d2657108f777e8765efc060094c98c"},
+ {file = "regex-2023.12.25-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fda75704357805eb953a3ee15a2b240694a9a514548cd49b3c5124b4e2ad01b"},
+ {file = "regex-2023.12.25-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f72cbae7f6b01591f90814250e636065850c5926751af02bb48da94dfced7baa"},
+ {file = "regex-2023.12.25-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:db2a0b1857f18b11e3b0e54ddfefc96af46b0896fb678c85f63fb8c37518b3e7"},
+ {file = "regex-2023.12.25-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:7502534e55c7c36c0978c91ba6f61703faf7ce733715ca48f499d3dbbd7657e0"},
+ {file = "regex-2023.12.25-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:e8c7e08bb566de4faaf11984af13f6bcf6a08f327b13631d41d62592681d24fe"},
+ {file = "regex-2023.12.25-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:283fc8eed679758de38fe493b7d7d84a198b558942b03f017b1f94dda8efae80"},
+ {file = "regex-2023.12.25-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:f44dd4d68697559d007462b0a3a1d9acd61d97072b71f6d1968daef26bc744bd"},
+ {file = "regex-2023.12.25-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:67d3ccfc590e5e7197750fcb3a2915b416a53e2de847a728cfa60141054123d4"},
+ {file = "regex-2023.12.25-cp311-cp311-win32.whl", hash = "sha256:68191f80a9bad283432385961d9efe09d783bcd36ed35a60fb1ff3f1ec2efe87"},
+ {file = "regex-2023.12.25-cp311-cp311-win_amd64.whl", hash = "sha256:7d2af3f6b8419661a0c421584cfe8aaec1c0e435ce7e47ee2a97e344b98f794f"},
+ {file = "regex-2023.12.25-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:8a0ccf52bb37d1a700375a6b395bff5dd15c50acb745f7db30415bae3c2b0715"},
+ {file = "regex-2023.12.25-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c3c4a78615b7762740531c27cf46e2f388d8d727d0c0c739e72048beb26c8a9d"},
+ {file = "regex-2023.12.25-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ad83e7545b4ab69216cef4cc47e344d19622e28aabec61574b20257c65466d6a"},
+ {file = "regex-2023.12.25-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7a635871143661feccce3979e1727c4e094f2bdfd3ec4b90dfd4f16f571a87a"},
+ {file = "regex-2023.12.25-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d498eea3f581fbe1b34b59c697512a8baef88212f92e4c7830fcc1499f5b45a5"},
+ {file = "regex-2023.12.25-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:43f7cd5754d02a56ae4ebb91b33461dc67be8e3e0153f593c509e21d219c5060"},
+ {file = "regex-2023.12.25-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51f4b32f793812714fd5307222a7f77e739b9bc566dc94a18126aba3b92b98a3"},
+ {file = "regex-2023.12.25-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba99d8077424501b9616b43a2d208095746fb1284fc5ba490139651f971d39d9"},
+ {file = "regex-2023.12.25-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:4bfc2b16e3ba8850e0e262467275dd4d62f0d045e0e9eda2bc65078c0110a11f"},
+ {file = "regex-2023.12.25-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8c2c19dae8a3eb0ea45a8448356ed561be843b13cbc34b840922ddf565498c1c"},
+ {file = "regex-2023.12.25-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:60080bb3d8617d96f0fb7e19796384cc2467447ef1c491694850ebd3670bc457"},
+ {file = "regex-2023.12.25-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b77e27b79448e34c2c51c09836033056a0547aa360c45eeeb67803da7b0eedaf"},
+ {file = "regex-2023.12.25-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:518440c991f514331f4850a63560321f833979d145d7d81186dbe2f19e27ae3d"},
+ {file = "regex-2023.12.25-cp312-cp312-win32.whl", hash = "sha256:e2610e9406d3b0073636a3a2e80db05a02f0c3169b5632022b4e81c0364bcda5"},
+ {file = "regex-2023.12.25-cp312-cp312-win_amd64.whl", hash = "sha256:cc37b9aeebab425f11f27e5e9e6cf580be7206c6582a64467a14dda211abc232"},
+ {file = "regex-2023.12.25-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:da695d75ac97cb1cd725adac136d25ca687da4536154cdc2815f576e4da11c69"},
+ {file = "regex-2023.12.25-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d126361607b33c4eb7b36debc173bf25d7805847346dd4d99b5499e1fef52bc7"},
+ {file = "regex-2023.12.25-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4719bb05094d7d8563a450cf8738d2e1061420f79cfcc1fa7f0a44744c4d8f73"},
+ {file = "regex-2023.12.25-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5dd58946bce44b53b06d94aa95560d0b243eb2fe64227cba50017a8d8b3cd3e2"},
+ {file = "regex-2023.12.25-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22a86d9fff2009302c440b9d799ef2fe322416d2d58fc124b926aa89365ec482"},
+ {file = "regex-2023.12.25-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2aae8101919e8aa05ecfe6322b278f41ce2994c4a430303c4cd163fef746e04f"},
+ {file = "regex-2023.12.25-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e692296c4cc2873967771345a876bcfc1c547e8dd695c6b89342488b0ea55cd8"},
+ {file = "regex-2023.12.25-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:263ef5cc10979837f243950637fffb06e8daed7f1ac1e39d5910fd29929e489a"},
+ {file = "regex-2023.12.25-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:d6f7e255e5fa94642a0724e35406e6cb7001c09d476ab5fce002f652b36d0c39"},
+ {file = "regex-2023.12.25-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:88ad44e220e22b63b0f8f81f007e8abbb92874d8ced66f32571ef8beb0643b2b"},
+ {file = "regex-2023.12.25-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:3a17d3ede18f9cedcbe23d2daa8a2cd6f59fe2bf082c567e43083bba3fb00347"},
+ {file = "regex-2023.12.25-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d15b274f9e15b1a0b7a45d2ac86d1f634d983ca40d6b886721626c47a400bf39"},
+ {file = "regex-2023.12.25-cp37-cp37m-win32.whl", hash = "sha256:ed19b3a05ae0c97dd8f75a5d8f21f7723a8c33bbc555da6bbe1f96c470139d3c"},
+ {file = "regex-2023.12.25-cp37-cp37m-win_amd64.whl", hash = "sha256:a6d1047952c0b8104a1d371f88f4ab62e6275567d4458c1e26e9627ad489b445"},
+ {file = "regex-2023.12.25-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:b43523d7bc2abd757119dbfb38af91b5735eea45537ec6ec3a5ec3f9562a1c53"},
+ {file = "regex-2023.12.25-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:efb2d82f33b2212898f1659fb1c2e9ac30493ac41e4d53123da374c3b5541e64"},
+ {file = "regex-2023.12.25-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b7fca9205b59c1a3d5031f7e64ed627a1074730a51c2a80e97653e3e9fa0d415"},
+ {file = "regex-2023.12.25-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:086dd15e9435b393ae06f96ab69ab2d333f5d65cbe65ca5a3ef0ec9564dfe770"},
+ {file = "regex-2023.12.25-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e81469f7d01efed9b53740aedd26085f20d49da65f9c1f41e822a33992cb1590"},
+ {file = "regex-2023.12.25-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:34e4af5b27232f68042aa40a91c3b9bb4da0eeb31b7632e0091afc4310afe6cb"},
+ {file = "regex-2023.12.25-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9852b76ab558e45b20bf1893b59af64a28bd3820b0c2efc80e0a70a4a3ea51c1"},
+ {file = "regex-2023.12.25-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ff100b203092af77d1a5a7abe085b3506b7eaaf9abf65b73b7d6905b6cb76988"},
+ {file = "regex-2023.12.25-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:cc038b2d8b1470364b1888a98fd22d616fba2b6309c5b5f181ad4483e0017861"},
+ {file = "regex-2023.12.25-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:094ba386bb5c01e54e14434d4caabf6583334090865b23ef58e0424a6286d3dc"},
+ {file = "regex-2023.12.25-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5cd05d0f57846d8ba4b71d9c00f6f37d6b97d5e5ef8b3c3840426a475c8f70f4"},
+ {file = "regex-2023.12.25-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:9aa1a67bbf0f957bbe096375887b2505f5d8ae16bf04488e8b0f334c36e31360"},
+ {file = "regex-2023.12.25-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:98a2636994f943b871786c9e82bfe7883ecdaba2ef5df54e1450fa9869d1f756"},
+ {file = "regex-2023.12.25-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:37f8e93a81fc5e5bd8db7e10e62dc64261bcd88f8d7e6640aaebe9bc180d9ce2"},
+ {file = "regex-2023.12.25-cp38-cp38-win32.whl", hash = "sha256:d78bd484930c1da2b9679290a41cdb25cc127d783768a0369d6b449e72f88beb"},
+ {file = "regex-2023.12.25-cp38-cp38-win_amd64.whl", hash = "sha256:b521dcecebc5b978b447f0f69b5b7f3840eac454862270406a39837ffae4e697"},
+ {file = "regex-2023.12.25-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:f7bc09bc9c29ebead055bcba136a67378f03d66bf359e87d0f7c759d6d4ffa31"},
+ {file = "regex-2023.12.25-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e14b73607d6231f3cc4622809c196b540a6a44e903bcfad940779c80dffa7be7"},
+ {file = "regex-2023.12.25-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9eda5f7a50141291beda3edd00abc2d4a5b16c29c92daf8d5bd76934150f3edc"},
+ {file = "regex-2023.12.25-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc6bb9aa69aacf0f6032c307da718f61a40cf970849e471254e0e91c56ffca95"},
+ {file = "regex-2023.12.25-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:298dc6354d414bc921581be85695d18912bea163a8b23cac9a2562bbcd5088b1"},
+ {file = "regex-2023.12.25-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2f4e475a80ecbd15896a976aa0b386c5525d0ed34d5c600b6d3ebac0a67c7ddf"},
+ {file = "regex-2023.12.25-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:531ac6cf22b53e0696f8e1d56ce2396311254eb806111ddd3922c9d937151dae"},
+ {file = "regex-2023.12.25-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22f3470f7524b6da61e2020672df2f3063676aff444db1daa283c2ea4ed259d6"},
+ {file = "regex-2023.12.25-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:89723d2112697feaa320c9d351e5f5e7b841e83f8b143dba8e2d2b5f04e10923"},
+ {file = "regex-2023.12.25-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0ecf44ddf9171cd7566ef1768047f6e66975788258b1c6c6ca78098b95cf9a3d"},
+ {file = "regex-2023.12.25-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:905466ad1702ed4acfd67a902af50b8db1feeb9781436372261808df7a2a7bca"},
+ {file = "regex-2023.12.25-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:4558410b7a5607a645e9804a3e9dd509af12fb72b9825b13791a37cd417d73a5"},
+ {file = "regex-2023.12.25-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:7e316026cc1095f2a3e8cc012822c99f413b702eaa2ca5408a513609488cb62f"},
+ {file = "regex-2023.12.25-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3b1de218d5375cd6ac4b5493e0b9f3df2be331e86520f23382f216c137913d20"},
+ {file = "regex-2023.12.25-cp39-cp39-win32.whl", hash = "sha256:11a963f8e25ab5c61348d090bf1b07f1953929c13bd2309a0662e9ff680763c9"},
+ {file = "regex-2023.12.25-cp39-cp39-win_amd64.whl", hash = "sha256:e693e233ac92ba83a87024e1d32b5f9ab15ca55ddd916d878146f4e3406b5c91"},
+ {file = "regex-2023.12.25.tar.gz", hash = "sha256:29171aa128da69afdf4bde412d5bedc335f2ca8fcfe4489038577d05f16181e5"},
]
[[package]]
@@ -3554,13 +3676,13 @@ files = [
[[package]]
name = "rich"
-version = "13.6.0"
+version = "13.7.1"
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
optional = false
python-versions = ">=3.7.0"
files = [
- {file = "rich-13.6.0-py3-none-any.whl", hash = "sha256:2b38e2fe9ca72c9a00170a1a2d20c63c790d0e10ef1fe35eba76e1e7b1d7d245"},
- {file = "rich-13.6.0.tar.gz", hash = "sha256:5c14d22737e6d5084ef4771b62d5d4363165b403455a30a1c8ca39dc7b644bef"},
+ {file = "rich-13.7.1-py3-none-any.whl", hash = "sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222"},
+ {file = "rich-13.7.1.tar.gz", hash = "sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432"},
]
[package.dependencies]
@@ -3573,223 +3695,236 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"]
[[package]]
name = "rpds-py"
-version = "0.10.6"
+version = "0.18.0"
description = "Python bindings to Rust's persistent data structures (rpds)"
optional = false
python-versions = ">=3.8"
files = [
- {file = "rpds_py-0.10.6-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:6bdc11f9623870d75692cc33c59804b5a18d7b8a4b79ef0b00b773a27397d1f6"},
- {file = "rpds_py-0.10.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:26857f0f44f0e791f4a266595a7a09d21f6b589580ee0585f330aaccccb836e3"},
- {file = "rpds_py-0.10.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7f5e15c953ace2e8dde9824bdab4bec50adb91a5663df08d7d994240ae6fa31"},
- {file = "rpds_py-0.10.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:61fa268da6e2e1cd350739bb61011121fa550aa2545762e3dc02ea177ee4de35"},
- {file = "rpds_py-0.10.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c48f3fbc3e92c7dd6681a258d22f23adc2eb183c8cb1557d2fcc5a024e80b094"},
- {file = "rpds_py-0.10.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c0503c5b681566e8b722fe8c4c47cce5c7a51f6935d5c7012c4aefe952a35eed"},
- {file = "rpds_py-0.10.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:734c41f9f57cc28658d98270d3436dba65bed0cfc730d115b290e970150c540d"},
- {file = "rpds_py-0.10.6-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a5d7ed104d158c0042a6a73799cf0eb576dfd5fc1ace9c47996e52320c37cb7c"},
- {file = "rpds_py-0.10.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e3df0bc35e746cce42579826b89579d13fd27c3d5319a6afca9893a9b784ff1b"},
- {file = "rpds_py-0.10.6-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:73e0a78a9b843b8c2128028864901f55190401ba38aae685350cf69b98d9f7c9"},
- {file = "rpds_py-0.10.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:5ed505ec6305abd2c2c9586a7b04fbd4baf42d4d684a9c12ec6110deefe2a063"},
- {file = "rpds_py-0.10.6-cp310-none-win32.whl", hash = "sha256:d97dd44683802000277bbf142fd9f6b271746b4846d0acaf0cefa6b2eaf2a7ad"},
- {file = "rpds_py-0.10.6-cp310-none-win_amd64.whl", hash = "sha256:b455492cab07107bfe8711e20cd920cc96003e0da3c1f91297235b1603d2aca7"},
- {file = "rpds_py-0.10.6-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:e8cdd52744f680346ff8c1ecdad5f4d11117e1724d4f4e1874f3a67598821069"},
- {file = "rpds_py-0.10.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:66414dafe4326bca200e165c2e789976cab2587ec71beb80f59f4796b786a238"},
- {file = "rpds_py-0.10.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc435d059f926fdc5b05822b1be4ff2a3a040f3ae0a7bbbe672babb468944722"},
- {file = "rpds_py-0.10.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8e7f2219cb72474571974d29a191714d822e58be1eb171f229732bc6fdedf0ac"},
- {file = "rpds_py-0.10.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3953c6926a63f8ea5514644b7afb42659b505ece4183fdaaa8f61d978754349e"},
- {file = "rpds_py-0.10.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2bb2e4826be25e72013916eecd3d30f66fd076110de09f0e750163b416500721"},
- {file = "rpds_py-0.10.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7bf347b495b197992efc81a7408e9a83b931b2f056728529956a4d0858608b80"},
- {file = "rpds_py-0.10.6-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:102eac53bb0bf0f9a275b438e6cf6904904908562a1463a6fc3323cf47d7a532"},
- {file = "rpds_py-0.10.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:40f93086eef235623aa14dbddef1b9fb4b22b99454cb39a8d2e04c994fb9868c"},
- {file = "rpds_py-0.10.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e22260a4741a0e7a206e175232867b48a16e0401ef5bce3c67ca5b9705879066"},
- {file = "rpds_py-0.10.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f4e56860a5af16a0fcfa070a0a20c42fbb2012eed1eb5ceeddcc7f8079214281"},
- {file = "rpds_py-0.10.6-cp311-none-win32.whl", hash = "sha256:0774a46b38e70fdde0c6ded8d6d73115a7c39d7839a164cc833f170bbf539116"},
- {file = "rpds_py-0.10.6-cp311-none-win_amd64.whl", hash = "sha256:4a5ee600477b918ab345209eddafde9f91c0acd931f3776369585a1c55b04c57"},
- {file = "rpds_py-0.10.6-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:5ee97c683eaface61d38ec9a489e353d36444cdebb128a27fe486a291647aff6"},
- {file = "rpds_py-0.10.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0713631d6e2d6c316c2f7b9320a34f44abb644fc487b77161d1724d883662e31"},
- {file = "rpds_py-0.10.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5a53f5998b4bbff1cb2e967e66ab2addc67326a274567697379dd1e326bded7"},
- {file = "rpds_py-0.10.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6a555ae3d2e61118a9d3e549737bb4a56ff0cec88a22bd1dfcad5b4e04759175"},
- {file = "rpds_py-0.10.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:945eb4b6bb8144909b203a88a35e0a03d22b57aefb06c9b26c6e16d72e5eb0f0"},
- {file = "rpds_py-0.10.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:52c215eb46307c25f9fd2771cac8135d14b11a92ae48d17968eda5aa9aaf5071"},
- {file = "rpds_py-0.10.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1b3cd23d905589cb205710b3988fc8f46d4a198cf12862887b09d7aaa6bf9b9"},
- {file = "rpds_py-0.10.6-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:64ccc28683666672d7c166ed465c09cee36e306c156e787acef3c0c62f90da5a"},
- {file = "rpds_py-0.10.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:516a611a2de12fbea70c78271e558f725c660ce38e0006f75139ba337d56b1f6"},
- {file = "rpds_py-0.10.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9ff93d3aedef11f9c4540cf347f8bb135dd9323a2fc705633d83210d464c579d"},
- {file = "rpds_py-0.10.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d858532212f0650be12b6042ff4378dc2efbb7792a286bee4489eaa7ba010586"},
- {file = "rpds_py-0.10.6-cp312-none-win32.whl", hash = "sha256:3c4eff26eddac49d52697a98ea01b0246e44ca82ab09354e94aae8823e8bda02"},
- {file = "rpds_py-0.10.6-cp312-none-win_amd64.whl", hash = "sha256:150eec465dbc9cbca943c8e557a21afdcf9bab8aaabf386c44b794c2f94143d2"},
- {file = "rpds_py-0.10.6-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:cf693eb4a08eccc1a1b636e4392322582db2a47470d52e824b25eca7a3977b53"},
- {file = "rpds_py-0.10.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4134aa2342f9b2ab6c33d5c172e40f9ef802c61bb9ca30d21782f6e035ed0043"},
- {file = "rpds_py-0.10.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e782379c2028a3611285a795b89b99a52722946d19fc06f002f8b53e3ea26ea9"},
- {file = "rpds_py-0.10.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2f6da6d842195fddc1cd34c3da8a40f6e99e4a113918faa5e60bf132f917c247"},
- {file = "rpds_py-0.10.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b4a9fe992887ac68256c930a2011255bae0bf5ec837475bc6f7edd7c8dfa254e"},
- {file = "rpds_py-0.10.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b788276a3c114e9f51e257f2a6f544c32c02dab4aa7a5816b96444e3f9ffc336"},
- {file = "rpds_py-0.10.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:caa1afc70a02645809c744eefb7d6ee8fef7e2fad170ffdeacca267fd2674f13"},
- {file = "rpds_py-0.10.6-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bddd4f91eede9ca5275e70479ed3656e76c8cdaaa1b354e544cbcf94c6fc8ac4"},
- {file = "rpds_py-0.10.6-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:775049dfa63fb58293990fc59473e659fcafd953bba1d00fc5f0631a8fd61977"},
- {file = "rpds_py-0.10.6-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:c6c45a2d2b68c51fe3d9352733fe048291e483376c94f7723458cfd7b473136b"},
- {file = "rpds_py-0.10.6-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0699ab6b8c98df998c3eacf51a3b25864ca93dab157abe358af46dc95ecd9801"},
- {file = "rpds_py-0.10.6-cp38-none-win32.whl", hash = "sha256:ebdab79f42c5961682654b851f3f0fc68e6cc7cd8727c2ac4ffff955154123c1"},
- {file = "rpds_py-0.10.6-cp38-none-win_amd64.whl", hash = "sha256:24656dc36f866c33856baa3ab309da0b6a60f37d25d14be916bd3e79d9f3afcf"},
- {file = "rpds_py-0.10.6-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:0898173249141ee99ffcd45e3829abe7bcee47d941af7434ccbf97717df020e5"},
- {file = "rpds_py-0.10.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9e9184fa6c52a74a5521e3e87badbf9692549c0fcced47443585876fcc47e469"},
- {file = "rpds_py-0.10.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5752b761902cd15073a527b51de76bbae63d938dc7c5c4ad1e7d8df10e765138"},
- {file = "rpds_py-0.10.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:99a57006b4ec39dbfb3ed67e5b27192792ffb0553206a107e4aadb39c5004cd5"},
- {file = "rpds_py-0.10.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09586f51a215d17efdb3a5f090d7cbf1633b7f3708f60a044757a5d48a83b393"},
- {file = "rpds_py-0.10.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e225a6a14ecf44499aadea165299092ab0cba918bb9ccd9304eab1138844490b"},
- {file = "rpds_py-0.10.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2039f8d545f20c4e52713eea51a275e62153ee96c8035a32b2abb772b6fc9e5"},
- {file = "rpds_py-0.10.6-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:34ad87a831940521d462ac11f1774edf867c34172010f5390b2f06b85dcc6014"},
- {file = "rpds_py-0.10.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:dcdc88b6b01015da066da3fb76545e8bb9a6880a5ebf89e0f0b2e3ca557b3ab7"},
- {file = "rpds_py-0.10.6-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:25860ed5c4e7f5e10c496ea78af46ae8d8468e0be745bd233bab9ca99bfd2647"},
- {file = "rpds_py-0.10.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7854a207ef77319ec457c1eb79c361b48807d252d94348305db4f4b62f40f7f3"},
- {file = "rpds_py-0.10.6-cp39-none-win32.whl", hash = "sha256:e6fcc026a3f27c1282c7ed24b7fcac82cdd70a0e84cc848c0841a3ab1e3dea2d"},
- {file = "rpds_py-0.10.6-cp39-none-win_amd64.whl", hash = "sha256:e98c4c07ee4c4b3acf787e91b27688409d918212dfd34c872201273fdd5a0e18"},
- {file = "rpds_py-0.10.6-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:68fe9199184c18d997d2e4293b34327c0009a78599ce703e15cd9a0f47349bba"},
- {file = "rpds_py-0.10.6-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:3339eca941568ed52d9ad0f1b8eb9fe0958fa245381747cecf2e9a78a5539c42"},
- {file = "rpds_py-0.10.6-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a360cfd0881d36c6dc271992ce1eda65dba5e9368575663de993eeb4523d895f"},
- {file = "rpds_py-0.10.6-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:031f76fc87644a234883b51145e43985aa2d0c19b063e91d44379cd2786144f8"},
- {file = "rpds_py-0.10.6-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1f36a9d751f86455dc5278517e8b65580eeee37d61606183897f122c9e51cef3"},
- {file = "rpds_py-0.10.6-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:052a832078943d2b2627aea0d19381f607fe331cc0eb5df01991268253af8417"},
- {file = "rpds_py-0.10.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:023574366002bf1bd751ebaf3e580aef4a468b3d3c216d2f3f7e16fdabd885ed"},
- {file = "rpds_py-0.10.6-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:defa2c0c68734f4a82028c26bcc85e6b92cced99866af118cd6a89b734ad8e0d"},
- {file = "rpds_py-0.10.6-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:879fb24304ead6b62dbe5034e7b644b71def53c70e19363f3c3be2705c17a3b4"},
- {file = "rpds_py-0.10.6-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:53c43e10d398e365da2d4cc0bcaf0854b79b4c50ee9689652cdc72948e86f487"},
- {file = "rpds_py-0.10.6-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:3777cc9dea0e6c464e4b24760664bd8831738cc582c1d8aacf1c3f546bef3f65"},
- {file = "rpds_py-0.10.6-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:40578a6469e5d1df71b006936ce95804edb5df47b520c69cf5af264d462f2cbb"},
- {file = "rpds_py-0.10.6-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:cf71343646756a072b85f228d35b1d7407da1669a3de3cf47f8bbafe0c8183a4"},
- {file = "rpds_py-0.10.6-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:10f32b53f424fc75ff7b713b2edb286fdbfc94bf16317890260a81c2c00385dc"},
- {file = "rpds_py-0.10.6-pp38-pypy38_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:81de24a1c51cfb32e1fbf018ab0bdbc79c04c035986526f76c33e3f9e0f3356c"},
- {file = "rpds_py-0.10.6-pp38-pypy38_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ac17044876e64a8ea20ab132080ddc73b895b4abe9976e263b0e30ee5be7b9c2"},
- {file = "rpds_py-0.10.6-pp38-pypy38_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5e8a78bd4879bff82daef48c14d5d4057f6856149094848c3ed0ecaf49f5aec2"},
- {file = "rpds_py-0.10.6-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78ca33811e1d95cac8c2e49cb86c0fb71f4d8409d8cbea0cb495b6dbddb30a55"},
- {file = "rpds_py-0.10.6-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c63c3ef43f0b3fb00571cff6c3967cc261c0ebd14a0a134a12e83bdb8f49f21f"},
- {file = "rpds_py-0.10.6-pp38-pypy38_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:7fde6d0e00b2fd0dbbb40c0eeec463ef147819f23725eda58105ba9ca48744f4"},
- {file = "rpds_py-0.10.6-pp38-pypy38_pp73-musllinux_1_2_i686.whl", hash = "sha256:79edd779cfc46b2e15b0830eecd8b4b93f1a96649bcb502453df471a54ce7977"},
- {file = "rpds_py-0.10.6-pp38-pypy38_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:9164ec8010327ab9af931d7ccd12ab8d8b5dc2f4c6a16cbdd9d087861eaaefa1"},
- {file = "rpds_py-0.10.6-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:d29ddefeab1791e3c751e0189d5f4b3dbc0bbe033b06e9c333dca1f99e1d523e"},
- {file = "rpds_py-0.10.6-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:30adb75ecd7c2a52f5e76af50644b3e0b5ba036321c390b8e7ec1bb2a16dd43c"},
- {file = "rpds_py-0.10.6-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd609fafdcdde6e67a139898196698af37438b035b25ad63704fd9097d9a3482"},
- {file = "rpds_py-0.10.6-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6eef672de005736a6efd565577101277db6057f65640a813de6c2707dc69f396"},
- {file = "rpds_py-0.10.6-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6cf4393c7b41abbf07c88eb83e8af5013606b1cdb7f6bc96b1b3536b53a574b8"},
- {file = "rpds_py-0.10.6-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ad857f42831e5b8d41a32437f88d86ead6c191455a3499c4b6d15e007936d4cf"},
- {file = "rpds_py-0.10.6-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d7360573f1e046cb3b0dceeb8864025aa78d98be4bb69f067ec1c40a9e2d9df"},
- {file = "rpds_py-0.10.6-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d08f63561c8a695afec4975fae445245386d645e3e446e6f260e81663bfd2e38"},
- {file = "rpds_py-0.10.6-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:f0f17f2ce0f3529177a5fff5525204fad7b43dd437d017dd0317f2746773443d"},
- {file = "rpds_py-0.10.6-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:442626328600bde1d09dc3bb00434f5374948838ce75c41a52152615689f9403"},
- {file = "rpds_py-0.10.6-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:e9616f5bd2595f7f4a04b67039d890348ab826e943a9bfdbe4938d0eba606971"},
- {file = "rpds_py-0.10.6.tar.gz", hash = "sha256:4ce5a708d65a8dbf3748d2474b580d606b1b9f91b5c6ab2a316e0b0cf7a4ba50"},
+ {file = "rpds_py-0.18.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:5b4e7d8d6c9b2e8ee2d55c90b59c707ca59bc30058269b3db7b1f8df5763557e"},
+ {file = "rpds_py-0.18.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c463ed05f9dfb9baebef68048aed8dcdc94411e4bf3d33a39ba97e271624f8f7"},
+ {file = "rpds_py-0.18.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:01e36a39af54a30f28b73096dd39b6802eddd04c90dbe161c1b8dbe22353189f"},
+ {file = "rpds_py-0.18.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d62dec4976954a23d7f91f2f4530852b0c7608116c257833922a896101336c51"},
+ {file = "rpds_py-0.18.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dd18772815d5f008fa03d2b9a681ae38d5ae9f0e599f7dda233c439fcaa00d40"},
+ {file = "rpds_py-0.18.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:923d39efa3cfb7279a0327e337a7958bff00cc447fd07a25cddb0a1cc9a6d2da"},
+ {file = "rpds_py-0.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39514da80f971362f9267c600b6d459bfbbc549cffc2cef8e47474fddc9b45b1"},
+ {file = "rpds_py-0.18.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a34d557a42aa28bd5c48a023c570219ba2593bcbbb8dc1b98d8cf5d529ab1434"},
+ {file = "rpds_py-0.18.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:93df1de2f7f7239dc9cc5a4a12408ee1598725036bd2dedadc14d94525192fc3"},
+ {file = "rpds_py-0.18.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:34b18ba135c687f4dac449aa5157d36e2cbb7c03cbea4ddbd88604e076aa836e"},
+ {file = "rpds_py-0.18.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c0b5dcf9193625afd8ecc92312d6ed78781c46ecbf39af9ad4681fc9f464af88"},
+ {file = "rpds_py-0.18.0-cp310-none-win32.whl", hash = "sha256:c4325ff0442a12113a6379af66978c3fe562f846763287ef66bdc1d57925d337"},
+ {file = "rpds_py-0.18.0-cp310-none-win_amd64.whl", hash = "sha256:7223a2a5fe0d217e60a60cdae28d6949140dde9c3bcc714063c5b463065e3d66"},
+ {file = "rpds_py-0.18.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:3a96e0c6a41dcdba3a0a581bbf6c44bb863f27c541547fb4b9711fd8cf0ffad4"},
+ {file = "rpds_py-0.18.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:30f43887bbae0d49113cbaab729a112251a940e9b274536613097ab8b4899cf6"},
+ {file = "rpds_py-0.18.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fcb25daa9219b4cf3a0ab24b0eb9a5cc8949ed4dc72acb8fa16b7e1681aa3c58"},
+ {file = "rpds_py-0.18.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d68c93e381010662ab873fea609bf6c0f428b6d0bb00f2c6939782e0818d37bf"},
+ {file = "rpds_py-0.18.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b34b7aa8b261c1dbf7720b5d6f01f38243e9b9daf7e6b8bc1fd4657000062f2c"},
+ {file = "rpds_py-0.18.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e6d75ab12b0bbab7215e5d40f1e5b738aa539598db27ef83b2ec46747df90e1"},
+ {file = "rpds_py-0.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b8612cd233543a3781bc659c731b9d607de65890085098986dfd573fc2befe5"},
+ {file = "rpds_py-0.18.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:aec493917dd45e3c69d00a8874e7cbed844efd935595ef78a0f25f14312e33c6"},
+ {file = "rpds_py-0.18.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:661d25cbffaf8cc42e971dd570d87cb29a665f49f4abe1f9e76be9a5182c4688"},
+ {file = "rpds_py-0.18.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1df3659d26f539ac74fb3b0c481cdf9d725386e3552c6fa2974f4d33d78e544b"},
+ {file = "rpds_py-0.18.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a1ce3ba137ed54f83e56fb983a5859a27d43a40188ba798993812fed73c70836"},
+ {file = "rpds_py-0.18.0-cp311-none-win32.whl", hash = "sha256:69e64831e22a6b377772e7fb337533c365085b31619005802a79242fee620bc1"},
+ {file = "rpds_py-0.18.0-cp311-none-win_amd64.whl", hash = "sha256:998e33ad22dc7ec7e030b3df701c43630b5bc0d8fbc2267653577e3fec279afa"},
+ {file = "rpds_py-0.18.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:7f2facbd386dd60cbbf1a794181e6aa0bd429bd78bfdf775436020172e2a23f0"},
+ {file = "rpds_py-0.18.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1d9a5be316c15ffb2b3c405c4ff14448c36b4435be062a7f578ccd8b01f0c4d8"},
+ {file = "rpds_py-0.18.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cd5bf1af8efe569654bbef5a3e0a56eca45f87cfcffab31dd8dde70da5982475"},
+ {file = "rpds_py-0.18.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5417558f6887e9b6b65b4527232553c139b57ec42c64570569b155262ac0754f"},
+ {file = "rpds_py-0.18.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:56a737287efecafc16f6d067c2ea0117abadcd078d58721f967952db329a3e5c"},
+ {file = "rpds_py-0.18.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8f03bccbd8586e9dd37219bce4d4e0d3ab492e6b3b533e973fa08a112cb2ffc9"},
+ {file = "rpds_py-0.18.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4457a94da0d5c53dc4b3e4de1158bdab077db23c53232f37a3cb7afdb053a4e3"},
+ {file = "rpds_py-0.18.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0ab39c1ba9023914297dd88ec3b3b3c3f33671baeb6acf82ad7ce883f6e8e157"},
+ {file = "rpds_py-0.18.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:9d54553c1136b50fd12cc17e5b11ad07374c316df307e4cfd6441bea5fb68496"},
+ {file = "rpds_py-0.18.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0af039631b6de0397ab2ba16eaf2872e9f8fca391b44d3d8cac317860a700a3f"},
+ {file = "rpds_py-0.18.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:84ffab12db93b5f6bad84c712c92060a2d321b35c3c9960b43d08d0f639d60d7"},
+ {file = "rpds_py-0.18.0-cp312-none-win32.whl", hash = "sha256:685537e07897f173abcf67258bee3c05c374fa6fff89d4c7e42fb391b0605e98"},
+ {file = "rpds_py-0.18.0-cp312-none-win_amd64.whl", hash = "sha256:e003b002ec72c8d5a3e3da2989c7d6065b47d9eaa70cd8808b5384fbb970f4ec"},
+ {file = "rpds_py-0.18.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:08f9ad53c3f31dfb4baa00da22f1e862900f45908383c062c27628754af2e88e"},
+ {file = "rpds_py-0.18.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c0013fe6b46aa496a6749c77e00a3eb07952832ad6166bd481c74bda0dcb6d58"},
+ {file = "rpds_py-0.18.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e32a92116d4f2a80b629778280103d2a510a5b3f6314ceccd6e38006b5e92dcb"},
+ {file = "rpds_py-0.18.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e541ec6f2ec456934fd279a3120f856cd0aedd209fc3852eca563f81738f6861"},
+ {file = "rpds_py-0.18.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bed88b9a458e354014d662d47e7a5baafd7ff81c780fd91584a10d6ec842cb73"},
+ {file = "rpds_py-0.18.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2644e47de560eb7bd55c20fc59f6daa04682655c58d08185a9b95c1970fa1e07"},
+ {file = "rpds_py-0.18.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e8916ae4c720529e18afa0b879473049e95949bf97042e938530e072fde061d"},
+ {file = "rpds_py-0.18.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:465a3eb5659338cf2a9243e50ad9b2296fa15061736d6e26240e713522b6235c"},
+ {file = "rpds_py-0.18.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:ea7d4a99f3b38c37eac212dbd6ec42b7a5ec51e2c74b5d3223e43c811609e65f"},
+ {file = "rpds_py-0.18.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:67071a6171e92b6da534b8ae326505f7c18022c6f19072a81dcf40db2638767c"},
+ {file = "rpds_py-0.18.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:41ef53e7c58aa4ef281da975f62c258950f54b76ec8e45941e93a3d1d8580594"},
+ {file = "rpds_py-0.18.0-cp38-none-win32.whl", hash = "sha256:fdea4952db2793c4ad0bdccd27c1d8fdd1423a92f04598bc39425bcc2b8ee46e"},
+ {file = "rpds_py-0.18.0-cp38-none-win_amd64.whl", hash = "sha256:7cd863afe7336c62ec78d7d1349a2f34c007a3cc6c2369d667c65aeec412a5b1"},
+ {file = "rpds_py-0.18.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:5307def11a35f5ae4581a0b658b0af8178c65c530e94893345bebf41cc139d33"},
+ {file = "rpds_py-0.18.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:77f195baa60a54ef9d2de16fbbfd3ff8b04edc0c0140a761b56c267ac11aa467"},
+ {file = "rpds_py-0.18.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39f5441553f1c2aed4de4377178ad8ff8f9d733723d6c66d983d75341de265ab"},
+ {file = "rpds_py-0.18.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9a00312dea9310d4cb7dbd7787e722d2e86a95c2db92fbd7d0155f97127bcb40"},
+ {file = "rpds_py-0.18.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8f2fc11e8fe034ee3c34d316d0ad8808f45bc3b9ce5857ff29d513f3ff2923a1"},
+ {file = "rpds_py-0.18.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:586f8204935b9ec884500498ccc91aa869fc652c40c093bd9e1471fbcc25c022"},
+ {file = "rpds_py-0.18.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddc2f4dfd396c7bfa18e6ce371cba60e4cf9d2e5cdb71376aa2da264605b60b9"},
+ {file = "rpds_py-0.18.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5ddcba87675b6d509139d1b521e0c8250e967e63b5909a7e8f8944d0f90ff36f"},
+ {file = "rpds_py-0.18.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:7bd339195d84439cbe5771546fe8a4e8a7a045417d8f9de9a368c434e42a721e"},
+ {file = "rpds_py-0.18.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:d7c36232a90d4755b720fbd76739d8891732b18cf240a9c645d75f00639a9024"},
+ {file = "rpds_py-0.18.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:6b0817e34942b2ca527b0e9298373e7cc75f429e8da2055607f4931fded23e20"},
+ {file = "rpds_py-0.18.0-cp39-none-win32.whl", hash = "sha256:99f70b740dc04d09e6b2699b675874367885217a2e9f782bdf5395632ac663b7"},
+ {file = "rpds_py-0.18.0-cp39-none-win_amd64.whl", hash = "sha256:6ef687afab047554a2d366e112dd187b62d261d49eb79b77e386f94644363294"},
+ {file = "rpds_py-0.18.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:ad36cfb355e24f1bd37cac88c112cd7730873f20fb0bdaf8ba59eedf8216079f"},
+ {file = "rpds_py-0.18.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:36b3ee798c58ace201289024b52788161e1ea133e4ac93fba7d49da5fec0ef9e"},
+ {file = "rpds_py-0.18.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f8a2f084546cc59ea99fda8e070be2fd140c3092dc11524a71aa8f0f3d5a55ca"},
+ {file = "rpds_py-0.18.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e4461d0f003a0aa9be2bdd1b798a041f177189c1a0f7619fe8c95ad08d9a45d7"},
+ {file = "rpds_py-0.18.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8db715ebe3bb7d86d77ac1826f7d67ec11a70dbd2376b7cc214199360517b641"},
+ {file = "rpds_py-0.18.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:793968759cd0d96cac1e367afd70c235867831983f876a53389ad869b043c948"},
+ {file = "rpds_py-0.18.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66e6a3af5a75363d2c9a48b07cb27c4ea542938b1a2e93b15a503cdfa8490795"},
+ {file = "rpds_py-0.18.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6ef0befbb5d79cf32d0266f5cff01545602344eda89480e1dd88aca964260b18"},
+ {file = "rpds_py-0.18.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:1d4acf42190d449d5e89654d5c1ed3a4f17925eec71f05e2a41414689cda02d1"},
+ {file = "rpds_py-0.18.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:a5f446dd5055667aabaee78487f2b5ab72e244f9bc0b2ffebfeec79051679984"},
+ {file = "rpds_py-0.18.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:9dbbeb27f4e70bfd9eec1be5477517365afe05a9b2c441a0b21929ee61048124"},
+ {file = "rpds_py-0.18.0-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:22806714311a69fd0af9b35b7be97c18a0fc2826e6827dbb3a8c94eac6cf7eeb"},
+ {file = "rpds_py-0.18.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:b34ae4636dfc4e76a438ab826a0d1eed2589ca7d9a1b2d5bb546978ac6485461"},
+ {file = "rpds_py-0.18.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c8370641f1a7f0e0669ddccca22f1da893cef7628396431eb445d46d893e5cd"},
+ {file = "rpds_py-0.18.0-pp38-pypy38_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c8362467a0fdeccd47935f22c256bec5e6abe543bf0d66e3d3d57a8fb5731863"},
+ {file = "rpds_py-0.18.0-pp38-pypy38_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:11a8c85ef4a07a7638180bf04fe189d12757c696eb41f310d2426895356dcf05"},
+ {file = "rpds_py-0.18.0-pp38-pypy38_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b316144e85316da2723f9d8dc75bada12fa58489a527091fa1d5a612643d1a0e"},
+ {file = "rpds_py-0.18.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf1ea2e34868f6fbf070e1af291c8180480310173de0b0c43fc38a02929fc0e3"},
+ {file = "rpds_py-0.18.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e546e768d08ad55b20b11dbb78a745151acbd938f8f00d0cfbabe8b0199b9880"},
+ {file = "rpds_py-0.18.0-pp38-pypy38_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:4901165d170a5fde6f589acb90a6b33629ad1ec976d4529e769c6f3d885e3e80"},
+ {file = "rpds_py-0.18.0-pp38-pypy38_pp73-musllinux_1_2_i686.whl", hash = "sha256:618a3d6cae6ef8ec88bb76dd80b83cfe415ad4f1d942ca2a903bf6b6ff97a2da"},
+ {file = "rpds_py-0.18.0-pp38-pypy38_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:ed4eb745efbff0a8e9587d22a84be94a5eb7d2d99c02dacf7bd0911713ed14dd"},
+ {file = "rpds_py-0.18.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6c81e5f372cd0dc5dc4809553d34f832f60a46034a5f187756d9b90586c2c307"},
+ {file = "rpds_py-0.18.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:43fbac5f22e25bee1d482c97474f930a353542855f05c1161fd804c9dc74a09d"},
+ {file = "rpds_py-0.18.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d7faa6f14017c0b1e69f5e2c357b998731ea75a442ab3841c0dbbbfe902d2c4"},
+ {file = "rpds_py-0.18.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:08231ac30a842bd04daabc4d71fddd7e6d26189406d5a69535638e4dcb88fe76"},
+ {file = "rpds_py-0.18.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:044a3e61a7c2dafacae99d1e722cc2d4c05280790ec5a05031b3876809d89a5c"},
+ {file = "rpds_py-0.18.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3f26b5bd1079acdb0c7a5645e350fe54d16b17bfc5e71f371c449383d3342e17"},
+ {file = "rpds_py-0.18.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:482103aed1dfe2f3b71a58eff35ba105289b8d862551ea576bd15479aba01f66"},
+ {file = "rpds_py-0.18.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1374f4129f9bcca53a1bba0bb86bf78325a0374577cf7e9e4cd046b1e6f20e24"},
+ {file = "rpds_py-0.18.0-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:635dc434ff724b178cb192c70016cc0ad25a275228f749ee0daf0eddbc8183b1"},
+ {file = "rpds_py-0.18.0-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:bc362ee4e314870a70f4ae88772d72d877246537d9f8cb8f7eacf10884862432"},
+ {file = "rpds_py-0.18.0-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:4832d7d380477521a8c1644bbab6588dfedea5e30a7d967b5fb75977c45fd77f"},
+ {file = "rpds_py-0.18.0.tar.gz", hash = "sha256:42821446ee7a76f5d9f71f9e33a4fb2ffd724bb3e7f93386150b61a43115788d"},
]
[[package]]
name = "safetensors"
-version = "0.4.0"
+version = "0.4.2"
description = ""
optional = false
python-versions = ">=3.7"
files = [
- {file = "safetensors-0.4.0-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:2289ae6dbe6d027ecee016b28ced13a2e21a0b3a3a757a23033a2d1c0b1bad55"},
- {file = "safetensors-0.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bf6458959f310f551cbbeef2255527ade5f783f952738e73e4d0136198cc3bfe"},
- {file = "safetensors-0.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b6b60a58a8f7cc7aed3b5b73dce1f5259a53c83d9ba43a76a874e6ad868c1b4d"},
- {file = "safetensors-0.4.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:491b3477e4d0d4599bb75d79da4b75af2e6ed9b1f6ec2b715991f0bc927bf09a"},
- {file = "safetensors-0.4.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:59d2e10b7e0cd18bb73ed7c17c624a5957b003b81345e18159591771c26ee428"},
- {file = "safetensors-0.4.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3f667a4c12fb593f5f66ce966cb1b14a7148898b2b1a7f79e0761040ae1e3c51"},
- {file = "safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f9909512bcb6f712bdd04c296cdfb0d8ff73d258ffc5af884bb62ea02d221e0"},
- {file = "safetensors-0.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d33d29e846821f0e4f92614022949b09ccf063cb36fe2f9fe099cde1efbfbb87"},
- {file = "safetensors-0.4.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4d512525a8e05a045ce6698066ba0c5378c174a83e0b3720a8c7799dc1bb06f3"},
- {file = "safetensors-0.4.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0219cea445177f6ad1f9acd3a8d025440c8ff436d70a4a7c7ba9c36066aa9474"},
- {file = "safetensors-0.4.0-cp310-none-win32.whl", hash = "sha256:67ab171eeaad6972d3971c53d29d53353c67f6743284c6d637b59fa3e54c8a94"},
- {file = "safetensors-0.4.0-cp310-none-win_amd64.whl", hash = "sha256:7ffc736039f08a9ca1f09816a7481b8e4469c06e8f8a5ffa8cb67ddd79e6d77f"},
- {file = "safetensors-0.4.0-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:4fe9e3737b30de458225a23926219ca30b902ee779b6a3df96eaab2b6d625ec2"},
- {file = "safetensors-0.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e7916e814a90008de767b1c164a1d83803693c661ffe9af5a697b22e2752edb0"},
- {file = "safetensors-0.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cbc4a4da01143472323c145f3c289e5f6fabde0ac0a3414dabf912a21692fff4"},
- {file = "safetensors-0.4.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a54c21654a47669b38e359e8f852af754b786c9da884bb61ad5e9af12bd71ccb"},
- {file = "safetensors-0.4.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:25cd407955bad5340ba17f9f8ac789a0d751601a311e2f7b2733f9384478c95e"},
- {file = "safetensors-0.4.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:82e8fc4e3503cd738fd40718a430fe0e5ce6e7ff91a73d6ce628bbb89c41e8ce"},
- {file = "safetensors-0.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48b92059b1a4ad163024d4f526e0e73ebe2bb3ae70537e15e347820b4de5dc27"},
- {file = "safetensors-0.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5daa05058f7dce85b5f9f60c4eab483ed7859d63978f08a76e52e78859ff20ca"},
- {file = "safetensors-0.4.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a86565a5c112dd855909e20144947b4f53abb78c4de207f36ca71ee63ba5b90d"},
- {file = "safetensors-0.4.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:38032078ed9fea52d06584e441bccc73fb475c4581600c6d6166de2fe2deb3d1"},
- {file = "safetensors-0.4.0-cp311-none-win32.whl", hash = "sha256:2f99d90c91b7c76b40a862acd9085bc77f7974a27dee7cfcebe46149af5a99a1"},
- {file = "safetensors-0.4.0-cp311-none-win_amd64.whl", hash = "sha256:74e2a448ffe19be188b457b130168190ee73b5a75e45ba96796320c1f5ae35d2"},
- {file = "safetensors-0.4.0-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:1e2f9c69b41d03b4826ffb96b29e07444bb6b34a78a7bafd0b88d59e8ec75b8a"},
- {file = "safetensors-0.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3910fb5bf747413b59f1a34e6d2a993b589fa7d919709518823c70efaaa350bd"},
- {file = "safetensors-0.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf8fdca709b2470a35a59b1e6dffea75cbe1214b22612b5dd4c93947697aea8b"},
- {file = "safetensors-0.4.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2f27b8ef814c5fb43456caeb7f3cbb889b76115180aad1f42402839c14a47c5b"},
- {file = "safetensors-0.4.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7b2d6101eccc43c7be0cb052f13ceda64288b3d8b344b988ed08d7133cbce2f3"},
- {file = "safetensors-0.4.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fdc34027b545a69be3d4220c140b276129523e4e46db06ad1a0b60d6a4cf9214"},
- {file = "safetensors-0.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db7bb48ca9e90bb9526c71b388d38d8de160c0354f4c5126df23e8701a870dcb"},
- {file = "safetensors-0.4.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a78ffc0795d3595cd9e4d453502e35f764276c49e434b25556a15a337db4dafc"},
- {file = "safetensors-0.4.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:8e735b0f79090f6855b55e205e820b7b595502ffca0009a5c13eef3661ce465b"},
- {file = "safetensors-0.4.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:f8d2416734e850d5392afffbcb2b8985ea29fb171f1cb197e2ae51b8e35d6438"},
- {file = "safetensors-0.4.0-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:e853e189ba7d47eaf561094586692ba2bbdd258c096f1755805cac098de0e6ab"},
- {file = "safetensors-0.4.0-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:4b2aa57b5a4d576f3d1dd6e56980026340f156f8a13c13016bfac4e25295b53f"},
- {file = "safetensors-0.4.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b6c1316ffde6cb4bf22c7445bc9fd224b4d1b9dd7320695f5611c89e802e4b6"},
- {file = "safetensors-0.4.0-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:003077ec85261d00061058fa12e3c1d2055366b02ce8f2938929359ffbaff2b8"},
- {file = "safetensors-0.4.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bd63d83a92f1437a8b0431779320376030ae43ace980bea5686d515de0784100"},
- {file = "safetensors-0.4.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2077801800b4b13301d8d6290c7fb5bd60737320001717153ebc4371776643b5"},
- {file = "safetensors-0.4.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7abe0e157a49a75aeeccfbc4f3dac38d8f98512d3cdb35c200f8e628dc5773cf"},
- {file = "safetensors-0.4.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3bfed574f6b1e7e7fe1f17213278875ef6c6e8b1582ab6eda93947db1178cae6"},
- {file = "safetensors-0.4.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:964ef166a286ce3b023d0d0bd0e21d440a1c8028981c8abdb136bc7872ba9b3d"},
- {file = "safetensors-0.4.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:44f84373e42183bd56a13a1f2d8acb1db7fedaeffbd83e79cec861477eee1af4"},
- {file = "safetensors-0.4.0-cp37-none-win32.whl", hash = "sha256:c68132727dd86fb641102e494d445f705efe402f4d5e24b278183a15499ab400"},
- {file = "safetensors-0.4.0-cp37-none-win_amd64.whl", hash = "sha256:1db87155454c168aef118d5657a403aee48a4cb08d8851a981157f07351ea317"},
- {file = "safetensors-0.4.0-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:9e583fa68e5a07cc859c4e13c1ebff12029904aa2e27185cf04a1f57fe9a81c4"},
- {file = "safetensors-0.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:73e7696dcf3f72f99545eb1abe6106ad65ff1f62381d6ce4b34be3272552897a"},
- {file = "safetensors-0.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4936096a57c62e84e200f92620a536be067fc5effe46ecc7f230ebb496ecd579"},
- {file = "safetensors-0.4.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:87b328ee1591adac332543e1f5fc2c2d7f149b745ebb0d58d7850818ff9cee27"},
- {file = "safetensors-0.4.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b69554c143336256260eceff1d3c0969172a641b54d4668489a711b05f92a2c0"},
- {file = "safetensors-0.4.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3ebf6bcece5d5d1bd6416472f94604d2c834ca752ac60ed42dba7157e595a990"},
- {file = "safetensors-0.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6686ce01b8602d55a7d9903c90d4a6e6f90aeb6ddced7cf4605892d0ba94bcb8"},
- {file = "safetensors-0.4.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9b8fd6cc2f3bda444a048b541c843c7b7fefc89c4120d7898ea7d5b026e93891"},
- {file = "safetensors-0.4.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8a6abfe67692f81b8bdb99c837f28351c17e624ebf136970c850ee989c720446"},
- {file = "safetensors-0.4.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:27a24ca8822c469ee452db4c13418ba983315a0d863c018a9af15f2305eac38c"},
- {file = "safetensors-0.4.0-cp38-none-win32.whl", hash = "sha256:c4a0a47c8640167792d8261ee21b26430bbc39130a7edaad7f4c0bc05669d00e"},
- {file = "safetensors-0.4.0-cp38-none-win_amd64.whl", hash = "sha256:a738970a367f39249e2abb900d9441a8a86d7ff50083e5eaa6e7760a9f216014"},
- {file = "safetensors-0.4.0-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:806379f37e1abd5d302288c4b2f4186dd7ea7143d4c7811f90a8077f0ae8967b"},
- {file = "safetensors-0.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2b9b94133ed2ae9dda0e95dcace7b7556eba023ffa4c4ae6df8f99377f571d6a"},
- {file = "safetensors-0.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b563a14c43614815a6b524d2e4edeaace50b717f7e7487bb227dd5b68350f5a"},
- {file = "safetensors-0.4.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:00a9b157be660fb7ba88fa2eedd05ec93793a5b61e43e783e10cb0b995372802"},
- {file = "safetensors-0.4.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c8f194f45ab6aa767993c24f0aeb950af169dbc5d611b94c9021a1d13b8a1a34"},
- {file = "safetensors-0.4.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:469360b9451db10bfed3881378d5a71b347ecb1ab4f42367d77b8164a13af70b"},
- {file = "safetensors-0.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5f75fa97ccf32a3c7af476c6a0e851023197d3c078f6de3612008fff94735f9"},
- {file = "safetensors-0.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:acf0180283c2efae72f1d8c0a4a7974662091df01be3aa43b5237b1e52ed0a01"},
- {file = "safetensors-0.4.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:cd02b495ba0814619f40bda46771bb06dbbf1d42524b66fa03b2a736c77e4515"},
- {file = "safetensors-0.4.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c42bdea183dbaa99e2f0e6120dc524df79cf4289a6f90f30a534444ef20f49fa"},
- {file = "safetensors-0.4.0-cp39-none-win32.whl", hash = "sha256:cef7bb5d9feae7146c3c3c7b3aef7d2c8b39ba7f5ff4252d368eb69462a47076"},
- {file = "safetensors-0.4.0-cp39-none-win_amd64.whl", hash = "sha256:79dd46fb1f19282fd12f544471efb97823ede927cedbf9cf35550d92b349fdd2"},
- {file = "safetensors-0.4.0-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:002301c1afa32909f83745b0c124d002e7ae07e15671f3b43cbebd0ffc5e6037"},
- {file = "safetensors-0.4.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:67762d36ae088c73d4a3c96bfc4ea8d31233554f35b6cace3a18533238d462ea"},
- {file = "safetensors-0.4.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f45230f20a206e5e4c7f7bbf9342178410c6f8b0af889843aa99045a76f7691"},
- {file = "safetensors-0.4.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f2ca939bbd8fb2f4dfa28e39a146dad03bc9325e9fc831b68f7b98f69a5a2f1"},
- {file = "safetensors-0.4.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:61a00f281391fae5ce91df70918bb61c12d2d514a493fd8056e12114be729911"},
- {file = "safetensors-0.4.0-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:435fd136a42492b280cb55126f9ce9535b35dd49df2c5d572a5945455a439448"},
- {file = "safetensors-0.4.0-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f0daa788273d683258fb1e4a5e16bef4486b2fca536451a2591bc0f4a6488895"},
- {file = "safetensors-0.4.0-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:0620ab0d41e390ccb1c4ea8f63dc00cb5f0b96a5cdd3cd0d64c21765720c074a"},
- {file = "safetensors-0.4.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc1fa8d067733cb67f22926689ee808f08afacf7700d2ffb44efae90a0693eb1"},
- {file = "safetensors-0.4.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcaa40bc363edda145db75cd030f3b1822e5478d550c3500a42502ecef32c959"},
- {file = "safetensors-0.4.0-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b561fbc044db7beff2ece0ec219a291809d45a38d30c6b38e7cc46482582f4ba"},
- {file = "safetensors-0.4.0-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:79a983b09782dacf9a1adb19bb98f4a8f6c3144108939f572c047b5797e43cf5"},
- {file = "safetensors-0.4.0-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:10b65cd3ad79f5d0daf281523b4146bc271a34bb7430d4e03212e0de8622dab8"},
- {file = "safetensors-0.4.0-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:114decacc475a6a9e2f9102a00c171d113ddb5d35cb0bda0db2c0c82b2eaa9ce"},
- {file = "safetensors-0.4.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:72ddb741dd5fe42521db76a70e012f76995516a12e7e0ef26be03ea9be77802a"},
- {file = "safetensors-0.4.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c5556c2ec75f5a6134866eddd7341cb36062e6edaea343478a279591b63ddba"},
- {file = "safetensors-0.4.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ed50f239b0ce7ae85b078395593b4a351ede7e6f73af25f4873e3392336f64c9"},
- {file = "safetensors-0.4.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:495dcaea8fbab70b927d2274e2547824462737acbf98ccd851a71124f779a5c6"},
- {file = "safetensors-0.4.0-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:3f4d90c79a65ba2fe2ff0876f6140748f0a3ce6a21e27a35190f4f96321803f8"},
- {file = "safetensors-0.4.0-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:7a524382b5c55b5fbb168e0e9d3f502450c8cf3fb81b93e880018437c206a482"},
- {file = "safetensors-0.4.0-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:9849ea60c7e840bfdd6030ad454d4a6ba837b3398c902f15a30460dd6961c28c"},
- {file = "safetensors-0.4.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:6c42623ae7045615d9eaa6877b9df1db4e9cc71ecc14bcc721ea1e475dddd595"},
- {file = "safetensors-0.4.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80cb8342f00f3c41b3b93b1a599b84723280d3ac90829bc62262efc03ab28793"},
- {file = "safetensors-0.4.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8c4f5ed4ede384dea8c99bae76b0718a828dbf7b2c8ced1f44e3b9b1a124475"},
- {file = "safetensors-0.4.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:40d7cf03493bfe75ef62e2c716314474b28d9ba5bf4909763e4b8dd14330c01a"},
- {file = "safetensors-0.4.0-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:232029f0a9fa6fa1f737324eda98a700409811186888536a2333cbbf64e41741"},
- {file = "safetensors-0.4.0-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:9ed55f4a20c78ff3e8477efb63c8303c2152cdfb3bfea4d025a80f54d38fd628"},
- {file = "safetensors-0.4.0.tar.gz", hash = "sha256:b985953c3cf11e942eac4317ef3db3da713e274109cf7cfb6076d877054f013e"},
+ {file = "safetensors-0.4.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:69d8bb8384dc2cb5b72c36c4d6980771b293d1a1377b378763f5e37b6bb8d133"},
+ {file = "safetensors-0.4.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3d420e19fcef96d0067f4de4699682b4bbd85fc8fea0bd45fcd961fdf3e8c82c"},
+ {file = "safetensors-0.4.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9ca54742122fa3c4821754adb67318e1cd25c3a22bbf0c5520d5176e77a099ac"},
+ {file = "safetensors-0.4.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8b47aa643afdfd66cf7ce4c184092ae734e15d10aba2c2948f24270211801c3c"},
+ {file = "safetensors-0.4.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d88a16bbc330f27e7f2d4caaf6fb061ad0b8a756ecc4033260b0378e128ce8a2"},
+ {file = "safetensors-0.4.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e9223b8ac21085db614a510eb3445e7083cae915a9202357555fa939695d4f57"},
+ {file = "safetensors-0.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce6cb86133dc8930a7ab5e7438545a7f205f7a1cdd5aaf108c1d0da6bdcfbc2b"},
+ {file = "safetensors-0.4.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b8a628e0ae2bbc334b62952c384aa5f41621d01850f8d67b04a96b9c39dd7326"},
+ {file = "safetensors-0.4.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:88d6beb7f811a081e0e5f1d9669fdac816c45340c04b1eaf7ebfda0ce93ea403"},
+ {file = "safetensors-0.4.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b57fc5b1b54cb12d8690a58a4cf4b7144730d4bde9d98aa0e1dab6295a1cd579"},
+ {file = "safetensors-0.4.2-cp310-none-win32.whl", hash = "sha256:9d87a1c98803c16cf113b9ba03f07b2dce5e8eabfd1811a7f7323fcaa2a1bf47"},
+ {file = "safetensors-0.4.2-cp310-none-win_amd64.whl", hash = "sha256:18930ec1d1ecb526d3d9835abc2489b8f1530877518f0c541e77ef0b7abcbd99"},
+ {file = "safetensors-0.4.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:c5dd2ed788730ed56b415d1a11c62026b8cc8c573f55a2092afb3ab383e94fff"},
+ {file = "safetensors-0.4.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:cc41791b33efb9c83a59b731619f3d15f543dfe71f3a793cb8fbf9bd5d0d5d71"},
+ {file = "safetensors-0.4.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c888bf71d5ca12a720f1ed87d407c4918afa022fb247a6546d8fac15b1f112b"},
+ {file = "safetensors-0.4.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e6b2feb4b47226a16a792e6fac3f49442714884a3d4c1008569d5068a3941be9"},
+ {file = "safetensors-0.4.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f41cc0ee4b838ae8f4d8364a1b162067693d11a3893f0863be8c228d40e4d0ee"},
+ {file = "safetensors-0.4.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:51b7228e46c0a483c40ba4b9470dea00fb1ff8685026bb4766799000f6328ac2"},
+ {file = "safetensors-0.4.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02697f8f2be8ca3c37a4958702dbdb1864447ef765e18b5328a1617022dcf164"},
+ {file = "safetensors-0.4.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:27fd8f65cf7c80e4280cae1ee6bcd85c483882f6580821abe71ee1a0d3dcfca7"},
+ {file = "safetensors-0.4.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c487b5f113b0924c9534a07dc034830fb4ef05ce9bb6d78cfe016a7dedfe281f"},
+ {file = "safetensors-0.4.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:da7f6483f3fe67ff39b3a55552552c67930ea10a36e9f2539d36fc205273d767"},
+ {file = "safetensors-0.4.2-cp311-none-win32.whl", hash = "sha256:52a7012f6cb9cb4a132760b6308daede18a9f5f8952ce08adc7c67a7d865c2d8"},
+ {file = "safetensors-0.4.2-cp311-none-win_amd64.whl", hash = "sha256:4d1361a097ac430b310ce9eed8ed4746edee33ddafdfbb965debc8966fc34dc2"},
+ {file = "safetensors-0.4.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:77af8aa0edcc2863760fd6febbfdb82e88fd75d0e60c1ce4ba57208ba5e4a89b"},
+ {file = "safetensors-0.4.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:846666c1c5a8c8888d2dfda8d3921cb9cb8e2c5f78365be756c11021e75a0a2a"},
+ {file = "safetensors-0.4.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f4bfc7ea19b446bfad41510d4b4c76101698c00caaa8a332c8edd8090a412ef"},
+ {file = "safetensors-0.4.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:233436fd30f27ffeb3c3780d0b84f496518868445c7a8db003639a649cc98453"},
+ {file = "safetensors-0.4.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7a09237a795d11cd11f9dae505d170a29b5616151db1e10c14f892b11caadc7d"},
+ {file = "safetensors-0.4.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:de01c9a3a3b7b69627d624ff69d9f11d28ce9908eea2fb6245adafa4b1d43df6"},
+ {file = "safetensors-0.4.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c1f25c5069ee42a5bcffdc66c300a407941edd73f3239e9fdefd26216407391"},
+ {file = "safetensors-0.4.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7a73b3649456d09ca8506140d44484b63154a7378434cc1e8719f8056550b224"},
+ {file = "safetensors-0.4.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e1625a8d07d046e968bd5c4961810aba1225984e4fb9243626f9d04a06ed3fee"},
+ {file = "safetensors-0.4.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f74c86b25615cb24ad4cff765a2eefc09d71bf0fed97588cf585aad9c38fbb4"},
+ {file = "safetensors-0.4.2-cp312-none-win32.whl", hash = "sha256:8523b9c5777d771bcde5c2389c03f1cdf7ebe8797432a1bd5e345efe25c55987"},
+ {file = "safetensors-0.4.2-cp312-none-win_amd64.whl", hash = "sha256:dcff0243e1737a21f83d664c63fed89d1f532c23fc6830d0427279fabd789ccb"},
+ {file = "safetensors-0.4.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:96ad3d7d472612e26cbe413922b4fb13933310f0511d346ea5cc9a1e856e52eb"},
+ {file = "safetensors-0.4.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:88250922401b5ae4e37de929178caf46be47ed16c817b2237b81679bec07c120"},
+ {file = "safetensors-0.4.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d40443554142fc0ab30652d5cc8554c4b7a613513bde00373e18afd5de8cbe4b"},
+ {file = "safetensors-0.4.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:27f53f70106224d32d874aacecbeb4a6e4c5b16a1d2006d0e876d97229086d71"},
+ {file = "safetensors-0.4.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cc068afe23734dfb26ce19db0a7877499ddf73b1d55ceb762417e8da4a1b05fb"},
+ {file = "safetensors-0.4.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9be1918eb8d43a11a6f8806759fccfa0eeb0542b12924caba66af8a7800ad01a"},
+ {file = "safetensors-0.4.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41911087d20a7bbd78cb4ad4f98aab0c431533107584df6635d8b54b99945573"},
+ {file = "safetensors-0.4.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:50771c662aab909f31e94d048e76861fd027d66076ea773eef2e66c717766e24"},
+ {file = "safetensors-0.4.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:13f2e57be007b7ea9329133d2399e6bdfcf1910f655440a4da17df3a45afcd30"},
+ {file = "safetensors-0.4.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:c772147e6395bc829842e0a98e1b30c67fe25d816299c28196488511d5a5e951"},
+ {file = "safetensors-0.4.2-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:36239a0060b537a3e8c473df78cffee14c3ec4f51d5f1a853af99371a2fb2a35"},
+ {file = "safetensors-0.4.2-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:d0cbb7664fad2c307f95195f951b7059e95dc23e0e1822e5978c8b500098543c"},
+ {file = "safetensors-0.4.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2b3e55adb6bd9dc1c2a341e72f48f075953fa35d173dd8e29a95b3b02d0d1462"},
+ {file = "safetensors-0.4.2-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:42f743b3cca863fba53ca57a193f510e5ec359b97f38c282437716b6768e4a25"},
+ {file = "safetensors-0.4.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:04e6af4a6dbeb06c4e6e7d46cf9c716cbc4cc5ef62584fd8a7c0fe558562df45"},
+ {file = "safetensors-0.4.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a492ba21b5c8f14ee5ec9b20f42ba969e53ca1f909a4d04aad736b66a341dcc2"},
+ {file = "safetensors-0.4.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b25b8233a1a85dc67e39838951cfb01595d792f3b7b644add63edb652992e030"},
+ {file = "safetensors-0.4.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:fd27e063fbdafe776f7b1714da59110e88f270e86db00788a8fd65f4eacfeba7"},
+ {file = "safetensors-0.4.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:1b6fa399f251bbeb52029bf5a0ac2878d7705dd3612a2f8895b48e9c11f0367d"},
+ {file = "safetensors-0.4.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:de642d46b459e4afd5c2020b26c0d6d869a171ea00411897d5776c127cac74f0"},
+ {file = "safetensors-0.4.2-cp37-none-win32.whl", hash = "sha256:77b72d17754c93bb68f3598182f14d78776e0b9b31682ca5bb2c7c5bd9a75267"},
+ {file = "safetensors-0.4.2-cp37-none-win_amd64.whl", hash = "sha256:d36ee3244d461cd655aeef493792c3bccf4875282f8407fd9af99e9a41cf2530"},
+ {file = "safetensors-0.4.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:16b6b3884f7876c6b3b23a742428223a7170a5a9dac819d8c12a1569422c4b5a"},
+ {file = "safetensors-0.4.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ee25d311493fbbe0be9d395faee46e9d79e8948f461e388ff39e59875ed9a350"},
+ {file = "safetensors-0.4.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eed8097968585cd752a1171f86fce9aa1d89a29033e5cd8bec5a502e29f6b7af"},
+ {file = "safetensors-0.4.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:880e6865cf72cb67f9ab8d04a3c4b49dd95ae92fb1583929ce65aed94e1f685f"},
+ {file = "safetensors-0.4.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91290f83daf80ce6d1a7f629b244443c200060a80f908b29d879021409e5ea94"},
+ {file = "safetensors-0.4.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3517d568486ab3508a7acc360b82d7a4a3e26b86efdf210a9ecd9d233c40708a"},
+ {file = "safetensors-0.4.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1f43a77eb38540f782999e5dc5645164fe9027d3f0194f6c9a5126168017efa"},
+ {file = "safetensors-0.4.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b684d9818aa5d63fddc65f7d0151968037d255d91adf74eba82125b41c680aaa"},
+ {file = "safetensors-0.4.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ab1f5d84185f9fefaf21413efb764e4908057b8a9a0b987ede890c353490fd70"},
+ {file = "safetensors-0.4.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2bd979642e6c3a517ef4b84ff36c2fee4015664fea05a61154fc565978347553"},
+ {file = "safetensors-0.4.2-cp38-none-win32.whl", hash = "sha256:11be6e7afed29e5a5628f0aa6214e34bc194da73f558dc69fc7d56e07037422a"},
+ {file = "safetensors-0.4.2-cp38-none-win_amd64.whl", hash = "sha256:2f7a6e5d29bd2cc340cffaa391fa437b1be9d21a2bd8b8724d2875d13a6ef2a9"},
+ {file = "safetensors-0.4.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a5a921b4fe6925f9942adff3ebae8c16e0487908c54586a5a42f35b59fd69794"},
+ {file = "safetensors-0.4.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b691727228c28f2d82d8a92b2bc26e7a1f129ee40b2f2a3185b5974e038ed47c"},
+ {file = "safetensors-0.4.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91ca1056decc4e981248786e87b2a202d4841ee5f99d433f1adf3d44d4bcfa0e"},
+ {file = "safetensors-0.4.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:55969fd2e6fdb38dc221b0ab380668c21b0efa12a7562db9924759faa3c51757"},
+ {file = "safetensors-0.4.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6ae429bfaecc10ab5fe78c93009b3d1656c1581da560041e700eadb497dbe7a4"},
+ {file = "safetensors-0.4.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4ff88f194fe4ac50b463a4a6f0c03af9ad72eb5d24ec6d6730af59522e37fedb"},
+ {file = "safetensors-0.4.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a80cb48d0a447f8dd18e61813efa7d3f8f8d52edf0f05806abc0c59b83431f57"},
+ {file = "safetensors-0.4.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b286fb7adfee70a4189898ac2342b8a67d5f493e6b21b0af89ca8eac1b967cbf"},
+ {file = "safetensors-0.4.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0ceeff9ddbab4f78738489eb6682867ae946178776f33699737b2129b5394dc1"},
+ {file = "safetensors-0.4.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a26fae748a7488cb3aac381eddfa818c42052c87b5e689fb4c6e82ed58cec209"},
+ {file = "safetensors-0.4.2-cp39-none-win32.whl", hash = "sha256:039a42ab33c9d68b39706fd38f1922ace26866eff246bf20271edb619f5f848b"},
+ {file = "safetensors-0.4.2-cp39-none-win_amd64.whl", hash = "sha256:b3a3e1f5b85859e398773f064943b62a4059f225008a2a8ee6add1edcf77cacf"},
+ {file = "safetensors-0.4.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:4e70d442ad17e8b153ef9095bf48ea64f15a66bf26dc2b6ca94660c154edbc24"},
+ {file = "safetensors-0.4.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:b90f1d9809caf4ff395951b4703295a68d12907f6945bbc3129e934ff8ae46f6"},
+ {file = "safetensors-0.4.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c7ac9ad3728838006598e296b3ae9f27d80b489effd4685b92d97b3fc4c98f6"},
+ {file = "safetensors-0.4.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de5730d77e6ff7f4c7039e20913661ad0ea2f86c09e71c039e73dfdd1f394f08"},
+ {file = "safetensors-0.4.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:44feb8cb156d6803dcd19fc6b81b27235f29b877660605a6ac35e1da7d64f0e4"},
+ {file = "safetensors-0.4.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:523a241c33e7c827ab9a3a23760d75c7d062f43dfe55b6b019409f89b0fb52d1"},
+ {file = "safetensors-0.4.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:fb18300e8eb74291225214f26c9a8ae2110fd61a6c9b5a2ff4c4e0eb1bb9a998"},
+ {file = "safetensors-0.4.2-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:fe5437ff9fb116e44f2ab558981249ae63f978392b4576e62fcfe167d353edbc"},
+ {file = "safetensors-0.4.2-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9304a0934ced5a5d272f39de36291dc141dfc152d277f03fb4d65f2fb2ffa7c"},
+ {file = "safetensors-0.4.2-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:160ba1b1e11cf874602c233ab80a14f588571d09556cbc3586900121d622b5ed"},
+ {file = "safetensors-0.4.2-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04fcd6fcf7d9c13c7e5dc7e08de5e492ee4daa8f4ad74b4d8299d3eb0224292f"},
+ {file = "safetensors-0.4.2-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:906d14c4a677d35834fb0f3a5455ef8305e1bba10a5e0f2e0f357b3d1ad989f2"},
+ {file = "safetensors-0.4.2-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:df3fcdec0cd543084610d1f09c65cdb10fb3079f79bceddc092b0d187c6a265b"},
+ {file = "safetensors-0.4.2-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5ca76f13fb1cef242ea3ad2cb37388e7d005994f42af8b44bee56ba48b2d45ce"},
+ {file = "safetensors-0.4.2-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:278a1a3414c020785decdcd741c578725721274d2f9f787fcc930882e83b89cc"},
+ {file = "safetensors-0.4.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05b5a461cc68ecd42d9d546e5e1268a39d8ede7934a68d1ce17c3c659cb829d6"},
+ {file = "safetensors-0.4.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2341411412a41671d25e26bed59ec121e46bf4fadb8132895e610411c4b9681"},
+ {file = "safetensors-0.4.2-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3497ac3895acf17c5f98197f1fa4769f09c5e7ede07fcb102f1c201e663e052c"},
+ {file = "safetensors-0.4.2-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:01b5e71d3754d2201294f1eb7a6d59cce3a5702ff96d83d226571b2ca2183837"},
+ {file = "safetensors-0.4.2-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:3627dbd1ea488dd8046a0491de5087f3c0d641e7acc80c0189a33c69398f1cd1"},
+ {file = "safetensors-0.4.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:9d56f0ef53afad26ec54ceede78a43e9a23a076dadbbda7b44d304c591abf4c1"},
+ {file = "safetensors-0.4.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:b259ca73d42daf658a1bda463f1f83885ae4d93a60869be80d7f7dfcc9d8bbb5"},
+ {file = "safetensors-0.4.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ebc3cd401e4eb54e7c0a70346be565e81942d9a41fafd5f4bf7ab3a55d10378"},
+ {file = "safetensors-0.4.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5bc384a0309b706aa0425c93abb0390508a61bf029ce99c7d9df4220f25871a5"},
+ {file = "safetensors-0.4.2-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:af2d8f7235d8a08fbccfb8394387890e7fa38942b349a94e6eff13c52ac98087"},
+ {file = "safetensors-0.4.2-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:0911315bbcc5289087d063c2c2c7ccd711ea97a7e557a7bce005ac2cf80146aa"},
+ {file = "safetensors-0.4.2-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:1efe31673be91832d73439a2af426743e1395fc9ef7b081914e9e1d567bd7b5f"},
+ {file = "safetensors-0.4.2.tar.gz", hash = "sha256:acc85dcb09ec5e8aa787f588d7ad4d55c103f31e4ff060e17d92cc0e8b8cac73"},
]
[package.extras]
all = ["safetensors[jax]", "safetensors[numpy]", "safetensors[paddlepaddle]", "safetensors[pinned-tf]", "safetensors[quality]", "safetensors[testing]", "safetensors[torch]"]
dev = ["safetensors[all]"]
jax = ["flax (>=0.6.3)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "safetensors[numpy]"]
+mlx = ["mlx (>=0.0.9)"]
numpy = ["numpy (>=1.21.6)"]
paddlepaddle = ["paddlepaddle (>=2.4.1)", "safetensors[numpy]"]
pinned-tf = ["safetensors[numpy]", "tensorflow (==2.11.0)"]
@@ -3814,15 +3949,77 @@ nativelib = ["pyobjc-framework-Cocoa", "pywin32"]
objc = ["pyobjc-framework-Cocoa"]
win32 = ["pywin32"]
+[[package]]
+name = "sentencepiece"
+version = "0.2.0"
+description = "SentencePiece python wrapper"
+optional = false
+python-versions = "*"
+files = [
+ {file = "sentencepiece-0.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:188779e1298a1c8b8253c7d3ad729cb0a9891e5cef5e5d07ce4592c54869e227"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bed9cf85b296fa2b76fc2547b9cbb691a523864cebaee86304c43a7b4cb1b452"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d7b67e724bead13f18db6e1d10b6bbdc454af574d70efbb36f27d90387be1ca3"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2fde4b08cfe237be4484c6c7c2e2c75fb862cfeab6bd5449ce4caeafd97b767a"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c378492056202d1c48a4979650981635fd97875a00eabb1f00c6a236b013b5e"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1380ce6540a368de2ef6d7e6ba14ba8f3258df650d39ba7d833b79ee68a52040"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-win32.whl", hash = "sha256:a1151d6a6dd4b43e552394aed0edfe9292820272f0194bd56c7c1660a0c06c3d"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:d490142b0521ef22bc1085f061d922a2a6666175bb6b42e588ff95c0db6819b2"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:17982700c4f6dbb55fa3594f3d7e5dd1c8659a274af3738e33c987d2a27c9d5c"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7c867012c0e8bcd5bdad0f791609101cb5c66acb303ab3270218d6debc68a65e"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7fd6071249c74f779c5b27183295b9202f8dedb68034e716784364443879eaa6"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27f90c55a65013cbb8f4d7aab0599bf925cde4adc67ae43a0d323677b5a1c6cb"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b293734059ef656dcd65be62ff771507bea8fed0a711b6733976e1ed3add4553"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e58b47f933aca74c6a60a79dcb21d5b9e47416256c795c2d58d55cec27f9551d"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-win32.whl", hash = "sha256:c581258cf346b327c62c4f1cebd32691826306f6a41d8c4bec43b010dee08e75"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:0993dbc665f4113017892f1b87c3904a44d0640eda510abcacdfb07f74286d36"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea5f536e32ea8ec96086ee00d7a4a131ce583a1b18d130711707c10e69601cb2"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0cb51f53b6aae3c36bafe41e86167c71af8370a039f542c43b0cce5ef24a68c"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3212121805afc58d8b00ab4e7dd1f8f76c203ddb9dc94aa4079618a31cf5da0f"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a3149e3066c2a75e0d68a43eb632d7ae728c7925b517f4c05c40f6f7280ce08"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:632f3594d3e7ac8b367bca204cb3fd05a01d5b21455acd097ea4c0e30e2f63d7"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f295105c6bdbb05bd5e1b0cafbd78ff95036f5d3641e7949455a3f4e5e7c3109"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-win32.whl", hash = "sha256:fb89f811e5efd18bab141afc3fea3de141c3f69f3fe9e898f710ae7fe3aab251"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a673a72aab81fef5ebe755c6e0cc60087d1f3a4700835d40537183c1703a45f"},
+ {file = "sentencepiece-0.2.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4547683f330289ec4f093027bfeb87f9ef023b2eb6f879fdc4a8187c7e0ffb90"},
+ {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cd6175f7eaec7142d2bf6f6597ce7db4c9ac89acf93fcdb17410c3a8b781eeb"},
+ {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:859ba1acde782609a0910a26a60e16c191a82bf39b5621107552c0cd79fad00f"},
+ {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcbbef6cc277f8f18f36959e305f10b1c620442d75addc79c21d7073ae581b50"},
+ {file = "sentencepiece-0.2.0-cp36-cp36m-win32.whl", hash = "sha256:536b934e244829e3fe6c4f198652cd82da48adb9aa145c9f00889542726dee3d"},
+ {file = "sentencepiece-0.2.0-cp36-cp36m-win_amd64.whl", hash = "sha256:0a91aaa3c769b52440df56fafda683b3aa48e3f2169cf7ee5b8c8454a7f3ae9b"},
+ {file = "sentencepiece-0.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:787e480ca4c1d08c9985a7eb1eae4345c107729c99e9b5a9a00f2575fc7d4b4b"},
+ {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4d158189eb2ecffea3a51edf6d25e110b3678ec47f1a40f2d541eafbd8f6250"},
+ {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1e5ca43013e8935f25457a4fca47e315780172c3e821b4b13a890668911c792"},
+ {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7140d9e5a74a0908493bb4a13f1f16a401297bd755ada4c707e842fbf6f0f5bf"},
+ {file = "sentencepiece-0.2.0-cp37-cp37m-win32.whl", hash = "sha256:6cf333625234f247ab357b0bd9836638405ea9082e1543d5b8408f014979dcbf"},
+ {file = "sentencepiece-0.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:ff88712338b01031910e8e61e7239aff3ce8869ee31a47df63cb38aadd591bea"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20813a68d4c221b1849c62c30e1281ea81687894d894b8d4a0f4677d9311e0f5"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:926ef920ae2e8182db31d3f5d081ada57804e3e1d3a8c4ef8b117f9d9fb5a945"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:89f65f69636b7e9c015b79dff9c9985a9bc7d19ded6f79ef9f1ec920fdd73ecf"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f67eae0dbe6f2d7d6ba50a354623d787c99965f068b81e145d53240198021b0"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:98501e075f35dd1a1d5a20f65be26839fcb1938752ec61539af008a5aa6f510b"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3d1d2cc4882e8d6a1adf9d5927d7716f80617fc693385661caff21888972269"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-win32.whl", hash = "sha256:b99a308a2e5e569031ab164b74e6fab0b6f37dfb493c32f7816225f4d411a6dd"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:cdb701eec783d3ec86b7cd4c763adad8eaf6b46db37ee1c36e5e6c44b3fe1b5f"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1e0f9c4d0a6b0af59b613175f019916e28ade076e21242fd5be24340d8a2f64a"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:298f21cc1366eb60311aedba3169d30f885c363ddbf44214b0a587d2908141ad"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3f1ec95aa1e5dab11f37ac7eff190493fd87770f7a8b81ebc9dd768d1a3c8704"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b06b70af54daa4b4904cbb90b4eb6d35c9f3252fdc86c9c32d5afd4d30118d8"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22e37bac44dd6603388cb598c64ff7a76e41ca774646f21c23aadfbf5a2228ab"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0461324897735512a32d222e3d886e24ad6a499761952b6bda2a9ee6e4313ea5"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-win32.whl", hash = "sha256:38aed822fb76435fa1f12185f10465a94ab9e51d5e8a9159e9a540ce926f0ffd"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:d8cf876516548b5a1d6ac4745d8b554f5c07891d55da557925e5c13ff0b4e6ad"},
+ {file = "sentencepiece-0.2.0.tar.gz", hash = "sha256:a52c19171daaf2e697dc6cbe67684e0fa341b1248966f6aebb541de654d15843"},
+]
+
[[package]]
name = "sentry-sdk"
-version = "1.32.0"
+version = "1.44.1"
description = "Python client for Sentry (https://sentry.io)"
optional = false
python-versions = "*"
files = [
- {file = "sentry-sdk-1.32.0.tar.gz", hash = "sha256:935e8fbd7787a3702457393b74b13d89a5afb67185bc0af85c00cb27cbd42e7c"},
- {file = "sentry_sdk-1.32.0-py2.py3-none-any.whl", hash = "sha256:eeb0b3550536f3bbc05bb1c7e0feb3a78d74acb43b607159a606ed2ec0a33a4d"},
+ {file = "sentry-sdk-1.44.1.tar.gz", hash = "sha256:24e6a53eeabffd2f95d952aa35ca52f0f4201d17f820ac9d3ff7244c665aaf68"},
+ {file = "sentry_sdk-1.44.1-py2.py3-none-any.whl", hash = "sha256:5f75eb91d8ab6037c754a87b8501cc581b2827e923682f593bed3539ce5b3999"},
]
[package.dependencies]
@@ -3836,6 +4033,7 @@ asyncpg = ["asyncpg (>=0.23)"]
beam = ["apache-beam (>=2.12)"]
bottle = ["bottle (>=0.12.13)"]
celery = ["celery (>=3)"]
+celery-redbeat = ["celery-redbeat (>=2)"]
chalice = ["chalice (>=1.16.0)"]
clickhouse-driver = ["clickhouse-driver (>=0.2.0)"]
django = ["django (>=1.8)"]
@@ -3846,6 +4044,7 @@ grpcio = ["grpcio (>=1.21.1)"]
httpx = ["httpx (>=0.16.0)"]
huey = ["huey (>=2)"]
loguru = ["loguru (>=0.5)"]
+openai = ["openai (>=1.0.0)", "tiktoken (>=0.3.0)"]
opentelemetry = ["opentelemetry-distro (>=0.35b0)"]
opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"]
pure-eval = ["asttokens", "executing", "pure-eval"]
@@ -3961,19 +4160,30 @@ test = ["pytest"]
[[package]]
name = "setuptools"
-version = "68.2.2"
+version = "69.2.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
optional = false
python-versions = ">=3.8"
files = [
- {file = "setuptools-68.2.2-py3-none-any.whl", hash = "sha256:b454a35605876da60632df1a60f736524eb73cc47bbc9f3f1ef1b644de74fd2a"},
- {file = "setuptools-68.2.2.tar.gz", hash = "sha256:4ac1475276d2f1c48684874089fefcd83bd7162ddaafb81fac866ba0db282a87"},
+ {file = "setuptools-69.2.0-py3-none-any.whl", hash = "sha256:c21c49fb1042386df081cb5d86759792ab89efca84cf114889191cd09aacc80c"},
+ {file = "setuptools-69.2.0.tar.gz", hash = "sha256:0ff4183f8f42cd8fa3acea16c45205521a4ef28f73c6391d8a25e92893134f2e"},
]
[package.extras]
-docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
-testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
-testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
+docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
+testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
+testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
+
+[[package]]
+name = "shellingham"
+version = "1.5.4"
+description = "Tool to Detect Surrounding Shell"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686"},
+ {file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"},
+]
[[package]]
name = "six"
@@ -3999,13 +4209,13 @@ files = [
[[package]]
name = "sniffio"
-version = "1.3.0"
+version = "1.3.1"
description = "Sniff out which async library your code is running under"
optional = false
python-versions = ">=3.7"
files = [
- {file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"},
- {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"},
+ {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"},
+ {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"},
]
[[package]]
@@ -4268,13 +4478,13 @@ doc = ["reno", "sphinx", "tornado (>=4.5)"]
[[package]]
name = "terminado"
-version = "0.17.1"
+version = "0.18.1"
description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library."
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
files = [
- {file = "terminado-0.17.1-py3-none-any.whl", hash = "sha256:8650d44334eba354dd591129ca3124a6ba42c3d5b70df5051b6921d506fdaeae"},
- {file = "terminado-0.17.1.tar.gz", hash = "sha256:6ccbbcd3a4f8a25a5ec04991f39a0b8db52dfcd487ea0e578d977e6752380333"},
+ {file = "terminado-0.18.1-py3-none-any.whl", hash = "sha256:a4468e1b37bb318f8a86514f65814e1afc977cf29b3992a4500d9dd305dcceb0"},
+ {file = "terminado-0.18.1.tar.gz", hash = "sha256:de09f2c4b85de4765f7714688fff57d3e75bad1f909b589fde880460c753fd2e"},
]
[package.dependencies]
@@ -4285,6 +4495,7 @@ tornado = ">=6.1.0"
[package.extras]
docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"]
test = ["pre-commit", "pytest (>=7.0)", "pytest-timeout"]
+typing = ["mypy (>=1.6,<2.0)", "traitlets (>=5.11.1)"]
[[package]]
name = "tinycss2"
@@ -4306,113 +4517,125 @@ test = ["flake8", "isort", "pytest"]
[[package]]
name = "tokenizers"
-version = "0.14.1"
+version = "0.15.2"
description = ""
optional = false
python-versions = ">=3.7"
files = [
- {file = "tokenizers-0.14.1-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:04ec1134a18ede355a05641cdc7700f17280e01f69f2f315769f02f7e295cf1e"},
- {file = "tokenizers-0.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:638abedb39375f0ddce2de536fc9c976639b2d1b7202d715c2e7a25f0ebfd091"},
- {file = "tokenizers-0.14.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:901635098565773a44f74068639d265f19deaaca47ea77b428fd9bee13a61d87"},
- {file = "tokenizers-0.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:72e95184bf5b9a4c08153ed07c16c130ff174835c9a1e6ee2b311be758c8b3ef"},
- {file = "tokenizers-0.14.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ebefbc26ccff5e96ae7d40772172e7310174f9aa3683d2870a1882313ec3a4d5"},
- {file = "tokenizers-0.14.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d3a6330c9f1deda22873e8b4ac849cc06d3ff33d60b3217ac0bb397b541e1509"},
- {file = "tokenizers-0.14.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6cba7483ba45600346a35c466bde32327b108575022f73c35a0f7170b5a71ae2"},
- {file = "tokenizers-0.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:60fec380778d75cbb492f14ca974f11f37b41d53c057b9c8ba213315b86e1f84"},
- {file = "tokenizers-0.14.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:930c19b699dd7e1077eac98967adc2fe5f0b104bd96cc1f26778ab82b31ceb24"},
- {file = "tokenizers-0.14.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a1e30a13376db5329570e09b14c8eb36c017909ed7e88591ca3aa81f3c7d6f32"},
- {file = "tokenizers-0.14.1-cp310-none-win32.whl", hash = "sha256:370b5b86da9bddbe65fa08711f0e8ffdf8b0036558178d1a31dfcb44efcde72a"},
- {file = "tokenizers-0.14.1-cp310-none-win_amd64.whl", hash = "sha256:c2c659f2106b6d154f118ad1b700e68148c46c59b720f04867b1fc5f26a85060"},
- {file = "tokenizers-0.14.1-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:00df4c5bf25c153b432b98689609b426ae701a44f3d8074dcb619f410bc2a870"},
- {file = "tokenizers-0.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fee553657dcdb7e73df8823c49e8611457ba46e9d7026b7e9c44820c08c327c3"},
- {file = "tokenizers-0.14.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a480bd902e327dfcaa52b7dd14fdc71e7aa45d73a3d6e41e028a75891d2823cf"},
- {file = "tokenizers-0.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e448b2be0430ab839cf7954715c39d6f34ff6cf2b49393f336283b7a59f485af"},
- {file = "tokenizers-0.14.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c11444984aecd342f0cf160c3320288edeb1763871fbb560ed466654b2a7016c"},
- {file = "tokenizers-0.14.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bfe164a1c72c6be3c5c26753c6c412f81412f4dae0d7d06371e0b396a9cc0fc9"},
- {file = "tokenizers-0.14.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:72d9967fb1f927542cfb5347207fde01b29f25c9bb8cbc7ced280decfa015983"},
- {file = "tokenizers-0.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37cc955c84ec67c2d11183d372044399342b20a1fa447b7a33040f4889bba318"},
- {file = "tokenizers-0.14.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:db96cf092d86d4cb543daa9148e299011e0a40770380bb78333b9fd700586fcb"},
- {file = "tokenizers-0.14.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c84d3cb1349936c2b96ca6175b50f5a9518170bffd76464219ee0ea6022a64a7"},
- {file = "tokenizers-0.14.1-cp311-none-win32.whl", hash = "sha256:8db3a6f3d430ac3dc3793c53fa8e5e665c23ba359484d365a191027ad8b65a30"},
- {file = "tokenizers-0.14.1-cp311-none-win_amd64.whl", hash = "sha256:c65d76052561c60e17cb4fa289885ed00a9995d59e97019fac2138bd45142057"},
- {file = "tokenizers-0.14.1-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:c375161b588982be381c43eb7158c250f430793d0f708ce379a0f196164c6778"},
- {file = "tokenizers-0.14.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:50f03d2330a153a9114c2429061137bd323736059f384de8348d7cb1ca1baa15"},
- {file = "tokenizers-0.14.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0c8ee283b249c3c3c201c41bc23adc3be2514ae4121eacdb5c5250a461eaa8c6"},
- {file = "tokenizers-0.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e9f27399b8d50c5d3f08f0aae961bcc66a1dead1cd0ae9401e4c2a43a623322a"},
- {file = "tokenizers-0.14.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:89cbeec7e9d5d8773ec4779c64e3cbcbff53d234ca6ad7b1a3736588003bba48"},
- {file = "tokenizers-0.14.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:08e55920b453c30b46d58accc68a38e8e7488d0c03babfdb29c55d3f39dd2052"},
- {file = "tokenizers-0.14.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:91d32bd1056c0e83a0f90e4ffa213c25096b2d8b9f0e2d172a45f138c7d8c081"},
- {file = "tokenizers-0.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44f1748035c36c939848c935715bde41734d9249ab7b844ff9bfbe984be8952c"},
- {file = "tokenizers-0.14.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:1ff516d129f01bb7a4aa95bc6aae88e4d86dd63bfc2d57db9302c2624d1be7cb"},
- {file = "tokenizers-0.14.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:acfc8db61c6e919d932448cc7985b85e330c8d745528e12fce6e62d40d268bce"},
- {file = "tokenizers-0.14.1-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:ba336bc9107acbc1da2ad30967df7b2db93448ca66538ad86aa1fbb91116f631"},
- {file = "tokenizers-0.14.1-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:f77371b5030e53f8bf92197640af437539e3bba1bc8342b97888c8e26567bfdc"},
- {file = "tokenizers-0.14.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:d72d25c57a9c814240802d188ff0a808b701e2dd2bf1c64721c7088ceeeb1ed7"},
- {file = "tokenizers-0.14.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caf0df8657277e32671aa8a4d3cc05f2050ab19d9b49447f2265304168e9032c"},
- {file = "tokenizers-0.14.1-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cb3c6bc6e599e46a26ad559ad5dec260ffdf705663cc9b894033d64a69314e86"},
- {file = "tokenizers-0.14.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8cf2fcdc2368df4317e05571e33810eeed24cd594acc9dfc9788b21dac6b3a8"},
- {file = "tokenizers-0.14.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f475d5eda41d2ed51ca775a07c80529a923dd759fcff7abf03ccdd83d9f7564e"},
- {file = "tokenizers-0.14.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cce4d1a97a7eb2253b5d3f29f4a478d8c37ba0303ea34024eb9e65506d4209f8"},
- {file = "tokenizers-0.14.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:ff66577ae55114f7d0f6aa0d4d335f27cae96bf245962a745b718ec887bbe7eb"},
- {file = "tokenizers-0.14.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a687099e085f5162e5b88b3402adb6c2b41046180c015c5075c9504440b6e971"},
- {file = "tokenizers-0.14.1-cp37-none-win32.whl", hash = "sha256:49f5336b82e315a33bef1025d247ca08d95719715b29e33f0e9e8cf15ff1dfb6"},
- {file = "tokenizers-0.14.1-cp37-none-win_amd64.whl", hash = "sha256:117c8da60d1bd95a6df2692926f36de7971baa1d89ff702fae47b6689a4465ad"},
- {file = "tokenizers-0.14.1-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:01d2bd5935642de22a6c6778bb2307f9949cd6eaeeb5c77f9b98f0060b69f0db"},
- {file = "tokenizers-0.14.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b05ec04132394c20bd6bcb692d557a8eb8ab1bac1646d28e49c67c00907d17c8"},
- {file = "tokenizers-0.14.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:7d9025b185465d9d18679406f6f394850347d5ed2681efc203539d800f36f459"},
- {file = "tokenizers-0.14.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2539831838ab5393f78a893d7bbf27d5c36e43baf77e91dc9992922b2b97e09d"},
- {file = "tokenizers-0.14.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ec8f46d533092d8e20bc742c47918cbe24b8641dbfbbcb83177c5de3c9d4decb"},
- {file = "tokenizers-0.14.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8b019c4810903fdea3b230f358b9d27377c0f38454778b607676c9e1b57d14b7"},
- {file = "tokenizers-0.14.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e8984114fd83ed3913d89526c992395920930c9620a2feee61faf035f41d7b9a"},
- {file = "tokenizers-0.14.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11284b32f0036fe7ef4b8b00201dda79c00f3fcea173bc0e5c599e09c937ab0f"},
- {file = "tokenizers-0.14.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:53614f44f36917282a583180e402105bc63d61d1aca067d51cb7f051eb489901"},
- {file = "tokenizers-0.14.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e3b6082e9532309727273443c8943bb9558d52e36788b246aa278bda7c642116"},
- {file = "tokenizers-0.14.1-cp38-none-win32.whl", hash = "sha256:7560fca3e17a6bc876d20cd825d7721c101fa2b1cd0bfa0abf9a2e781e49b37b"},
- {file = "tokenizers-0.14.1-cp38-none-win_amd64.whl", hash = "sha256:c318a5acb429ca38f632577754235140bbb8c5a27faca1c51b43fbf575596e34"},
- {file = "tokenizers-0.14.1-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:b886e0f5c72aa4249c609c24b9610a9ca83fd963cbb5066b19302723ea505279"},
- {file = "tokenizers-0.14.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f522f28c88a0d5b2f9e895cf405dd594cd518e99d61905406aec74d30eb6383b"},
- {file = "tokenizers-0.14.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5bef76c4d9329913cef2fe79ce1f4dab98f77fa4887e5f0420ffc9386941de32"},
- {file = "tokenizers-0.14.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59c7df2103052b30b7c76d4fa8251326c9f82689578a912698a127dc1737f43e"},
- {file = "tokenizers-0.14.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:232445e7b85255ccfe68dfd42185db8a3f3349b34ad7068404856c4a5f67c355"},
- {file = "tokenizers-0.14.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8e63781da85aa8948864970e529af10abc4084a990d30850c41bbdb5f83eee45"},
- {file = "tokenizers-0.14.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5760a831c0f3c6d3229b50ef3fafa4c164ec99d7e8c2237fe144e67a9d33b120"},
- {file = "tokenizers-0.14.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c84b456ff8525ec3ff09762e32ccc27888d036dcd0ba2883e1db491e164dd725"},
- {file = "tokenizers-0.14.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:463ee5f3afbfec29cbf5652752c9d1032bdad63daf48bb8cb9970064cc81d5f9"},
- {file = "tokenizers-0.14.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ee6b63aecf929a7bcf885bdc8a8aec96c43bc4442f63fe8c6d48f24fc992b05b"},
- {file = "tokenizers-0.14.1-cp39-none-win32.whl", hash = "sha256:aae42798ba1da3bc1572b2048fe42e61dd6bacced2b424cb0f5572c5432f79c2"},
- {file = "tokenizers-0.14.1-cp39-none-win_amd64.whl", hash = "sha256:68c4699147dded6926a3d2c2f948d435d54d027f69909e0ef3c6587933723ed2"},
- {file = "tokenizers-0.14.1-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:5f9afdcf701a1aa3c41e0e748c152d2162434d61639a1e5d8523ecf60ae35aea"},
- {file = "tokenizers-0.14.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:6859d81243cd09854be9054aca3ecab14a2dee5b3c9f6d7ef12061d478ca0c57"},
- {file = "tokenizers-0.14.1-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:7975178f9478ccedcf613332d5d6f37b67c74ef4e2e47e0c965597506b921f04"},
- {file = "tokenizers-0.14.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ce2f0ff2e5f12ac5bebaa690606395725239265d7ffa35f35c243a379316297"},
- {file = "tokenizers-0.14.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c7cfc3d42e81cda802f93aa9e92caf79feaa1711426e28ce620560b8aaf5e4d"},
- {file = "tokenizers-0.14.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:67d3adff654dc7f7c7091dd259b3b847fe119c08d0bda61db91e2ea2b61c38c0"},
- {file = "tokenizers-0.14.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:956729b7dd599020e57133fb95b777e4f81ee069ff0a70e80f6eeac82658972f"},
- {file = "tokenizers-0.14.1-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:fe2ea1177146a7ab345ab61e90a490eeea25d5f063e1cb9d4eb1425b169b64d7"},
- {file = "tokenizers-0.14.1-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:9930f31f603ecc6ea54d5c6dfa299f926ab3e921f72f94babcb02598c32b57c6"},
- {file = "tokenizers-0.14.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d49567a2754e9991c05c2b5a7e6650b56e24365b7cab504558e58033dcf0edc4"},
- {file = "tokenizers-0.14.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3678be5db330726f19c1949d8ae1b845a02eeb2a2e1d5a8bb8eaa82087ae25c1"},
- {file = "tokenizers-0.14.1-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:42b180ed1bec58ab9bdc65d406577e0c0fb7241b74b8c032846073c7743c9f86"},
- {file = "tokenizers-0.14.1-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:319e4367596fb0d52be645b3de1616faf0fadaf28507ce1c7595bebd9b4c402c"},
- {file = "tokenizers-0.14.1-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:2cda65b689aec63b7c76a77f43a08044fa90bbc6ad9849267cedfee9795913f3"},
- {file = "tokenizers-0.14.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:ca0bfc79b27d84fcb7fa09339b2ee39077896738d9a30ff99c0332376e985072"},
- {file = "tokenizers-0.14.1-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a7093767e070269e22e2c5f845e46510304f124c32d2cd249633c0f27eb29d86"},
- {file = "tokenizers-0.14.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad759ba39cd32c2c2247864d02c84ea5883b5f6cc6a4ee0c95602a3dde52268f"},
- {file = "tokenizers-0.14.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26fee36a6d8f2bd9464f3566b95e3e3fb7fd7dad723f775c500aac8204ec98c6"},
- {file = "tokenizers-0.14.1-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d091c62cb7abbd32e527a85c41f7c8eb4526a926251891fc4ecbe5f974142ffb"},
- {file = "tokenizers-0.14.1-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:ca304402ea66d58f99c05aa3d7a6052faea61e5a8313b94f6bc36fbf27960e2d"},
- {file = "tokenizers-0.14.1-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:102f118fa9b720b93c3217c1e239ed7bc1ae1e8dbfe9b4983a4f2d7b4ce6f2ec"},
- {file = "tokenizers-0.14.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:df4f058e96e8b467b7742e5dba7564255cd482d3c1e6cf81f8cb683bb0433340"},
- {file = "tokenizers-0.14.1-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:040ee44efc1806900de72b13c1c3036154077d9cde189c9a7e7a50bbbdcbf39f"},
- {file = "tokenizers-0.14.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7618b84118ae704f7fa23c4a190bd80fc605671841a4427d5ca14b9b8d9ec1a3"},
- {file = "tokenizers-0.14.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ecdfe9736c4a73343f629586016a137a10faed1a29c6dc699d8ab20c2d3cf64"},
- {file = "tokenizers-0.14.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:92c34de04fec7f4ff95f7667d4eb085c4e4db46c31ef44c3d35c38df128430da"},
- {file = "tokenizers-0.14.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:628b654ba555b2ba9111c0936d558b14bfc9d5f57b8c323b02fc846036b38b2f"},
- {file = "tokenizers-0.14.1.tar.gz", hash = "sha256:ea3b3f8908a9a5b9d6fc632b5f012ece7240031c44c6d4764809f33736534166"},
+ {file = "tokenizers-0.15.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:52f6130c9cbf70544287575a985bf44ae1bda2da7e8c24e97716080593638012"},
+ {file = "tokenizers-0.15.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:054c1cc9c6d68f7ffa4e810b3d5131e0ba511b6e4be34157aa08ee54c2f8d9ee"},
+ {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a9b9b070fdad06e347563b88c278995735292ded1132f8657084989a4c84a6d5"},
+ {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea621a7eef4b70e1f7a4e84dd989ae3f0eeb50fc8690254eacc08acb623e82f1"},
+ {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cf7fd9a5141634fa3aa8d6b7be362e6ae1b4cda60da81388fa533e0b552c98fd"},
+ {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:44f2a832cd0825295f7179eaf173381dc45230f9227ec4b44378322d900447c9"},
+ {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8b9ec69247a23747669ec4b0ca10f8e3dfb3545d550258129bd62291aabe8605"},
+ {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40b6a4c78da863ff26dbd5ad9a8ecc33d8a8d97b535172601cf00aee9d7ce9ce"},
+ {file = "tokenizers-0.15.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5ab2a4d21dcf76af60e05af8063138849eb1d6553a0d059f6534357bce8ba364"},
+ {file = "tokenizers-0.15.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a47acfac7e511f6bbfcf2d3fb8c26979c780a91e06fb5b9a43831b2c0153d024"},
+ {file = "tokenizers-0.15.2-cp310-none-win32.whl", hash = "sha256:064ff87bb6acdbd693666de9a4b692add41308a2c0ec0770d6385737117215f2"},
+ {file = "tokenizers-0.15.2-cp310-none-win_amd64.whl", hash = "sha256:3b919afe4df7eb6ac7cafd2bd14fb507d3f408db7a68c43117f579c984a73843"},
+ {file = "tokenizers-0.15.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:89cd1cb93e4b12ff39bb2d626ad77e35209de9309a71e4d3d4672667b4b256e7"},
+ {file = "tokenizers-0.15.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:cfed5c64e5be23d7ee0f0e98081a25c2a46b0b77ce99a4f0605b1ec43dd481fa"},
+ {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a907d76dcfda37023ba203ab4ceeb21bc5683436ebefbd895a0841fd52f6f6f2"},
+ {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20ea60479de6fc7b8ae756b4b097572372d7e4032e2521c1bbf3d90c90a99ff0"},
+ {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:48e2b9335be2bc0171df9281385c2ed06a15f5cf121c44094338306ab7b33f2c"},
+ {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:112a1dd436d2cc06e6ffdc0b06d55ac019a35a63afd26475205cb4b1bf0bfbff"},
+ {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4620cca5c2817177ee8706f860364cc3a8845bc1e291aaf661fb899e5d1c45b0"},
+ {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ccd73a82751c523b3fc31ff8194702e4af4db21dc20e55b30ecc2079c5d43cb7"},
+ {file = "tokenizers-0.15.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:107089f135b4ae7817affe6264f8c7a5c5b4fd9a90f9439ed495f54fcea56fb4"},
+ {file = "tokenizers-0.15.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0ff110ecc57b7aa4a594396525a3451ad70988e517237fe91c540997c4e50e29"},
+ {file = "tokenizers-0.15.2-cp311-none-win32.whl", hash = "sha256:6d76f00f5c32da36c61f41c58346a4fa7f0a61be02f4301fd30ad59834977cc3"},
+ {file = "tokenizers-0.15.2-cp311-none-win_amd64.whl", hash = "sha256:cc90102ed17271cf0a1262babe5939e0134b3890345d11a19c3145184b706055"},
+ {file = "tokenizers-0.15.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f86593c18d2e6248e72fb91c77d413a815153b8ea4e31f7cd443bdf28e467670"},
+ {file = "tokenizers-0.15.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0774bccc6608eca23eb9d620196687c8b2360624619623cf4ba9dc9bd53e8b51"},
+ {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:d0222c5b7c9b26c0b4822a82f6a7011de0a9d3060e1da176f66274b70f846b98"},
+ {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3835738be1de66624fff2f4f6f6684775da4e9c00bde053be7564cbf3545cc66"},
+ {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0143e7d9dcd811855c1ce1ab9bf5d96d29bf5e528fd6c7824d0465741e8c10fd"},
+ {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db35825f6d54215f6b6009a7ff3eedee0848c99a6271c870d2826fbbedf31a38"},
+ {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3f5e64b0389a2be47091d8cc53c87859783b837ea1a06edd9d8e04004df55a5c"},
+ {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e0480c452217edd35eca56fafe2029fb4d368b7c0475f8dfa3c5c9c400a7456"},
+ {file = "tokenizers-0.15.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a33ab881c8fe70474980577e033d0bc9a27b7ab8272896e500708b212995d834"},
+ {file = "tokenizers-0.15.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a308a607ca9de2c64c1b9ba79ec9a403969715a1b8ba5f998a676826f1a7039d"},
+ {file = "tokenizers-0.15.2-cp312-none-win32.whl", hash = "sha256:b8fcfa81bcb9447df582c5bc96a031e6df4da2a774b8080d4f02c0c16b42be0b"},
+ {file = "tokenizers-0.15.2-cp312-none-win_amd64.whl", hash = "sha256:38d7ab43c6825abfc0b661d95f39c7f8af2449364f01d331f3b51c94dcff7221"},
+ {file = "tokenizers-0.15.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:38bfb0204ff3246ca4d5e726e8cc8403bfc931090151e6eede54d0e0cf162ef0"},
+ {file = "tokenizers-0.15.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:9c861d35e8286a53e06e9e28d030b5a05bcbf5ac9d7229e561e53c352a85b1fc"},
+ {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:936bf3842db5b2048eaa53dade907b1160f318e7c90c74bfab86f1e47720bdd6"},
+ {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:620beacc3373277700d0e27718aa8b25f7b383eb8001fba94ee00aeea1459d89"},
+ {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2735ecbbf37e52db4ea970e539fd2d450d213517b77745114f92867f3fc246eb"},
+ {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:473c83c5e2359bb81b0b6fde870b41b2764fcdd36d997485e07e72cc3a62264a"},
+ {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:968fa1fb3c27398b28a4eca1cbd1e19355c4d3a6007f7398d48826bbe3a0f728"},
+ {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:865c60ae6eaebdde7da66191ee9b7db52e542ed8ee9d2c653b6d190a9351b980"},
+ {file = "tokenizers-0.15.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7c0d8b52664ab2d4a8d6686eb5effc68b78608a9008f086a122a7b2996befbab"},
+ {file = "tokenizers-0.15.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:f33dfbdec3784093a9aebb3680d1f91336c56d86cc70ddf88708251da1fe9064"},
+ {file = "tokenizers-0.15.2-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:d44ba80988ff9424e33e0a49445072ac7029d8c0e1601ad25a0ca5f41ed0c1d6"},
+ {file = "tokenizers-0.15.2-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:dce74266919b892f82b1b86025a613956ea0ea62a4843d4c4237be2c5498ed3a"},
+ {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0ef06b9707baeb98b316577acb04f4852239d856b93e9ec3a299622f6084e4be"},
+ {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c73e2e74bbb07910da0d37c326869f34113137b23eadad3fc00856e6b3d9930c"},
+ {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4eeb12daf02a59e29f578a865f55d87cd103ce62bd8a3a5874f8fdeaa82e336b"},
+ {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ba9f6895af58487ca4f54e8a664a322f16c26bbb442effd01087eba391a719e"},
+ {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ccec77aa7150e38eec6878a493bf8c263ff1fa8a62404e16c6203c64c1f16a26"},
+ {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3f40604f5042ff210ba82743dda2b6aa3e55aa12df4e9f2378ee01a17e2855e"},
+ {file = "tokenizers-0.15.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5645938a42d78c4885086767c70923abad047163d809c16da75d6b290cb30bbe"},
+ {file = "tokenizers-0.15.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:05a77cbfebe28a61ab5c3891f9939cc24798b63fa236d84e5f29f3a85a200c00"},
+ {file = "tokenizers-0.15.2-cp37-none-win32.whl", hash = "sha256:361abdc068e8afe9c5b818769a48624687fb6aaed49636ee39bec4e95e1a215b"},
+ {file = "tokenizers-0.15.2-cp37-none-win_amd64.whl", hash = "sha256:7ef789f83eb0f9baeb4d09a86cd639c0a5518528f9992f38b28e819df397eb06"},
+ {file = "tokenizers-0.15.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:4fe1f74a902bee74a3b25aff180fbfbf4f8b444ab37c4d496af7afd13a784ed2"},
+ {file = "tokenizers-0.15.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4c4b89038a684f40a6b15d6b09f49650ac64d951ad0f2a3ea9169687bbf2a8ba"},
+ {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:d05a1b06f986d41aed5f2de464c003004b2df8aaf66f2b7628254bcbfb72a438"},
+ {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:508711a108684111ec8af89d3a9e9e08755247eda27d0ba5e3c50e9da1600f6d"},
+ {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:daa348f02d15160cb35439098ac96e3a53bacf35885072611cd9e5be7d333daa"},
+ {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:494fdbe5932d3416de2a85fc2470b797e6f3226c12845cadf054dd906afd0442"},
+ {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c2d60f5246f4da9373f75ff18d64c69cbf60c3bca597290cea01059c336d2470"},
+ {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93268e788825f52de4c7bdcb6ebc1fcd4a5442c02e730faa9b6b08f23ead0e24"},
+ {file = "tokenizers-0.15.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6fc7083ab404019fc9acafe78662c192673c1e696bd598d16dc005bd663a5cf9"},
+ {file = "tokenizers-0.15.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:41e39b41e5531d6b2122a77532dbea60e171ef87a3820b5a3888daa847df4153"},
+ {file = "tokenizers-0.15.2-cp38-none-win32.whl", hash = "sha256:06cd0487b1cbfabefb2cc52fbd6b1f8d4c37799bd6c6e1641281adaa6b2504a7"},
+ {file = "tokenizers-0.15.2-cp38-none-win_amd64.whl", hash = "sha256:5179c271aa5de9c71712e31cb5a79e436ecd0d7532a408fa42a8dbfa4bc23fd9"},
+ {file = "tokenizers-0.15.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:82f8652a74cc107052328b87ea8b34291c0f55b96d8fb261b3880216a9f9e48e"},
+ {file = "tokenizers-0.15.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:02458bee6f5f3139f1ebbb6d042b283af712c0981f5bc50edf771d6b762d5e4f"},
+ {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:c9a09cd26cca2e1c349f91aa665309ddb48d71636370749414fbf67bc83c5343"},
+ {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:158be8ea8554e5ed69acc1ce3fbb23a06060bd4bbb09029431ad6b9a466a7121"},
+ {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ddba9a2b0c8c81633eca0bb2e1aa5b3a15362b1277f1ae64176d0f6eba78ab1"},
+ {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3ef5dd1d39797044642dbe53eb2bc56435308432e9c7907728da74c69ee2adca"},
+ {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:454c203164e07a860dbeb3b1f4a733be52b0edbb4dd2e5bd75023ffa8b49403a"},
+ {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cf6b7f1d4dc59af960e6ffdc4faffe6460bbfa8dce27a58bf75755ffdb2526d"},
+ {file = "tokenizers-0.15.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2ef09bbc16519f6c25d0c7fc0c6a33a6f62923e263c9d7cca4e58b8c61572afb"},
+ {file = "tokenizers-0.15.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c9a2ebdd2ad4ec7a68e7615086e633857c85e2f18025bd05d2a4399e6c5f7169"},
+ {file = "tokenizers-0.15.2-cp39-none-win32.whl", hash = "sha256:918fbb0eab96fe08e72a8c2b5461e9cce95585d82a58688e7f01c2bd546c79d0"},
+ {file = "tokenizers-0.15.2-cp39-none-win_amd64.whl", hash = "sha256:524e60da0135e106b254bd71f0659be9f89d83f006ea9093ce4d1fab498c6d0d"},
+ {file = "tokenizers-0.15.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6a9b648a58281c4672212fab04e60648fde574877d0139cd4b4f93fe28ca8944"},
+ {file = "tokenizers-0.15.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:7c7d18b733be6bbca8a55084027f7be428c947ddf871c500ee603e375013ffba"},
+ {file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:13ca3611de8d9ddfbc4dc39ef54ab1d2d4aaa114ac8727dfdc6a6ec4be017378"},
+ {file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:237d1bf3361cf2e6463e6c140628e6406766e8b27274f5fcc62c747ae3c6f094"},
+ {file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67a0fe1e49e60c664915e9fb6b0cb19bac082ab1f309188230e4b2920230edb3"},
+ {file = "tokenizers-0.15.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:4e022fe65e99230b8fd89ebdfea138c24421f91c1a4f4781a8f5016fd5cdfb4d"},
+ {file = "tokenizers-0.15.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:d857be2df69763362ac699f8b251a8cd3fac9d21893de129bc788f8baaef2693"},
+ {file = "tokenizers-0.15.2-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:708bb3e4283177236309e698da5fcd0879ce8fd37457d7c266d16b550bcbbd18"},
+ {file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:64c35e09e9899b72a76e762f9854e8750213f67567787d45f37ce06daf57ca78"},
+ {file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1257f4394be0d3b00de8c9e840ca5601d0a4a8438361ce9c2b05c7d25f6057b"},
+ {file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02272fe48280e0293a04245ca5d919b2c94a48b408b55e858feae9618138aeda"},
+ {file = "tokenizers-0.15.2-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:dc3ad9ebc76eabe8b1d7c04d38be884b8f9d60c0cdc09b0aa4e3bcf746de0388"},
+ {file = "tokenizers-0.15.2-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:32e16bdeffa7c4f46bf2152172ca511808b952701d13e7c18833c0b73cb5c23f"},
+ {file = "tokenizers-0.15.2-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:fb16ba563d59003028b678d2361a27f7e4ae0ab29c7a80690efa20d829c81fdb"},
+ {file = "tokenizers-0.15.2-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:2277c36d2d6cdb7876c274547921a42425b6810d38354327dd65a8009acf870c"},
+ {file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:1cf75d32e8d250781940d07f7eece253f2fe9ecdb1dc7ba6e3833fa17b82fcbc"},
+ {file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1b3b31884dc8e9b21508bb76da80ebf7308fdb947a17affce815665d5c4d028"},
+ {file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b10122d8d8e30afb43bb1fe21a3619f62c3e2574bff2699cf8af8b0b6c5dc4a3"},
+ {file = "tokenizers-0.15.2-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d88b96ff0fe8e91f6ef01ba50b0d71db5017fa4e3b1d99681cec89a85faf7bf7"},
+ {file = "tokenizers-0.15.2-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:37aaec5a52e959892870a7c47cef80c53797c0db9149d458460f4f31e2fb250e"},
+ {file = "tokenizers-0.15.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:e2ea752f2b0fe96eb6e2f3adbbf4d72aaa1272079b0dfa1145507bd6a5d537e6"},
+ {file = "tokenizers-0.15.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:4b19a808d8799fda23504a5cd31d2f58e6f52f140380082b352f877017d6342b"},
+ {file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:64c86e5e068ac8b19204419ed8ca90f9d25db20578f5881e337d203b314f4104"},
+ {file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de19c4dc503c612847edf833c82e9f73cd79926a384af9d801dcf93f110cea4e"},
+ {file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea09acd2fe3324174063d61ad620dec3bcf042b495515f27f638270a7d466e8b"},
+ {file = "tokenizers-0.15.2-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:cf27fd43472e07b57cf420eee1e814549203d56de00b5af8659cb99885472f1f"},
+ {file = "tokenizers-0.15.2-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:7ca22bd897537a0080521445d91a58886c8c04084a6a19e6c78c586e0cfa92a5"},
+ {file = "tokenizers-0.15.2.tar.gz", hash = "sha256:e6e9c6e019dd5484be5beafc775ae6c925f4c69a3487040ed09b45e13df2cb91"},
]
[package.dependencies]
-huggingface_hub = ">=0.16.4,<0.18"
+huggingface_hub = ">=0.16.4,<1.0"
[package.extras]
dev = ["tokenizers[testing]"]
@@ -4432,42 +4655,42 @@ files = [
[[package]]
name = "tomlkit"
-version = "0.12.1"
+version = "0.12.4"
description = "Style preserving TOML library"
optional = false
python-versions = ">=3.7"
files = [
- {file = "tomlkit-0.12.1-py3-none-any.whl", hash = "sha256:712cbd236609acc6a3e2e97253dfc52d4c2082982a88f61b640ecf0817eab899"},
- {file = "tomlkit-0.12.1.tar.gz", hash = "sha256:38e1ff8edb991273ec9f6181244a6a391ac30e9f5098e7535640ea6be97a7c86"},
+ {file = "tomlkit-0.12.4-py3-none-any.whl", hash = "sha256:5cd82d48a3dd89dee1f9d64420aa20ae65cfbd00668d6f094d7578a78efbb77b"},
+ {file = "tomlkit-0.12.4.tar.gz", hash = "sha256:7ca1cfc12232806517a8515047ba66a19369e71edf2439d0f5824f91032b6cc3"},
]
[[package]]
name = "torch"
-version = "2.1.1"
+version = "2.1.2"
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
optional = false
python-versions = ">=3.8.0"
files = [
- {file = "torch-2.1.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:5ebc43f5355a9b7be813392b3fb0133991f0380f6f0fcc8218d5468dc45d1071"},
- {file = "torch-2.1.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:84fefd63356416c0cd20578637ccdbb82164993400ed17b57c951dd6376dcee8"},
- {file = "torch-2.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:0a7a9da0c324409bcb5a7bdad1b4e94e936d21c2590aaa7ac2f63968da8c62f7"},
- {file = "torch-2.1.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:1e1e5faddd43a8f2c0e0e22beacd1e235a2e447794d807483c94a9e31b54a758"},
- {file = "torch-2.1.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:e76bf3c5c354874f1da465c852a2fb60ee6cbce306e935337885760f080f9baa"},
- {file = "torch-2.1.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:98fea993639b0bb432dfceb7b538f07c0f1c33386d63f635219f49254968c80f"},
- {file = "torch-2.1.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:61b51b33c61737c287058b0c3061e6a9d3c363863e4a094f804bc486888a188a"},
- {file = "torch-2.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:1d70920da827e2276bf07f7ec46958621cad18d228c97da8f9c19638474dbd52"},
- {file = "torch-2.1.1-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:a70593806f1d7e6b53657d96810518da0f88ef2608c98a402955765b8c79d52c"},
- {file = "torch-2.1.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:e312f7e82e49565f7667b0bbf9559ab0c597063d93044740781c02acd5a87978"},
- {file = "torch-2.1.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:1e3cbecfa5a7314d828f4a37b0c286714dc9aa2e69beb7a22f7aca76567ed9f4"},
- {file = "torch-2.1.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:9ca0fcbf3d5ba644d6a8572c83a9abbdf5f7ff575bc38529ef6c185a3a71bde9"},
- {file = "torch-2.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:2dc9f312fc1fa0d61a565a0292ad73119d4b74c9f8b5031b55f8b4722abca079"},
- {file = "torch-2.1.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:d56b032176458e2af4709627bbd2c20fe2917eff8cd087a7fe313acccf5ce2f1"},
- {file = "torch-2.1.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:29e3b90a8c281f6660804a939d1f4218604c80162e521e1e6d8c8557325902a0"},
- {file = "torch-2.1.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:bd95cee8511584b67ddc0ba465c3f1edeb5708d833ee02af1206b4486f1d9096"},
- {file = "torch-2.1.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:b31230bd058424e56dba7f899280dbc6ac8b9948e43902e0c84a44666b1ec151"},
- {file = "torch-2.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:403f1095e665e4f35971b43797a920725b8b205723aa68254a4050c6beca29b6"},
- {file = "torch-2.1.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:715b50d8c1de5da5524a68287eb000f73e026e74d5f6b12bc450ef6995fcf5f9"},
- {file = "torch-2.1.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:db67e8725c76f4c7f4f02e7551bb16e81ba1a1912867bc35d7bb96d2be8c78b4"},
+ {file = "torch-2.1.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:3a871edd6c02dae77ad810335c0833391c1a4ce49af21ea8cf0f6a5d2096eea8"},
+ {file = "torch-2.1.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:bef6996c27d8f6e92ea4e13a772d89611da0e103b48790de78131e308cf73076"},
+ {file = "torch-2.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:0e13034fd5fb323cbbc29e56d0637a3791e50dd589616f40c79adfa36a5a35a1"},
+ {file = "torch-2.1.2-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:d9b535cad0df3d13997dbe8bd68ac33e0e3ae5377639c9881948e40794a61403"},
+ {file = "torch-2.1.2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:f9a55d55af02826ebfbadf4e9b682f0f27766bc33df8236b48d28d705587868f"},
+ {file = "torch-2.1.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:a6ebbe517097ef289cc7952783588c72de071d4b15ce0f8b285093f0916b1162"},
+ {file = "torch-2.1.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:8f32ce591616a30304f37a7d5ea80b69ca9e1b94bba7f308184bf616fdaea155"},
+ {file = "torch-2.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:e0ee6cf90c8970e05760f898d58f9ac65821c37ffe8b04269ec787aa70962b69"},
+ {file = "torch-2.1.2-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:76d37967c31c99548ad2c4d3f2cf191db48476f2e69b35a0937137116da356a1"},
+ {file = "torch-2.1.2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:e2d83f07b4aac983453ea5bf8f9aa9dacf2278a8d31247f5d9037f37befc60e4"},
+ {file = "torch-2.1.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:f41fe0c7ecbf903a568c73486139a75cfab287a0f6c17ed0698fdea7a1e8641d"},
+ {file = "torch-2.1.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e3225f47d50bb66f756fe9196a768055d1c26b02154eb1f770ce47a2578d3aa7"},
+ {file = "torch-2.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:33d59cd03cb60106857f6c26b36457793637512998666ee3ce17311f217afe2b"},
+ {file = "torch-2.1.2-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:8e221deccd0def6c2badff6be403e0c53491805ed9915e2c029adbcdb87ab6b5"},
+ {file = "torch-2.1.2-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:05b18594f60a911a0c4f023f38a8bda77131fba5fd741bda626e97dcf5a3dd0a"},
+ {file = "torch-2.1.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9ca96253b761e9aaf8e06fb30a66ee301aecbf15bb5a303097de1969077620b6"},
+ {file = "torch-2.1.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d93ba70f67b08c2ae5598ee711cbc546a1bc8102cef938904b8c85c2089a51a0"},
+ {file = "torch-2.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:255b50bc0608db177e6a3cc118961d77de7e5105f07816585fa6f191f33a9ff3"},
+ {file = "torch-2.1.2-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6984cd5057c0c977b3c9757254e989d3f1124f4ce9d07caa6cb637783c71d42a"},
+ {file = "torch-2.1.2-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:bc195d7927feabc0eb7c110e457c955ed2ab616f3c7c28439dd4188cf589699f"},
]
[package.dependencies]
@@ -4496,33 +4719,33 @@ opt-einsum = ["opt-einsum (>=3.3)"]
[[package]]
name = "tornado"
-version = "6.3.3"
+version = "6.4"
description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed."
optional = false
python-versions = ">= 3.8"
files = [
- {file = "tornado-6.3.3-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:502fba735c84450974fec147340016ad928d29f1e91f49be168c0a4c18181e1d"},
- {file = "tornado-6.3.3-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:805d507b1f588320c26f7f097108eb4023bbaa984d63176d1652e184ba24270a"},
- {file = "tornado-6.3.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1bd19ca6c16882e4d37368e0152f99c099bad93e0950ce55e71daed74045908f"},
- {file = "tornado-6.3.3-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7ac51f42808cca9b3613f51ffe2a965c8525cb1b00b7b2d56828b8045354f76a"},
- {file = "tornado-6.3.3-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71a8db65160a3c55d61839b7302a9a400074c9c753040455494e2af74e2501f2"},
- {file = "tornado-6.3.3-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:ceb917a50cd35882b57600709dd5421a418c29ddc852da8bcdab1f0db33406b0"},
- {file = "tornado-6.3.3-cp38-abi3-musllinux_1_1_i686.whl", hash = "sha256:7d01abc57ea0dbb51ddfed477dfe22719d376119844e33c661d873bf9c0e4a16"},
- {file = "tornado-6.3.3-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:9dc4444c0defcd3929d5c1eb5706cbe1b116e762ff3e0deca8b715d14bf6ec17"},
- {file = "tornado-6.3.3-cp38-abi3-win32.whl", hash = "sha256:65ceca9500383fbdf33a98c0087cb975b2ef3bfb874cb35b8de8740cf7f41bd3"},
- {file = "tornado-6.3.3-cp38-abi3-win_amd64.whl", hash = "sha256:22d3c2fa10b5793da13c807e6fc38ff49a4f6e1e3868b0a6f4164768bb8e20f5"},
- {file = "tornado-6.3.3.tar.gz", hash = "sha256:e7d8db41c0181c80d76c982aacc442c0783a2c54d6400fe028954201a2e032fe"},
+ {file = "tornado-6.4-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:02ccefc7d8211e5a7f9e8bc3f9e5b0ad6262ba2fbb683a6443ecc804e5224ce0"},
+ {file = "tornado-6.4-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:27787de946a9cffd63ce5814c33f734c627a87072ec7eed71f7fc4417bb16263"},
+ {file = "tornado-6.4-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7894c581ecdcf91666a0912f18ce5e757213999e183ebfc2c3fdbf4d5bd764e"},
+ {file = "tornado-6.4-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e43bc2e5370a6a8e413e1e1cd0c91bedc5bd62a74a532371042a18ef19e10579"},
+ {file = "tornado-6.4-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0251554cdd50b4b44362f73ad5ba7126fc5b2c2895cc62b14a1c2d7ea32f212"},
+ {file = "tornado-6.4-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:fd03192e287fbd0899dd8f81c6fb9cbbc69194d2074b38f384cb6fa72b80e9c2"},
+ {file = "tornado-6.4-cp38-abi3-musllinux_1_1_i686.whl", hash = "sha256:88b84956273fbd73420e6d4b8d5ccbe913c65d31351b4c004ae362eba06e1f78"},
+ {file = "tornado-6.4-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:71ddfc23a0e03ef2df1c1397d859868d158c8276a0603b96cf86892bff58149f"},
+ {file = "tornado-6.4-cp38-abi3-win32.whl", hash = "sha256:6f8a6c77900f5ae93d8b4ae1196472d0ccc2775cc1dfdc9e7727889145c45052"},
+ {file = "tornado-6.4-cp38-abi3-win_amd64.whl", hash = "sha256:10aeaa8006333433da48dec9fe417877f8bcc21f48dda8d661ae79da357b2a63"},
+ {file = "tornado-6.4.tar.gz", hash = "sha256:72291fa6e6bc84e626589f1c29d90a5a6d593ef5ae68052ee2ef000dfd273dee"},
]
[[package]]
name = "tqdm"
-version = "4.66.1"
+version = "4.66.2"
description = "Fast, Extensible Progress Meter"
optional = false
python-versions = ">=3.7"
files = [
- {file = "tqdm-4.66.1-py3-none-any.whl", hash = "sha256:d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"},
- {file = "tqdm-4.66.1.tar.gz", hash = "sha256:d88e651f9db8d8551a62556d3cff9e3034274ca5d66e93197cf2490e2dcb69c7"},
+ {file = "tqdm-4.66.2-py3-none-any.whl", hash = "sha256:1ee4f8a893eb9bef51c6e35730cebf234d5d0b6bd112b0271e10ed7c24a02bd9"},
+ {file = "tqdm-4.66.2.tar.gz", hash = "sha256:6cd52cdf0fef0e0f543299cfc96fec90d7b8a7e88745f411ec33eb44d5ed3531"},
]
[package.dependencies]
@@ -4536,87 +4759,86 @@ telegram = ["requests"]
[[package]]
name = "traitlets"
-version = "5.12.0"
+version = "5.14.2"
description = "Traitlets Python configuration system"
optional = false
python-versions = ">=3.8"
files = [
- {file = "traitlets-5.12.0-py3-none-any.whl", hash = "sha256:81539f07f7aebcde2e4b5ab76727f53eabf18ad155c6ed7979a681411602fa47"},
- {file = "traitlets-5.12.0.tar.gz", hash = "sha256:833273bf645d8ce31dcb613c56999e2e055b1ffe6d09168a164bcd91c36d5d35"},
+ {file = "traitlets-5.14.2-py3-none-any.whl", hash = "sha256:fcdf85684a772ddeba87db2f398ce00b40ff550d1528c03c14dbf6a02003cd80"},
+ {file = "traitlets-5.14.2.tar.gz", hash = "sha256:8cdd83c040dab7d1dee822678e5f5d100b514f7b72b01615b26fc5718916fdf9"},
]
[package.extras]
docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"]
-test = ["argcomplete (>=3.0.3)", "mypy (>=1.6.0)", "pre-commit", "pytest (>=7.0,<7.5)", "pytest-mock", "pytest-mypy-testing"]
+test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.1)", "pytest-mock", "pytest-mypy-testing"]
[[package]]
name = "transformers"
-version = "4.34.1"
+version = "4.39.3"
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
optional = false
python-versions = ">=3.8.0"
files = [
- {file = "transformers-4.34.1-py3-none-any.whl", hash = "sha256:d06ac09151d7b845e4a4acd6b143a591d946031ee67b4cbb20693b241920ffc0"},
- {file = "transformers-4.34.1.tar.gz", hash = "sha256:1d0258d5a18063b66005bbe1e3276ec5943d9ab4ab47f020db1fd485cc40ea22"},
+ {file = "transformers-4.39.3-py3-none-any.whl", hash = "sha256:7838034a12cca3168247f9d2d1dba6724c9de3ae0f73a108258c6b8fc5912601"},
+ {file = "transformers-4.39.3.tar.gz", hash = "sha256:2586e5ff4150f122716fc40f5530e92871befc051848fbe82600969c535b762d"},
]
[package.dependencies]
filelock = "*"
-huggingface-hub = ">=0.16.4,<1.0"
+huggingface-hub = ">=0.19.3,<1.0"
numpy = ">=1.17"
packaging = ">=20.0"
pyyaml = ">=5.1"
regex = "!=2019.12.17"
requests = "*"
-safetensors = ">=0.3.1"
-tokenizers = ">=0.14,<0.15"
+safetensors = ">=0.4.1"
+tokenizers = ">=0.14,<0.19"
tqdm = ">=4.27"
[package.extras]
-accelerate = ["accelerate (>=0.20.3)"]
-agents = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.10,!=1.12.0)"]
-all = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.15)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"]
+accelerate = ["accelerate (>=0.21.0)"]
+agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"]
+all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch", "torchaudio", "torchvision"]
audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
codecarbon = ["codecarbon (==1.2.0)"]
-deepspeed = ["accelerate (>=0.20.3)", "deepspeed (>=0.9.3)"]
-deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "timeout-decorator"]
-dev = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.15)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
-dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.15)", "urllib3 (<2.0.0)"]
-dev-torch = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.15)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
-docs = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.15)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"]
+deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"]
+deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
+dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
+dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.19)", "urllib3 (<2.0.0)"]
+dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
+docs = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch", "torchaudio", "torchvision"]
docs-specific = ["hf-doc-builder"]
-fairscale = ["fairscale (>0.3)"]
flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)"]
flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
ftfy = ["ftfy"]
-integrations = ["optuna", "ray[tune]", "sigopt"]
+integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"]
ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"]
modelcreation = ["cookiecutter (==1.7.3)"]
-natten = ["natten (>=0.14.6)"]
+natten = ["natten (>=0.14.6,<0.15.0)"]
onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"]
onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
optuna = ["optuna"]
-quality = ["GitPython (<3.1.19)", "black (>=23.1,<24.0)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (>=0.0.241,<=0.0.259)", "urllib3 (<2.0.0)"]
-ray = ["ray[tune]"]
+quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<2.0.0)"]
+ray = ["ray[tune] (>=2.7.0)"]
retrieval = ["datasets (!=2.5.0)", "faiss-cpu"]
sagemaker = ["sagemaker (>=2.31.0)"]
sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"]
-serving = ["fastapi", "pydantic (<2)", "starlette", "uvicorn"]
+serving = ["fastapi", "pydantic", "starlette", "uvicorn"]
sigopt = ["sigopt"]
sklearn = ["scikit-learn"]
speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
-testing = ["GitPython (<3.1.19)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "timeout-decorator"]
-tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx"]
-tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx"]
+testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "tensorboard", "timeout-decorator"]
+tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
+tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
timm = ["timm"]
-tokenizers = ["tokenizers (>=0.14,<0.15)"]
-torch = ["accelerate (>=0.20.3)", "torch (>=1.10,!=1.12.0)"]
+tokenizers = ["tokenizers (>=0.14,<0.19)"]
+torch = ["accelerate (>=0.21.0)", "torch"]
torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
-torch-vision = ["Pillow (<10.0.0)", "torchvision"]
-torchhub = ["filelock", "huggingface-hub (>=0.16.4,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.15)", "torch (>=1.10,!=1.12.0)", "tqdm (>=4.27)"]
+torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"]
+torchhub = ["filelock", "huggingface-hub (>=0.19.3,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.19)", "torch", "tqdm (>=4.27)"]
video = ["av (==9.2.0)", "decord (==0.6.0)"]
-vision = ["Pillow (<10.0.0)"]
+vision = ["Pillow (>=10.0.1,<=15.0)"]
[[package]]
name = "triton"
@@ -4645,18 +4867,18 @@ tutorials = ["matplotlib", "pandas", "tabulate"]
[[package]]
name = "typeguard"
-version = "4.1.5"
+version = "4.2.1"
description = "Run-time type checker for Python"
optional = false
python-versions = ">=3.8"
files = [
- {file = "typeguard-4.1.5-py3-none-any.whl", hash = "sha256:8923e55f8873caec136c892c3bed1f676eae7be57cdb94819281b3d3bc9c0953"},
- {file = "typeguard-4.1.5.tar.gz", hash = "sha256:ea0a113bbc111bcffc90789ebb215625c963411f7096a7e9062d4e4630c155fd"},
+ {file = "typeguard-4.2.1-py3-none-any.whl", hash = "sha256:7da3bd46e61f03e0852f8d251dcbdc2a336aa495d7daff01e092b55327796eb8"},
+ {file = "typeguard-4.2.1.tar.gz", hash = "sha256:c556a1b95948230510070ca53fa0341fb0964611bd05d598d87fb52115d65fee"},
]
[package.dependencies]
importlib-metadata = {version = ">=3.6", markers = "python_version < \"3.10\""}
-typing-extensions = {version = ">=4.7.0", markers = "python_version < \"3.12\""}
+typing-extensions = {version = ">=4.10.0", markers = "python_version < \"3.13\""}
[package.extras]
doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)"]
@@ -4664,45 +4886,41 @@ test = ["coverage[toml] (>=7)", "mypy (>=1.2.0)", "pytest (>=7)"]
[[package]]
name = "typer"
-version = "0.9.0"
+version = "0.12.1"
description = "Typer, build great CLIs. Easy to code. Based on Python type hints."
optional = false
-python-versions = ">=3.6"
+python-versions = ">=3.7"
files = [
- {file = "typer-0.9.0-py3-none-any.whl", hash = "sha256:5d96d986a21493606a358cae4461bd8cdf83cbf33a5aa950ae629ca3b51467ee"},
- {file = "typer-0.9.0.tar.gz", hash = "sha256:50922fd79aea2f4751a8e0408ff10d2662bd0c8bbfa84755a699f3bada2978b2"},
+ {file = "typer-0.12.1-py3-none-any.whl", hash = "sha256:43ebb23c8a358c3d623e31064359a65f50229d0bf73ae8dfd203f49d9126ae06"},
+ {file = "typer-0.12.1.tar.gz", hash = "sha256:72d218ef3c686aed9c6ff3ca25b238aee0474a1628b29c559b18b634cfdeca88"},
]
[package.dependencies]
-click = ">=7.1.1,<9.0.0"
+click = ">=8.0.0"
+rich = ">=10.11.0"
+shellingham = ">=1.3.0"
typing-extensions = ">=3.7.4.3"
-[package.extras]
-all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"]
-dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"]
-doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"]
-test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"]
-
[[package]]
name = "types-python-dateutil"
-version = "2.8.19.14"
+version = "2.9.0.20240316"
description = "Typing stubs for python-dateutil"
optional = false
-python-versions = "*"
+python-versions = ">=3.8"
files = [
- {file = "types-python-dateutil-2.8.19.14.tar.gz", hash = "sha256:1f4f10ac98bb8b16ade9dbee3518d9ace017821d94b057a425b069f834737f4b"},
- {file = "types_python_dateutil-2.8.19.14-py3-none-any.whl", hash = "sha256:f977b8de27787639986b4e28963263fd0e5158942b3ecef91b9335c130cb1ce9"},
+ {file = "types-python-dateutil-2.9.0.20240316.tar.gz", hash = "sha256:5d2f2e240b86905e40944dd787db6da9263f0deabef1076ddaed797351ec0202"},
+ {file = "types_python_dateutil-2.9.0.20240316-py3-none-any.whl", hash = "sha256:6b8cb66d960771ce5ff974e9dd45e38facb81718cc1e208b10b1baccbfdbee3b"},
]
[[package]]
name = "typing-extensions"
-version = "4.8.0"
+version = "4.11.0"
description = "Backported and Experimental Type Hints for Python 3.8+"
optional = false
python-versions = ">=3.8"
files = [
- {file = "typing_extensions-4.8.0-py3-none-any.whl", hash = "sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0"},
- {file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"},
+ {file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"},
+ {file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"},
]
[[package]]
@@ -4722,13 +4940,13 @@ typing-extensions = ">=3.7.4"
[[package]]
name = "tzdata"
-version = "2023.3"
+version = "2024.1"
description = "Provider of IANA time zone data"
optional = false
python-versions = ">=2"
files = [
- {file = "tzdata-2023.3-py2.py3-none-any.whl", hash = "sha256:7e65763eef3120314099b6939b5546db7adce1e7d6f2e179e3df563c70511eda"},
- {file = "tzdata-2023.3.tar.gz", hash = "sha256:11ef1e08e54acb0d4f95bdb1be05da659673de4acbd21bf9c69e94cc5e907a3a"},
+ {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"},
+ {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"},
]
[[package]]
@@ -4747,30 +4965,30 @@ dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake
[[package]]
name = "urllib3"
-version = "2.0.7"
+version = "2.2.1"
description = "HTTP library with thread-safe connection pooling, file post, and more."
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
files = [
- {file = "urllib3-2.0.7-py3-none-any.whl", hash = "sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e"},
- {file = "urllib3-2.0.7.tar.gz", hash = "sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84"},
+ {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"},
+ {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"},
]
[package.extras]
brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
-secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"]
+h2 = ["h2 (>=4,<5)"]
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "wandb"
-version = "0.15.12"
+version = "0.16.6"
description = "A CLI and library for interacting with the Weights & Biases API."
optional = false
-python-versions = ">=3.6"
+python-versions = ">=3.7"
files = [
- {file = "wandb-0.15.12-py3-none-any.whl", hash = "sha256:75c57b5bb8ddae21d45a02f644628585bdd112fea686de3177099a0996f1c41c"},
- {file = "wandb-0.15.12.tar.gz", hash = "sha256:c344d92fb8044b072a6138afd9adc5d3801ad050cf11378fe2af2fe899dcca84"},
+ {file = "wandb-0.16.6-py3-none-any.whl", hash = "sha256:5810019a3b981c796e98ea58557a7c380f18834e0c6bdaed15df115522e5616e"},
+ {file = "wandb-0.16.6.tar.gz", hash = "sha256:86f491e3012d715e0d7d7421a4d6de41abef643b7403046261f962f3e512fe1c"},
]
[package.dependencies]
@@ -4778,7 +4996,6 @@ appdirs = ">=1.4.3"
Click = ">=7.1,<8.0.0 || >8.0.0"
docker-pycreds = ">=0.4.0"
GitPython = ">=1.0.0,<3.1.29 || >3.1.29"
-pathtools = "*"
protobuf = [
{version = ">=3.12.0,<4.21.0 || >4.21.0,<5", markers = "python_version < \"3.9\" and sys_platform == \"linux\""},
{version = ">=3.19.0,<4.21.0 || >4.21.0,<5", markers = "python_version > \"3.9\" or sys_platform != \"linux\""},
@@ -4793,27 +5010,28 @@ setuptools = "*"
typing-extensions = {version = "*", markers = "python_version < \"3.10\""}
[package.extras]
-async = ["httpx (>=0.22.0)"]
+async = ["httpx (>=0.23.0)"]
aws = ["boto3"]
azure = ["azure-identity", "azure-storage-blob"]
gcp = ["google-cloud-storage"]
+importers = ["filelock", "mlflow", "polars", "rich", "tenacity"]
kubeflow = ["google-cloud-storage", "kubernetes", "minio", "sh"]
-launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "nbconvert", "nbformat", "optuna", "typing-extensions"]
-media = ["bokeh", "moviepy", "numpy", "pillow", "plotly", "rdkit-pypi", "soundfile"]
+launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "tomli", "typing-extensions"]
+media = ["bokeh", "moviepy", "numpy", "pillow", "plotly (>=5.18.0)", "rdkit-pypi", "soundfile"]
models = ["cloudpickle"]
-nexus = ["wandb-core (>=0.16.0b1)"]
perf = ["orjson"]
+reports = ["pydantic (>=2.0.0)"]
sweeps = ["sweeps (>=0.2.0)"]
[[package]]
name = "wcwidth"
-version = "0.2.8"
+version = "0.2.13"
description = "Measures the displayed width of unicode strings in a terminal"
optional = false
python-versions = "*"
files = [
- {file = "wcwidth-0.2.8-py2.py3-none-any.whl", hash = "sha256:77f719e01648ed600dfa5402c347481c0992263b81a027344f3e1ba25493a704"},
- {file = "wcwidth-0.2.8.tar.gz", hash = "sha256:8705c569999ffbb4f6a87c6d1b80f324bd6db952f5eb0b95bc07517f4c1813d4"},
+ {file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"},
+ {file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"},
]
[[package]]
@@ -4844,13 +5062,13 @@ files = [
[[package]]
name = "websocket-client"
-version = "1.6.4"
+version = "1.7.0"
description = "WebSocket client for Python with low level API options"
optional = false
python-versions = ">=3.8"
files = [
- {file = "websocket-client-1.6.4.tar.gz", hash = "sha256:b3324019b3c28572086c4a319f91d1dcd44e6e11cd340232978c684a7650d0df"},
- {file = "websocket_client-1.6.4-py3-none-any.whl", hash = "sha256:084072e0a7f5f347ef2ac3d8698a5e0b4ffbfcab607628cadabc650fc9a83a24"},
+ {file = "websocket-client-1.7.0.tar.gz", hash = "sha256:10e511ea3a8c744631d3bd77e61eb17ed09304c413ad42cf6ddfa4c7787e8fe6"},
+ {file = "websocket_client-1.7.0-py3-none-any.whl", hash = "sha256:f4c3d22fec12a2461427a29957ff07d35098ee2d976d3ba244e688b8b4057588"},
]
[package.extras]
@@ -4860,13 +5078,13 @@ test = ["websockets"]
[[package]]
name = "widgetsnbextension"
-version = "4.0.9"
+version = "4.0.10"
description = "Jupyter interactive widgets for Jupyter Notebook"
optional = false
python-versions = ">=3.7"
files = [
- {file = "widgetsnbextension-4.0.9-py3-none-any.whl", hash = "sha256:91452ca8445beb805792f206e560c1769284267a30ceb1cec9f5bcc887d15175"},
- {file = "widgetsnbextension-4.0.9.tar.gz", hash = "sha256:3c1f5e46dc1166dfd40a42d685e6a51396fd34ff878742a3e47c6f0cc4a2a385"},
+ {file = "widgetsnbextension-4.0.10-py3-none-any.whl", hash = "sha256:d37c3724ec32d8c48400a435ecfa7d3e259995201fbefa37163124a9fcb393cc"},
+ {file = "widgetsnbextension-4.0.10.tar.gz", hash = "sha256:64196c5ff3b9a9183a8e699a4227fb0b7002f252c814098e66c4d1cd0644688f"},
]
[[package]]
@@ -4988,85 +5206,101 @@ files = [
[[package]]
name = "yarl"
-version = "1.9.2"
+version = "1.9.4"
description = "Yet another URL library"
optional = false
python-versions = ">=3.7"
files = [
- {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8c2ad583743d16ddbdf6bb14b5cd76bf43b0d0006e918809d5d4ddf7bde8dd82"},
- {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:82aa6264b36c50acfb2424ad5ca537a2060ab6de158a5bd2a72a032cc75b9eb8"},
- {file = "yarl-1.9.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c0c77533b5ed4bcc38e943178ccae29b9bcf48ffd1063f5821192f23a1bd27b9"},
- {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee4afac41415d52d53a9833ebae7e32b344be72835bbb589018c9e938045a560"},
- {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9bf345c3a4f5ba7f766430f97f9cc1320786f19584acc7086491f45524a551ac"},
- {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a96c19c52ff442a808c105901d0bdfd2e28575b3d5f82e2f5fd67e20dc5f4ea"},
- {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:891c0e3ec5ec881541f6c5113d8df0315ce5440e244a716b95f2525b7b9f3608"},
- {file = "yarl-1.9.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c3a53ba34a636a256d767c086ceb111358876e1fb6b50dfc4d3f4951d40133d5"},
- {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:566185e8ebc0898b11f8026447eacd02e46226716229cea8db37496c8cdd26e0"},
- {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2b0738fb871812722a0ac2154be1f049c6223b9f6f22eec352996b69775b36d4"},
- {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:32f1d071b3f362c80f1a7d322bfd7b2d11e33d2adf395cc1dd4df36c9c243095"},
- {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:e9fdc7ac0d42bc3ea78818557fab03af6181e076a2944f43c38684b4b6bed8e3"},
- {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:56ff08ab5df8429901ebdc5d15941b59f6253393cb5da07b4170beefcf1b2528"},
- {file = "yarl-1.9.2-cp310-cp310-win32.whl", hash = "sha256:8ea48e0a2f931064469bdabca50c2f578b565fc446f302a79ba6cc0ee7f384d3"},
- {file = "yarl-1.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:50f33040f3836e912ed16d212f6cc1efb3231a8a60526a407aeb66c1c1956dde"},
- {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:646d663eb2232d7909e6601f1a9107e66f9791f290a1b3dc7057818fe44fc2b6"},
- {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:aff634b15beff8902d1f918012fc2a42e0dbae6f469fce134c8a0dc51ca423bb"},
- {file = "yarl-1.9.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a83503934c6273806aed765035716216cc9ab4e0364f7f066227e1aaea90b8d0"},
- {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b25322201585c69abc7b0e89e72790469f7dad90d26754717f3310bfe30331c2"},
- {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:22a94666751778629f1ec4280b08eb11815783c63f52092a5953faf73be24191"},
- {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ec53a0ea2a80c5cd1ab397925f94bff59222aa3cf9c6da938ce05c9ec20428d"},
- {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:159d81f22d7a43e6eabc36d7194cb53f2f15f498dbbfa8edc8a3239350f59fe7"},
- {file = "yarl-1.9.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:832b7e711027c114d79dffb92576acd1bd2decc467dec60e1cac96912602d0e6"},
- {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:95d2ecefbcf4e744ea952d073c6922e72ee650ffc79028eb1e320e732898d7e8"},
- {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d4e2c6d555e77b37288eaf45b8f60f0737c9efa3452c6c44626a5455aeb250b9"},
- {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:783185c75c12a017cc345015ea359cc801c3b29a2966c2655cd12b233bf5a2be"},
- {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:b8cc1863402472f16c600e3e93d542b7e7542a540f95c30afd472e8e549fc3f7"},
- {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:822b30a0f22e588b32d3120f6d41e4ed021806418b4c9f0bc3048b8c8cb3f92a"},
- {file = "yarl-1.9.2-cp311-cp311-win32.whl", hash = "sha256:a60347f234c2212a9f0361955007fcf4033a75bf600a33c88a0a8e91af77c0e8"},
- {file = "yarl-1.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:be6b3fdec5c62f2a67cb3f8c6dbf56bbf3f61c0f046f84645cd1ca73532ea051"},
- {file = "yarl-1.9.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:38a3928ae37558bc1b559f67410df446d1fbfa87318b124bf5032c31e3447b74"},
- {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac9bb4c5ce3975aeac288cfcb5061ce60e0d14d92209e780c93954076c7c4367"},
- {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3da8a678ca8b96c8606bbb8bfacd99a12ad5dd288bc6f7979baddd62f71c63ef"},
- {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:13414591ff516e04fcdee8dc051c13fd3db13b673c7a4cb1350e6b2ad9639ad3"},
- {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf74d08542c3a9ea97bb8f343d4fcbd4d8f91bba5ec9d5d7f792dbe727f88938"},
- {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e7221580dc1db478464cfeef9b03b95c5852cc22894e418562997df0d074ccc"},
- {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:494053246b119b041960ddcd20fd76224149cfea8ed8777b687358727911dd33"},
- {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:52a25809fcbecfc63ac9ba0c0fb586f90837f5425edfd1ec9f3372b119585e45"},
- {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:e65610c5792870d45d7b68c677681376fcf9cc1c289f23e8e8b39c1485384185"},
- {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:1b1bba902cba32cdec51fca038fd53f8beee88b77efc373968d1ed021024cc04"},
- {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:662e6016409828ee910f5d9602a2729a8a57d74b163c89a837de3fea050c7582"},
- {file = "yarl-1.9.2-cp37-cp37m-win32.whl", hash = "sha256:f364d3480bffd3aa566e886587eaca7c8c04d74f6e8933f3f2c996b7f09bee1b"},
- {file = "yarl-1.9.2-cp37-cp37m-win_amd64.whl", hash = "sha256:6a5883464143ab3ae9ba68daae8e7c5c95b969462bbe42e2464d60e7e2698368"},
- {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5610f80cf43b6202e2c33ba3ec2ee0a2884f8f423c8f4f62906731d876ef4fac"},
- {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b9a4e67ad7b646cd6f0938c7ebfd60e481b7410f574c560e455e938d2da8e0f4"},
- {file = "yarl-1.9.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:83fcc480d7549ccebe9415d96d9263e2d4226798c37ebd18c930fce43dfb9574"},
- {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5fcd436ea16fee7d4207c045b1e340020e58a2597301cfbcfdbe5abd2356c2fb"},
- {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84e0b1599334b1e1478db01b756e55937d4614f8654311eb26012091be109d59"},
- {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3458a24e4ea3fd8930e934c129b676c27452e4ebda80fbe47b56d8c6c7a63a9e"},
- {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:838162460b3a08987546e881a2bfa573960bb559dfa739e7800ceeec92e64417"},
- {file = "yarl-1.9.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f4e2d08f07a3d7d3e12549052eb5ad3eab1c349c53ac51c209a0e5991bbada78"},
- {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:de119f56f3c5f0e2fb4dee508531a32b069a5f2c6e827b272d1e0ff5ac040333"},
- {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:149ddea5abf329752ea5051b61bd6c1d979e13fbf122d3a1f9f0c8be6cb6f63c"},
- {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:674ca19cbee4a82c9f54e0d1eee28116e63bc6fd1e96c43031d11cbab8b2afd5"},
- {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:9b3152f2f5677b997ae6c804b73da05a39daa6a9e85a512e0e6823d81cdad7cc"},
- {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5415d5a4b080dc9612b1b63cba008db84e908b95848369aa1da3686ae27b6d2b"},
- {file = "yarl-1.9.2-cp38-cp38-win32.whl", hash = "sha256:f7a3d8146575e08c29ed1cd287068e6d02f1c7bdff8970db96683b9591b86ee7"},
- {file = "yarl-1.9.2-cp38-cp38-win_amd64.whl", hash = "sha256:63c48f6cef34e6319a74c727376e95626f84ea091f92c0250a98e53e62c77c72"},
- {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:75df5ef94c3fdc393c6b19d80e6ef1ecc9ae2f4263c09cacb178d871c02a5ba9"},
- {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c027a6e96ef77d401d8d5a5c8d6bc478e8042f1e448272e8d9752cb0aff8b5c8"},
- {file = "yarl-1.9.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f3b078dbe227f79be488ffcfc7a9edb3409d018e0952cf13f15fd6512847f3f7"},
- {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59723a029760079b7d991a401386390c4be5bfec1e7dd83e25a6a0881859e716"},
- {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b03917871bf859a81ccb180c9a2e6c1e04d2f6a51d953e6a5cdd70c93d4e5a2a"},
- {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c1012fa63eb6c032f3ce5d2171c267992ae0c00b9e164efe4d73db818465fac3"},
- {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a74dcbfe780e62f4b5a062714576f16c2f3493a0394e555ab141bf0d746bb955"},
- {file = "yarl-1.9.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8c56986609b057b4839968ba901944af91b8e92f1725d1a2d77cbac6972b9ed1"},
- {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2c315df3293cd521033533d242d15eab26583360b58f7ee5d9565f15fee1bef4"},
- {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:b7232f8dfbd225d57340e441d8caf8652a6acd06b389ea2d3222b8bc89cbfca6"},
- {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:53338749febd28935d55b41bf0bcc79d634881195a39f6b2f767870b72514caf"},
- {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:066c163aec9d3d073dc9ffe5dd3ad05069bcb03fcaab8d221290ba99f9f69ee3"},
- {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8288d7cd28f8119b07dd49b7230d6b4562f9b61ee9a4ab02221060d21136be80"},
- {file = "yarl-1.9.2-cp39-cp39-win32.whl", hash = "sha256:b124e2a6d223b65ba8768d5706d103280914d61f5cae3afbc50fc3dfcc016623"},
- {file = "yarl-1.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:61016e7d582bc46a5378ffdd02cd0314fb8ba52f40f9cf4d9a5e7dbef88dee18"},
- {file = "yarl-1.9.2.tar.gz", hash = "sha256:04ab9d4b9f587c06d801c2abfe9317b77cdf996c65a90d5e84ecc45010823571"},
+ {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"},
+ {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a3a6ed1d525bfb91b3fc9b690c5a21bb52de28c018530ad85093cc488bee2dd2"},
+ {file = "yarl-1.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c38c9ddb6103ceae4e4498f9c08fac9b590c5c71b0370f98714768e22ac6fa66"},
+ {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9e09c9d74f4566e905a0b8fa668c58109f7624db96a2171f21747abc7524234"},
+ {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8477c1ee4bd47c57d49621a062121c3023609f7a13b8a46953eb6c9716ca392"},
+ {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5ff2c858f5f6a42c2a8e751100f237c5e869cbde669a724f2062d4c4ef93551"},
+ {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:357495293086c5b6d34ca9616a43d329317feab7917518bc97a08f9e55648455"},
+ {file = "yarl-1.9.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54525ae423d7b7a8ee81ba189f131054defdb122cde31ff17477951464c1691c"},
+ {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:801e9264d19643548651b9db361ce3287176671fb0117f96b5ac0ee1c3530d53"},
+ {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e516dc8baf7b380e6c1c26792610230f37147bb754d6426462ab115a02944385"},
+ {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:7d5aaac37d19b2904bb9dfe12cdb08c8443e7ba7d2852894ad448d4b8f442863"},
+ {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:54beabb809ffcacbd9d28ac57b0db46e42a6e341a030293fb3185c409e626b8b"},
+ {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bac8d525a8dbc2a1507ec731d2867025d11ceadcb4dd421423a5d42c56818541"},
+ {file = "yarl-1.9.4-cp310-cp310-win32.whl", hash = "sha256:7855426dfbddac81896b6e533ebefc0af2f132d4a47340cee6d22cac7190022d"},
+ {file = "yarl-1.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:848cd2a1df56ddbffeb375535fb62c9d1645dde33ca4d51341378b3f5954429b"},
+ {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:35a2b9396879ce32754bd457d31a51ff0a9d426fd9e0e3c33394bf4b9036b099"},
+ {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c7d56b293cc071e82532f70adcbd8b61909eec973ae9d2d1f9b233f3d943f2c"},
+ {file = "yarl-1.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d8a1c6c0be645c745a081c192e747c5de06e944a0d21245f4cf7c05e457c36e0"},
+ {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b3c1ffe10069f655ea2d731808e76e0f452fc6c749bea04781daf18e6039525"},
+ {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:549d19c84c55d11687ddbd47eeb348a89df9cb30e1993f1b128f4685cd0ebbf8"},
+ {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7409f968456111140c1c95301cadf071bd30a81cbd7ab829169fb9e3d72eae9"},
+ {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e23a6d84d9d1738dbc6e38167776107e63307dfc8ad108e580548d1f2c587f42"},
+ {file = "yarl-1.9.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8b889777de69897406c9fb0b76cdf2fd0f31267861ae7501d93003d55f54fbe"},
+ {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:03caa9507d3d3c83bca08650678e25364e1843b484f19986a527630ca376ecce"},
+ {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e9035df8d0880b2f1c7f5031f33f69e071dfe72ee9310cfc76f7b605958ceb9"},
+ {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c0ec0ed476f77db9fb29bca17f0a8fcc7bc97ad4c6c1d8959c507decb22e8572"},
+ {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:ee04010f26d5102399bd17f8df8bc38dc7ccd7701dc77f4a68c5b8d733406958"},
+ {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49a180c2e0743d5d6e0b4d1a9e5f633c62eca3f8a86ba5dd3c471060e352ca98"},
+ {file = "yarl-1.9.4-cp311-cp311-win32.whl", hash = "sha256:81eb57278deb6098a5b62e88ad8281b2ba09f2f1147c4767522353eaa6260b31"},
+ {file = "yarl-1.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d1d2532b340b692880261c15aee4dc94dd22ca5d61b9db9a8a361953d36410b1"},
+ {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0d2454f0aef65ea81037759be5ca9947539667eecebca092733b2eb43c965a81"},
+ {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:44d8ffbb9c06e5a7f529f38f53eda23e50d1ed33c6c869e01481d3fafa6b8142"},
+ {file = "yarl-1.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aaaea1e536f98754a6e5c56091baa1b6ce2f2700cc4a00b0d49eca8dea471074"},
+ {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3777ce5536d17989c91696db1d459574e9a9bd37660ea7ee4d3344579bb6f129"},
+ {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc5fc1eeb029757349ad26bbc5880557389a03fa6ada41703db5e068881e5f2"},
+ {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea65804b5dc88dacd4a40279af0cdadcfe74b3e5b4c897aa0d81cf86927fee78"},
+ {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa102d6d280a5455ad6a0f9e6d769989638718e938a6a0a2ff3f4a7ff8c62cc4"},
+ {file = "yarl-1.9.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09efe4615ada057ba2d30df871d2f668af661e971dfeedf0c159927d48bbeff0"},
+ {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:008d3e808d03ef28542372d01057fd09168419cdc8f848efe2804f894ae03e51"},
+ {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6f5cb257bc2ec58f437da2b37a8cd48f666db96d47b8a3115c29f316313654ff"},
+ {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:992f18e0ea248ee03b5a6e8b3b4738850ae7dbb172cc41c966462801cbf62cf7"},
+ {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0e9d124c191d5b881060a9e5060627694c3bdd1fe24c5eecc8d5d7d0eb6faabc"},
+ {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3986b6f41ad22988e53d5778f91855dc0399b043fc8946d4f2e68af22ee9ff10"},
+ {file = "yarl-1.9.4-cp312-cp312-win32.whl", hash = "sha256:4b21516d181cd77ebd06ce160ef8cc2a5e9ad35fb1c5930882baff5ac865eee7"},
+ {file = "yarl-1.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:a9bd00dc3bc395a662900f33f74feb3e757429e545d831eef5bb280252631984"},
+ {file = "yarl-1.9.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:63b20738b5aac74e239622d2fe30df4fca4942a86e31bf47a81a0e94c14df94f"},
+ {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7d7f7de27b8944f1fee2c26a88b4dabc2409d2fea7a9ed3df79b67277644e17"},
+ {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c74018551e31269d56fab81a728f683667e7c28c04e807ba08f8c9e3bba32f14"},
+ {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ca06675212f94e7a610e85ca36948bb8fc023e458dd6c63ef71abfd482481aa5"},
+ {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5aef935237d60a51a62b86249839b51345f47564208c6ee615ed2a40878dccdd"},
+ {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b134fd795e2322b7684155b7855cc99409d10b2e408056db2b93b51a52accc7"},
+ {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d25039a474c4c72a5ad4b52495056f843a7ff07b632c1b92ea9043a3d9950f6e"},
+ {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f7d6b36dd2e029b6bcb8a13cf19664c7b8e19ab3a58e0fefbb5b8461447ed5ec"},
+ {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:957b4774373cf6f709359e5c8c4a0af9f6d7875db657adb0feaf8d6cb3c3964c"},
+ {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d7eeb6d22331e2fd42fce928a81c697c9ee2d51400bd1a28803965883e13cead"},
+ {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:6a962e04b8f91f8c4e5917e518d17958e3bdee71fd1d8b88cdce74dd0ebbf434"},
+ {file = "yarl-1.9.4-cp37-cp37m-win32.whl", hash = "sha256:f3bc6af6e2b8f92eced34ef6a96ffb248e863af20ef4fde9448cc8c9b858b749"},
+ {file = "yarl-1.9.4-cp37-cp37m-win_amd64.whl", hash = "sha256:ad4d7a90a92e528aadf4965d685c17dacff3df282db1121136c382dc0b6014d2"},
+ {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ec61d826d80fc293ed46c9dd26995921e3a82146feacd952ef0757236fc137be"},
+ {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8be9e837ea9113676e5754b43b940b50cce76d9ed7d2461df1af39a8ee674d9f"},
+ {file = "yarl-1.9.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bef596fdaa8f26e3d66af846bbe77057237cb6e8efff8cd7cc8dff9a62278bbf"},
+ {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d47552b6e52c3319fede1b60b3de120fe83bde9b7bddad11a69fb0af7db32f1"},
+ {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84fc30f71689d7fc9168b92788abc977dc8cefa806909565fc2951d02f6b7d57"},
+ {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa9741085f635934f3a2583e16fcf62ba835719a8b2b28fb2917bb0537c1dfa"},
+ {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:206a55215e6d05dbc6c98ce598a59e6fbd0c493e2de4ea6cc2f4934d5a18d130"},
+ {file = "yarl-1.9.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07574b007ee20e5c375a8fe4a0789fad26db905f9813be0f9fef5a68080de559"},
+ {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5a2e2433eb9344a163aced6a5f6c9222c0786e5a9e9cac2c89f0b28433f56e23"},
+ {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6ad6d10ed9b67a382b45f29ea028f92d25bc0bc1daf6c5b801b90b5aa70fb9ec"},
+ {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:6fe79f998a4052d79e1c30eeb7d6c1c1056ad33300f682465e1b4e9b5a188b78"},
+ {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a825ec844298c791fd28ed14ed1bffc56a98d15b8c58a20e0e08c1f5f2bea1be"},
+ {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8619d6915b3b0b34420cf9b2bb6d81ef59d984cb0fde7544e9ece32b4b3043c3"},
+ {file = "yarl-1.9.4-cp38-cp38-win32.whl", hash = "sha256:686a0c2f85f83463272ddffd4deb5e591c98aac1897d65e92319f729c320eece"},
+ {file = "yarl-1.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:a00862fb23195b6b8322f7d781b0dc1d82cb3bcac346d1e38689370cc1cc398b"},
+ {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:604f31d97fa493083ea21bd9b92c419012531c4e17ea6da0f65cacdcf5d0bd27"},
+ {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8a854227cf581330ffa2c4824d96e52ee621dd571078a252c25e3a3b3d94a1b1"},
+ {file = "yarl-1.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ba6f52cbc7809cd8d74604cce9c14868306ae4aa0282016b641c661f981a6e91"},
+ {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6327976c7c2f4ee6816eff196e25385ccc02cb81427952414a64811037bbc8b"},
+ {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8397a3817d7dcdd14bb266283cd1d6fc7264a48c186b986f32e86d86d35fbac5"},
+ {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0381b4ce23ff92f8170080c97678040fc5b08da85e9e292292aba67fdac6c34"},
+ {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23d32a2594cb5d565d358a92e151315d1b2268bc10f4610d098f96b147370136"},
+ {file = "yarl-1.9.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ddb2a5c08a4eaaba605340fdee8fc08e406c56617566d9643ad8bf6852778fc7"},
+ {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:26a1dc6285e03f3cc9e839a2da83bcbf31dcb0d004c72d0730e755b33466c30e"},
+ {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:18580f672e44ce1238b82f7fb87d727c4a131f3a9d33a5e0e82b793362bf18b4"},
+ {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:29e0f83f37610f173eb7e7b5562dd71467993495e568e708d99e9d1944f561ec"},
+ {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:1f23e4fe1e8794f74b6027d7cf19dc25f8b63af1483d91d595d4a07eca1fb26c"},
+ {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db8e58b9d79200c76956cefd14d5c90af54416ff5353c5bfd7cbe58818e26ef0"},
+ {file = "yarl-1.9.4-cp39-cp39-win32.whl", hash = "sha256:c7224cab95645c7ab53791022ae77a4509472613e839dab722a72abe5a684575"},
+ {file = "yarl-1.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:824d6c50492add5da9374875ce72db7a0733b29c2394890aef23d533106e2b15"},
+ {file = "yarl-1.9.4-py3-none-any.whl", hash = "sha256:928cecb0ef9d5a7946eb6ff58417ad2fe9375762382f1bf5c55e61645f2c43ad"},
+ {file = "yarl-1.9.4.tar.gz", hash = "sha256:566db86717cf8080b99b58b083b773a908ae40f06681e87e589a976faf8246bf"},
]
[package.dependencies]
@@ -5075,20 +5309,20 @@ multidict = ">=4.0"
[[package]]
name = "zipp"
-version = "3.17.0"
+version = "3.18.1"
description = "Backport of pathlib-compatible object wrapper for zip files"
optional = false
python-versions = ">=3.8"
files = [
- {file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"},
- {file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"},
+ {file = "zipp-3.18.1-py3-none-any.whl", hash = "sha256:206f5a15f2af3dbaee80769fb7dc6f249695e940acca08dfb2a4769fe61e538b"},
+ {file = "zipp-3.18.1.tar.gz", hash = "sha256:2884ed22e7d8961de1c9a05142eb69a247f120291bc0206a00a7642f09b5b715"},
]
[package.extras]
-docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
-testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"]
+docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
+testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8,<4.0"
-content-hash = "7741ae68f922bc6c30656483e5337f5544c5f57c51a4102e067965014c806604"
+content-hash = "0bc401f271115fc5955dccb3cca0d29c981bc204a8894378723ca738b4d0287e"
diff --git a/pyproject.toml b/pyproject.toml
index f89a4bcf3..62d8ab2d7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,112 +1,103 @@
[tool.poetry]
-name = "transformer-lens"
-version = "0.0.0" # This is automatically set by the CD pipeline on release
-description = "An implementation of transformers tailored for mechanistic interpretability."
-authors = ["Neel Nanda <77788841+neelnanda-io@users.noreply.github.com>"]
-license = "MIT"
-readme = "README.md"
-packages = [{include = "transformer_lens"}]
+ authors=["Neel Nanda <77788841+neelnanda-io@users.noreply.github.com>"]
+ description="An implementation of transformers tailored for mechanistic interpretability."
+ license="MIT"
+ name="transformer-lens"
+ packages=[{include="transformer_lens"}]
+ readme="README.md"
+ # Version is automatically set by the pipeline on release
+ version="0.0.0"
-[tool.poetry.scripts]
-build-docs = "docs.make_docs:build_docs"
-docs-hot-reload = "docs.make_docs:docs_hot_reload"
+ [tool.poetry.scripts]
+ build-docs="docs.make_docs:build_docs"
+ docs-hot-reload="docs.make_docs:docs_hot_reload"
-[tool.poetry.dependencies]
-python = ">=3.8,<4.0"
-einops = ">=0.6.0"
-numpy = [{ version = ">=1.20,<1.25", python = ">=3.8,<3.9" },
- { version = ">=1.24", python = ">=3.9,<3.12" },
- { version = ">=1.26", python = ">=3.12,<3.13" }]
-torch = ">=1.10,!=2.0,!=2.1.0"
-# See PyTorch 2 fix below. We pin >=2.1.1 due to MPS errors (See our Slack)
+ [tool.poetry.dependencies]
+ accelerate=">=0.23.0" # Needed for Llama Models
+ beartype="^0.14.1"
+ better-abc="^0.0.3"
+ datasets=">=2.7.1"
+ einops=">=0.6.0"
+ fancy-einsum=">=0.0.3"
+ jaxtyping=">=0.2.11"
+ numpy=[
+ {version=">=1.20,<1.25", python=">=3.8,<3.9"},
+ {version=">=1.24", python=">=3.9,<3.12"},
+ {version=">=1.26", python=">=3.12,<3.13"},
+ ]
+ pandas=">=1.1.5"
+ python=">=3.8,<4.0"
+ rich=">=12.6.0"
+ torch = [
+ {platform = "linux", version = ">=1.10"}, # We can use any torch version on Linux (e.g colab)
+ {platform = "!=linux", version = ">=1.10,!=2.0,!=2.1.0"}, # Pin >=2.1.1 on Apple devices due to known MPS errors on 2.1.0
+ ]
+ tqdm=">=4.64.1"
+ transformers=">=4.37.2"
+ typing-extensions="*"
+ wandb=">=0.13.5"
+ sentencepiece = "*"
-datasets = ">=2.7.1"
-transformers = ">=4.25.1"
-tqdm = ">=4.64.1"
-pandas = ">=1.1.5"
-wandb = ">=0.13.5"
-fancy-einsum = ">=0.0.3"
-rich = ">=12.6.0"
-jaxtyping = ">=0.2.11"
-beartype = "^0.14.1"
-accelerate = ">=0.23.0" # Needed for Llama Models
-typing-extensions = "*"
-# PyTorch 2.0 Bug Fix PyTorch didn't put their dependencies metadata into all wheels for 2.1.0, so
-# it doesn't work with Poetry. This is a known bug - the workaround is to place them manually here
-# (from the one wheel that did correctly list them). This was broken in 2.0.1 and the fix wasn't
-# made for 2.1.0, however Meta are aware of the issue and once it is fixed (and the torch version
-# requirement bumped) this should be removed. Note also the python version is used to specify that
-# this is only added where v2 torch is installed (as per the torch version requirement above).
-# https://github.com/pytorch/pytorch/issues/100974
-# https://github.com/python-poetry/poetry/issues/7902#issuecomment-1583078794
-nvidia-cuda-nvrtc-cu12 = { version = ">=12.1.105", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
-nvidia-cuda-runtime-cu12 = { version = ">=12.1.105", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
-nvidia-cuda-cupti-cu12 = { version = ">=12.1.105", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
-nvidia-cudnn-cu12 = { version = ">=8.9.2.26", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
-nvidia-cublas-cu12 = { version = ">=12.1.3.1", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
-nvidia-cufft-cu12 = { version = ">=11.0.2.54", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
-nvidia-curand-cu12 = { version = ">=10.3.2.106", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
-nvidia-cusolver-cu12 = { version = ">=11.4.5.107", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
-nvidia-cusparse-cu12 = { version = ">=12.1.0.106", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
-nvidia-nccl-cu12 = { version = ">=2.18.1", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
-nvidia-nvtx-cu12 = { version = ">=12.1.105", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
-triton = { version = ">=2.1.0", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
-# End PyTorch 2.1.0 Bug Fix
+ [tool.poetry.group]
+ [tool.poetry.group.dev.dependencies]
+ black="^23.3.0"
+ circuitsvis=">=1.38.1"
+ isort="5.8.0"
+ jupyter=">=1.0.0"
+ mypy=">=1.8.0"
+ nbval="^0.10.0"
+ plotly=">=5.12.0"
+ pycln="^2.1.3"
+ pytest=">=7.2.0"
+ pytest-cov=">=4.0.0"
+ pytest-doctestplus="^1.0.0"
-[tool.poetry.group.dev.dependencies]
-pytest = ">=7.2.0"
-pytest-cov = ">=4.0.0"
-mypy = ">=0.991"
-jupyter = ">=1.0.0"
-circuitsvis = ">=1.38.1"
-plotly = ">=5.12.0"
-isort = "5.8.0"
-black = "^23.3.0"
-pycln = "^2.1.3"
-pytest-doctestplus = "^1.0.0"
-nbval = "^0.10.0"
+ [tool.poetry.group.jupyter.dependencies]
+ ipywidgets="^8.1.1"
+ jupyterlab=">=3.5.0"
-[tool.poetry.group.jupyter.dependencies]
-jupyterlab = ">=3.5.0"
-ipywidgets = "^8.1.1"
+ [tool.poetry.group.docs.dependencies]
+ furo={version=">=2022.12.7"}
+ myst-parser={version=">=0.18.1"}
+ nbconvert="^7.9.2"
+ nbsphinx="^0.9.3"
+ pandoc="^2.3"
+ snowballstemmer="*"
+ sphinx={version="5.2.3"}
+ sphinx-autobuild={version=">=2021.3.14"}
+ sphinxcontrib-napoleon={version=">=0.7"}
+ tabulate={version=">=0.9.0"}
-[tool.poetry.group.docs.dependencies]
-sphinx = {version = "5.2.3" }
-sphinxcontrib-napoleon = {version = ">=0.7" }
-sphinx-autobuild = {version = ">=2021.3.14" }
-furo = {version = ">=2022.12.7" }
-myst-parser = {version = ">=0.18.1" }
-tabulate= {version = ">=0.9.0" }
-snowballstemmer = "*"
-nbsphinx = "^0.9.3"
-pandoc = "^2.3"
-nbconvert = "^7.9.2"
-
-[tool.pytest.ini_options]
-doctest_optionflags = "NORMALIZE_WHITESPACE ELLIPSIS FLOAT_CMP"
-filterwarnings = [
- "ignore:pkg_resources is deprecated as an API:DeprecationWarning",
- # Ignore numpy.distutils deprecation warning caused by pandas
- # More info: https://numpy.org/doc/stable/reference/distutils.html#module-numpy.distutils
- "ignore:distutils Version classes are deprecated:DeprecationWarning"
-]
-addopts = """--jaxtyping-packages=transformer_lens,beartype.beartype \
--W ignore::beartype.roar.BeartypeDecorHintPep585DeprecationWarning \
---deselect transformer_lens/utils.py::test_prompt \
---doctest-modules --doctest-plus \
---nbval"""
+[tool.pytest]
+ [tool.pytest.ini_options]
+ addopts=[
+ "--doctest-modules",
+ "--doctest-plus",
+ "--jaxtyping-packages=transformer_lens,beartype.beartype",
+ "--nbval",
+ "-W ignore::beartype.roar.BeartypeDecorHintPep585DeprecationWarning",
+ ]
+ doctest_optionflags="NORMALIZE_WHITESPACE ELLIPSIS FLOAT_CMP"
+ filterwarnings=[
+ "ignore:pkg_resources is deprecated as an API:DeprecationWarning",
+ # Ignore numpy.distutils deprecation warning caused by pandas
+ # More info: https://numpy.org/doc/stable/reference/distutils.html#module-numpy.distutils
+ "ignore:distutils Version classes are deprecated:DeprecationWarning",
+ ]
[tool.isort]
-profile = "black"
-extend_skip = ["__init__.py", ".venv/"]
+ extend_skip=[".venv/", "__init__.py"]
+ profile="black"
[tool.mypy]
-ignore_missing_imports = true
-check_untyped_defs = true
+ check_untyped_defs=true
+ exclude=[".venv/", "assets", "demos", "docs", "easy_transformer", "tests"]
+ ignore_missing_imports=true
[tool.black]
-# Exclude snapshot tests & .venv
-exclude = '''
+ line-length=100 # Set line length to 100 to match other tools
+ # Exclude snapshot tests & .venv
+ exclude='''
(
/snapshots/
| .venv/
@@ -115,21 +106,21 @@ exclude = '''
[tool.pylint]
[tool.pylint.TYPECHECK]
- # Fix for Pytorch member existence checks
- generated-members = "torch.*"
+ # Fix for Pytorch member existence checks
+ generated-members="torch.*"
[tool.pylint.DESIGN]
- max-args = 10
- max-locals = 30
+ max-args=10
+ max-locals=30
[tool.pylint."MESSAGES CONTROL"]
- disable = "redefined-builtin" # Disable redefined builtin functions
+ disable="redefined-builtin" # Disable redefined builtin functions
[tool.pylint.'MASTER']
- disable = [
- "C0103", # Disable invalid file name (as we use PascalCase for classes)
- ]
+ disable=[
+ "C0103", # Disable invalid file name (as we use PascalCase for classes)
+ ]
[build-system]
-requires = ["poetry-core"]
-build-backend = "poetry.core.masonry.api"
+ build-backend="poetry.core.masonry.api"
+ requires=["poetry-core"]
diff --git a/tests/acceptance/test_activation_cache.py b/tests/acceptance/test_activation_cache.py
index 452d5e37c..aa52cf6b7 100644
--- a/tests/acceptance/test_activation_cache.py
+++ b/tests/acceptance/test_activation_cache.py
@@ -64,18 +64,12 @@ def test_logit_attrs_matches_reference_code():
_, cache = model.run_with_cache(tokens)
# Get accumulated resid
- accumulated_residual = cache.accumulated_resid(
- layer=-1, incl_mid=True, pos_slice=-1
- )
+ accumulated_residual = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1)
# Get ref ave logit diffs (cribbed notebook code)
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
- logit_diff_directions = (
- answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
- )
- scaled_residual_stack = cache.apply_ln_to_stack(
- accumulated_residual, layer=-1, pos_slice=-1
- )
+ logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
+ scaled_residual_stack = cache.apply_ln_to_stack(accumulated_residual, layer=-1, pos_slice=-1)
ref_ave_logit_diffs = einsum(
"... batch d_model, batch d_model -> ...",
scaled_residual_stack,
@@ -111,12 +105,8 @@ def test_logit_attrs_works_for_all_input_shapes():
# Get ref logit diffs (cribbed notebook code)
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
- logit_diff_directions = (
- answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
- )
- scaled_residual_stack = cache.apply_ln_to_stack(
- accumulated_residual, layer=-1, pos_slice=-1
- )
+ logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
+ scaled_residual_stack = cache.apply_ln_to_stack(accumulated_residual, layer=-1, pos_slice=-1)
ref_logit_diffs = einsum(
"... d_model, ... d_model -> ...", scaled_residual_stack, logit_diff_directions
)
@@ -198,9 +188,7 @@ def test_accumulated_resid_with_apply_ln():
_, cache = model.run_with_cache(tokens)
# Get accumulated resid and apply ln seperately (cribbed notebook code)
- accumulated_residual = cache.accumulated_resid(
- layer=-1, incl_mid=True, pos_slice=-1
- )
+ accumulated_residual = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1)
ref_scaled_residual_stack = cache.apply_ln_to_stack(
accumulated_residual, layer=-1, pos_slice=-1
)
@@ -210,9 +198,7 @@ def test_accumulated_resid_with_apply_ln():
layer=-1, incl_mid=True, pos_slice=-1, apply_ln=True
)
- assert torch.isclose(
- ref_scaled_residual_stack, scaled_residual_stack, atol=1e-7
- ).all()
+ assert torch.isclose(ref_scaled_residual_stack, scaled_residual_stack, atol=1e-7).all()
@torch.set_grad_enabled(False)
@@ -227,16 +213,12 @@ def test_decompose_resid_with_apply_ln():
# Get decomposed resid and apply ln seperately (cribbed notebook code)
per_layer_residual = cache.decompose_resid(layer=-1, pos_slice=-1)
- ref_scaled_residual_stack = cache.apply_ln_to_stack(
- per_layer_residual, layer=-1, pos_slice=-1
- )
+ ref_scaled_residual_stack = cache.apply_ln_to_stack(per_layer_residual, layer=-1, pos_slice=-1)
# Get scaled_residual_stack using apply_ln parameter
scaled_residual_stack = cache.decompose_resid(layer=-1, pos_slice=-1, apply_ln=True)
- assert torch.isclose(
- ref_scaled_residual_stack, scaled_residual_stack, atol=1e-7
- ).all()
+ assert torch.isclose(ref_scaled_residual_stack, scaled_residual_stack, atol=1e-7).all()
@torch.set_grad_enabled(False)
@@ -251,18 +233,12 @@ def test_stack_head_results_with_apply_ln():
# Get per head resid stack and apply ln seperately (cribbed notebook code)
per_head_residual = cache.stack_head_results(layer=-1, pos_slice=-1)
- ref_scaled_residual_stack = cache.apply_ln_to_stack(
- per_head_residual, layer=-1, pos_slice=-1
- )
+ ref_scaled_residual_stack = cache.apply_ln_to_stack(per_head_residual, layer=-1, pos_slice=-1)
# Get scaled_residual_stack using apply_ln parameter
- scaled_residual_stack = cache.stack_head_results(
- layer=-1, pos_slice=-1, apply_ln=True
- )
+ scaled_residual_stack = cache.stack_head_results(layer=-1, pos_slice=-1, apply_ln=True)
- assert torch.isclose(
- ref_scaled_residual_stack, scaled_residual_stack, atol=1e-7
- ).all()
+ assert torch.isclose(ref_scaled_residual_stack, scaled_residual_stack, atol=1e-7).all()
@torch.set_grad_enabled(False)
@@ -277,15 +253,9 @@ def test_stack_neuron_results_with_apply_ln():
# Get neuron result stack and apply ln seperately
neuron_result_stack = cache.stack_neuron_results(layer=-1, pos_slice=-1)
- ref_scaled_residual_stack = cache.apply_ln_to_stack(
- neuron_result_stack, layer=-1, pos_slice=-1
- )
+ ref_scaled_residual_stack = cache.apply_ln_to_stack(neuron_result_stack, layer=-1, pos_slice=-1)
# Get scaled_residual_stack using apply_ln parameter
- scaled_residual_stack = cache.stack_neuron_results(
- layer=-1, pos_slice=-1, apply_ln=True
- )
+ scaled_residual_stack = cache.stack_neuron_results(layer=-1, pos_slice=-1, apply_ln=True)
- assert torch.isclose(
- ref_scaled_residual_stack, scaled_residual_stack, atol=1e-7
- ).all()
+ assert torch.isclose(ref_scaled_residual_stack, scaled_residual_stack, atol=1e-7).all()
diff --git a/tests/acceptance/test_hook_tokens.py b/tests/acceptance/test_hook_tokens.py
index 2280daef1..73690ae1a 100644
--- a/tests/acceptance/test_hook_tokens.py
+++ b/tests/acceptance/test_hook_tokens.py
@@ -30,9 +30,7 @@ def test_patch_tokens():
new_first_token = model.to_single_token("Hi")
# Define hook function to alter the first token
- def hook_fn(
- tokens: Int[t.Tensor, "batch seq"], hook: HookPoint, new_first_token: int
- ):
+ def hook_fn(tokens: Int[t.Tensor, "batch seq"], hook: HookPoint, new_first_token: int):
assert (
tokens[0, 0].item() != new_first_token
) # Need new_first_token to be different from original
@@ -43,9 +41,7 @@ def hook_fn(
out_from_hook = model.run_with_hooks(
prompt,
prepend_bos=False,
- fwd_hooks=[
- ("hook_tokens", functools.partial(hook_fn, new_first_token=new_first_token))
- ],
+ fwd_hooks=[("hook_tokens", functools.partial(hook_fn, new_first_token=new_first_token))],
)
out_direct = model(modified_prompt, prepend_bos=False)
diff --git a/tests/acceptance/test_hooked_encoder.py b/tests/acceptance/test_hooked_encoder.py
index 0859686ef..e8394d873 100644
--- a/tests/acceptance/test_hooked_encoder.py
+++ b/tests/acceptance/test_hooked_encoder.py
@@ -41,9 +41,7 @@ def test_full_model(our_bert, huggingface_bert, tokenizer):
input_ids = tokenized["input_ids"]
attention_mask = tokenized["attention_mask"]
- huggingface_bert_out = huggingface_bert(
- input_ids, attention_mask=attention_mask
- ).logits
+ huggingface_bert_out = huggingface_bert(input_ids, attention_mask=attention_mask).logits
our_bert_out = our_bert(input_ids, one_zero_attention_mask=attention_mask)
assert_close(huggingface_bert_out, our_bert_out, rtol=1.3e-6, atol=4e-5)
@@ -97,23 +95,17 @@ def test_bert_block(our_bert, huggingface_bert, hello_world_tokens):
def test_mlm_head(our_bert, huggingface_bert, hello_world_tokens):
- huggingface_bert_core_outputs = huggingface_bert.bert(
- hello_world_tokens
- ).last_hidden_state
+ huggingface_bert_core_outputs = huggingface_bert.bert(hello_world_tokens).last_hidden_state
our_mlm_head_out = our_bert.mlm_head(huggingface_bert_core_outputs)
our_unembed_out = our_bert.unembed(our_mlm_head_out)
- huggingface_predictions_out = huggingface_bert.cls.predictions(
- huggingface_bert_core_outputs
- )
+ huggingface_predictions_out = huggingface_bert.cls.predictions(huggingface_bert_core_outputs)
assert_close(our_unembed_out, huggingface_predictions_out, rtol=1.3e-6, atol=4e-5)
def test_unembed(our_bert, huggingface_bert, hello_world_tokens):
- huggingface_bert_core_outputs = huggingface_bert.bert(
- hello_world_tokens
- ).last_hidden_state
+ huggingface_bert_core_outputs = huggingface_bert.bert(hello_world_tokens).last_hidden_state
our_mlm_head_out = our_bert.mlm_head(huggingface_bert_core_outputs)
huggingface_predictions_out = huggingface_bert.cls.predictions.transform(
@@ -167,9 +159,7 @@ def test_half_precision(dtype):
def test_predictions(our_bert, huggingface_bert, tokenizer):
input_ids = tokenizer("The [MASK] sat on the mat", return_tensors="pt")["input_ids"]
- def get_predictions(
- logits: Float[torch.Tensor, "batch pos d_vocab"], positions: List[int]
- ):
+ def get_predictions(logits: Float[torch.Tensor, "batch pos d_vocab"], positions: List[int]):
logits_at_position = logits.squeeze(0)[positions]
predicted_tokens = F.softmax(logits_at_position, dim=-1).argmax(dim=-1)
return tokenizer.batch_decode(predicted_tokens)
diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py
index e60639e57..9d9e2bb19 100644
--- a/tests/acceptance/test_hooked_transformer.py
+++ b/tests/acceptance/test_hooked_transformer.py
@@ -1,22 +1,25 @@
import gc
import os
+import pandas as pd
import pytest
import torch
-from transformers import AutoConfig, AutoModelForCausalLM
+from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformer_lens import HookedTransformer
from transformer_lens.components import LayerNormPre
-from transformer_lens.loading_from_pretrained import OFFICIAL_MODEL_NAMES
+from transformer_lens.HookedTransformer import DTYPE_FROM_STRING
+from transformer_lens.loading_from_pretrained import (
+ OFFICIAL_MODEL_NAMES,
+ get_official_model_name,
+)
from transformer_lens.utils import clear_huggingface_cache
TINY_STORIES_MODEL_NAMES = [
name for name in OFFICIAL_MODEL_NAMES if name.startswith("roneneldan/TinyStories")
]
-PYTHIA_MODEL_NAMES = [
- name for name in OFFICIAL_MODEL_NAMES if name.startswith("EleutherAI/pythia")
-]
+PYTHIA_MODEL_NAMES = [name for name in OFFICIAL_MODEL_NAMES if name.startswith("EleutherAI/pythia")]
model_names = [
"attn-only-demo",
@@ -33,6 +36,11 @@
"tiny-stories-33M",
"bloom-560m",
"santacoder",
+ "microsoft/phi-1",
+ "microsoft/phi-1_5",
+ "microsoft/phi-2",
+ "google/gemma-2b",
+ "google/gemma-7b",
]
text = "Hello world!"
"""
@@ -167,6 +175,223 @@ def test_from_pretrained_revision():
raise AssertionError("Should have raised an error")
+def check_norm_folding(
+ model_name,
+ hf_model=None,
+ tokenizer=None,
+ prompt="Hello, world!",
+ device=None,
+ dtype=None,
+):
+ """
+ Checks that loading a model with Layer/RMS Norm folding enabled does not (significantly) change its outputs.
+
+ Returns the maximum difference between the logits produced by the same model with and without norm folding enabled.
+
+ Also asserts that this difference is within some tolerance, although this is deliberately set to a high value
+ in order to account for lower precision models.
+ """
+
+ # If a device/dtype is not specified, and hf_model is provided, use its device/dtype
+ # Otherwise, default to cuda (if available)/float32
+ if device is None:
+ if hf_model:
+ device = hf_model.device
+ else:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ if dtype is None:
+ if hf_model:
+ dtype = hf_model.dtype
+ else:
+ dtype = "float32"
+
+ folded_model = HookedTransformer.from_pretrained(
+ model_name=model_name,
+ hf_model=hf_model,
+ device=device,
+ tokenizer=tokenizer,
+ dtype=dtype,
+ fold_ln=True,
+ center_writing_weights=False,
+ center_unembed=False,
+ )
+ tokens = folded_model.to_tokens(prompt)
+ folded_logits = folded_model(tokens).detach()
+ del folded_model
+ torch.cuda.empty_cache()
+
+ unfolded_model = HookedTransformer.from_pretrained(
+ model_name=model_name,
+ hf_model=hf_model,
+ device=device,
+ tokenizer=tokenizer,
+ dtype=dtype,
+ fold_ln=False,
+ center_writing_weights=False,
+ center_unembed=False,
+ )
+ unfolded_logits = unfolded_model(tokens).detach()
+ del unfolded_model
+ torch.cuda.empty_cache()
+
+ assert torch.allclose(
+ torch.softmax(folded_logits, dim=-1),
+ torch.softmax(unfolded_logits, dim=-1),
+ atol=1e-2,
+ )
+
+ return torch.max(
+ torch.abs(torch.softmax(folded_logits, dim=-1) - torch.softmax(unfolded_logits, dim=-1))
+ )
+
+
+def calculate_error(logits1, logits2):
+ t1 = torch.softmax(logits1, dim=-1).to("cpu")
+ t2 = torch.softmax(logits2, dim=-1).to("cpu")
+ err = torch.abs(t1 - t2)
+ return {
+ "max": torch.max(err).item(),
+ "mean": torch.mean(err).item(),
+ "median": torch.median(err).item(),
+ "std": torch.std(err).item(),
+ }
+
+
+def benchmark_model_options(
+ model_name: str,
+ hf_model=None,
+ tokenizer=None,
+ device="cuda",
+ n_devices=1,
+ dtype=torch.float16,
+ cache_in_cpu=True,
+):
+ options = {
+ "fold_ln": False,
+ "center_writing_weights": False,
+ "center_unembed": False,
+ "fold_value_biases": False,
+ }
+
+ prompts = [
+ "Hello, world!",
+ "This is a test.",
+ "What is it about?",
+ "I don't know.",
+ ]
+
+ model_name = get_official_model_name(model_name)
+
+ if hf_model is None:
+ hf_model = AutoModelForCausalLM.from_pretrained(
+ model_name, torch_dtype=dtype, device_map="auto"
+ )
+ if tokenizer is None:
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+
+ tokens = tokenizer(prompts, return_tensors="pt", truncation=True, max_length=4).input_ids.to(
+ device
+ )
+
+ # hf_model = hf_model.to(device)
+ hf_logits = hf_model(tokens).logits.detach()
+ hf_logits = hf_logits.to("cpu")
+
+ if cache_in_cpu:
+ hf_model = hf_model.to("cpu")
+ else:
+ del hf_model
+ hf_model = None
+
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ results = {}
+
+ # Check the error when all processing options are disabled
+ tl_model = HookedTransformer.from_pretrained(
+ model_name,
+ hf_model=hf_model,
+ tokenizer=tokenizer,
+ device=device,
+ n_devices=n_devices,
+ dtype=dtype,
+ **options,
+ )
+ tl_logits = tl_model(tokens).detach().to("cpu")
+ results["no_options"] = calculate_error(hf_logits, tl_logits)
+ del tl_model, tl_logits
+ torch.cuda.empty_cache()
+
+ # Check the error when each processing option is enabled individually
+ for option in options:
+ gc.collect()
+ new_options = options.copy()
+ new_options[option] = True
+ tl_model = HookedTransformer.from_pretrained(
+ model_name,
+ hf_model=hf_model,
+ tokenizer=tokenizer,
+ device=device,
+ n_devices=n_devices,
+ dtype=dtype,
+ **new_options,
+ )
+ tl_logits = tl_model(tokens).detach().to("cpu")
+ results[option] = calculate_error(hf_logits, tl_logits)
+
+ del tl_model, tl_logits
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ # Check the error when all processing options are enabled
+ all_options = {k: True for k, v in options.items()}
+ tl_model = HookedTransformer.from_pretrained(
+ model_name,
+ hf_model=hf_model,
+ tokenizer=tokenizer,
+ device=device,
+ n_devices=n_devices,
+ dtype=dtype,
+ **all_options,
+ )
+ tl_logits = tl_model(tokens).detach().to("cpu")
+ results["all_options"] = calculate_error(hf_logits, tl_logits)
+
+ del tl_model, tl_logits
+
+ del hf_model
+ del tokens
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return results
+
+
+def benchmark_models(models, device="cuda", n_devices=1, cache_in_cpu=True):
+ """
+ Benchmark the error introduced by different options and data types for a list of models.
+ :param models: A dict mapping model names to a list of dtypes to test
+ """
+ rows = []
+
+ for model in models:
+ dtypes = models[model]
+ for dtype in dtypes:
+ print(f"Testing {model} with dtype {dtype}")
+ results = benchmark_model_options(
+ model,
+ device=device,
+ n_devices=n_devices,
+ dtype=DTYPE_FROM_STRING[dtype],
+ cache_in_cpu=cache_in_cpu,
+ )
+ for option, result in results.items():
+ rows.append({"model": model, "dtype": dtype, "options": option, **result})
+
+ return pd.DataFrame(rows)
+
+
def check_similarity_with_hf_model(tl_model, hf_model, prompt="Hello, world!"):
"""
Check that the TransformerLens model and the HuggingFace model
@@ -209,9 +434,7 @@ def check_dtype(dtype, margin, no_processing=False):
for model_path in ["gpt2", "roneneldan/TinyStories-33M", "EleutherAI/pythia-70m"]:
if no_processing:
# For low precision, the processing is not advised.
- model = HookedTransformer.from_pretrained_no_processing(
- model_path, torch_dtype=dtype
- )
+ model = HookedTransformer.from_pretrained_no_processing(model_path, torch_dtype=dtype)
else:
model = HookedTransformer.from_pretrained(model_path, torch_dtype=dtype)
@@ -233,8 +456,12 @@ def check_dtype(dtype, margin, no_processing=False):
gc.collect()
+@pytest.mark.skipif(
+ torch.backends.mps.is_available() or not torch.cuda.is_available(),
+ reason="some operations unsupported by MPS: https://github.com/pytorch/pytorch/issues/77754 or no GPU",
+)
@pytest.mark.parametrize("dtype", [torch.float64, torch.float32])
-def test_dtypes(dtype):
+def test_dtype_float(dtype):
check_dtype(dtype, margin=5e-4)
@@ -266,9 +493,7 @@ def remove_pos_embed(z, hook):
z[:] = 0.0
return z
- _ = model.run_with_hooks(
- "Hello, world", fwd_hooks=[("hook_pos_embed", remove_pos_embed)]
- )
+ _ = model.run_with_hooks("Hello, world", fwd_hooks=[("hook_pos_embed", remove_pos_embed)])
# Check that pos embed has not been permanently changed
assert (model.W_pos == initial_W_pos).all()
diff --git a/tests/acceptance/test_multi_gpu.py b/tests/acceptance/test_multi_gpu.py
index 260344fcf..f5f082c33 100644
--- a/tests/acceptance/test_multi_gpu.py
+++ b/tests/acceptance/test_multi_gpu.py
@@ -19,9 +19,7 @@ def gpt2_medium_on_4_devices():
return model
-@pytest.mark.skipif(
- torch.cuda.device_count() < 4, reason="Requires at least 4 CUDA devices"
-)
+@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 CUDA devices")
def test_get_device_for_block_index(gpt2_medium_on_4_devices):
config = gpt2_medium_on_4_devices.cfg
n_layers = config.n_layers
@@ -44,19 +42,14 @@ def test_get_device_for_block_index(gpt2_medium_on_4_devices):
device_override_obj = torch.device("cuda")
for i in range(n_layers):
expected_device = torch.device(device_override_obj.type, i // layers_per_device)
- assert (
- get_device_for_block_index(i, config, device_override_obj)
- == expected_device
- )
+ assert get_device_for_block_index(i, config, device_override_obj) == expected_device
# Test when index is out of bounds
# with pytest.raises(IndexError):
# get_device_for_block_index(n_layers, config)
-@pytest.mark.skipif(
- torch.cuda.device_count() < 4, reason="Requires at least 4 CUDA devices"
-)
+@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 CUDA devices")
@pytest.mark.parametrize("n_devices", [1, 2, 3, 4])
def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices):
model_1_device = gpt2_medium_on_1_device
@@ -96,9 +89,7 @@ def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices):
cache_device = gpt2_cache_n_devices[f"blocks.{i}.mlp.hook_post"].device
assert cache_device == expected_device
- assert torch.allclose(
- gpt2_logits_1_device.to("cpu"), gpt2_logits_n_devices.to("cpu")
- )
+ assert torch.allclose(gpt2_logits_1_device.to("cpu"), gpt2_logits_n_devices.to("cpu"))
for key in gpt2_cache_1_device.keys():
assert torch.allclose(
gpt2_cache_1_device[key].to("cpu"), gpt2_cache_n_devices[key].to("cpu")
@@ -123,9 +114,7 @@ def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices):
)
-@pytest.mark.skipif(
- torch.cuda.device_count() < 2, reason="Requires at least 2 CUDA devices"
-)
+@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 CUDA devices")
def test_cache_device():
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda:1")
@@ -135,15 +124,11 @@ def test_cache_device():
)
logits, cache = model.run_with_cache("Hello there", device="cpu")
- assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(
- torch.device("cpu")
- )
+ assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(torch.device("cpu"))
model.to("cuda")
logits, cache = model.run_with_cache("Hello there")
- assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(
- logits.device
- )
+ assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(logits.device)
def norm_device(device):
diff --git a/tests/acceptance/test_tokenizer_special_tokens.py b/tests/acceptance/test_tokenizer_special_tokens.py
index 5453b9085..6e4a93a98 100644
--- a/tests/acceptance/test_tokenizer_special_tokens.py
+++ b/tests/acceptance/test_tokenizer_special_tokens.py
@@ -27,9 +27,7 @@ def test_d_vocab_from_tokenizer():
else:
tokenizer_name = loading.get_official_model_name(model_name)
- model = HookedTransformer(
- cfg=cfg, tokenizer=AutoTokenizer.from_pretrained(tokenizer_name)
- )
+ model = HookedTransformer(cfg=cfg, tokenizer=AutoTokenizer.from_pretrained(tokenizer_name))
tokens_with_bos = model.to_tokens(test_string)
tokens_without_bos = model.to_tokens(test_string, prepend_bos=False)
diff --git a/tests/manual_checks/manual_checks_type_annotations.py b/tests/manual_checks/manual_checks_type_annotations.py
index e89c1dcc4..a532a05ac 100644
--- a/tests/manual_checks/manual_checks_type_annotations.py
+++ b/tests/manual_checks/manual_checks_type_annotations.py
@@ -9,9 +9,7 @@
prompt = "Hello World!"
tokens = model.to_tokens(prompt, prepend_bos=False)
logits_tokens = model(tokens)
-logits_text: Float[torch.Tensor, "1 n_tokens d_vocab"] = model(
- prompt, prepend_bos=False
-)
+logits_text: Float[torch.Tensor, "1 n_tokens d_vocab"] = model(prompt, prepend_bos=False)
# n.b. that i used this file to see if my type annotations were working- they were! i occasionally
# changed one of the sizes and saw that the type checker caught it.
diff --git a/tests/unit/factored_matrix/test_multiply_by_matrix.py b/tests/unit/factored_matrix/test_multiply_by_matrix.py
index 85caff1a7..91e2dca4e 100644
--- a/tests/unit/factored_matrix/test_multiply_by_matrix.py
+++ b/tests/unit/factored_matrix/test_multiply_by_matrix.py
@@ -45,9 +45,7 @@ def test_left_multiply_when_both_have_leading_dim(self, a, b, matrix):
b_with_leading = repeat(b, "x y -> b x y", b=2)
matrix_with_leading = repeat(matrix, "x y -> b x y", b=2)
- product = self._test_multiply(
- a_with_leading, b_with_leading, matrix_with_leading
- )
+ product = self._test_multiply(a_with_leading, b_with_leading, matrix_with_leading)
assert product.A.shape[:-2] == (2,)
assert product.B.shape[:-2] == (2,)
diff --git a/tests/unit/factored_matrix/test_properties.py b/tests/unit/factored_matrix/test_properties.py
index d724ea854..c18b760cf 100644
--- a/tests/unit/factored_matrix/test_properties.py
+++ b/tests/unit/factored_matrix/test_properties.py
@@ -66,9 +66,7 @@ def test_transpose_property(self, factored_matrices):
def test_svd_property(self, factored_matrices):
for factored_matrix in factored_matrices:
U, S, Vh = factored_matrix.svd()
- assert torch.allclose(
- factored_matrix.AB, U @ torch.diag_embed(S) @ Vh.T, atol=1e-5
- )
+ assert torch.allclose(factored_matrix.AB, U @ torch.diag_embed(S) @ Vh.T, atol=1e-5)
# test that U and Vh are unitary
assert torch.allclose(U.T @ U, torch.eye(U.shape[-1]), atol=1e-5)
assert torch.allclose(Vh.T @ Vh, torch.eye(Vh.shape[-1]), atol=1e-5)
@@ -76,9 +74,7 @@ def test_svd_property(self, factored_matrices):
def test_svd_property_leading_ones(self, factored_matrices_leading_ones):
for factored_matrix in factored_matrices_leading_ones:
U, S, Vh = factored_matrix.svd()
- assert torch.allclose(
- factored_matrix.AB, U @ torch.diag_embed(S) @ Vh.mT, atol=1e-5
- )
+ assert torch.allclose(factored_matrix.AB, U @ torch.diag_embed(S) @ Vh.mT, atol=1e-5)
# test that U and Vh are unitary
assert torch.allclose(U.mT @ U, torch.eye(U.shape[-1]), atol=1e-5)
assert torch.allclose(Vh.mT @ Vh, torch.eye(Vh.shape[-1]), atol=1e-5)
@@ -123,9 +119,7 @@ def test_pair_property(self, factored_matrices, random_matrices):
def test_norm_property(self, factored_matrices):
for factored_matrix in factored_matrices:
- assert torch.allclose(
- factored_matrix.norm(), factored_matrix.AB.norm(), atol=1e-5
- )
+ assert torch.allclose(factored_matrix.norm(), factored_matrix.AB.norm(), atol=1e-5)
def test_get_corner(self, factored_matrices):
for factored_matrix in factored_matrices:
@@ -143,9 +137,7 @@ def test_ndim(self, factored_matrices):
def test_collapse_l(self, factored_matrices):
for factored_matrix in factored_matrices:
result = factored_matrix.collapse_l()
- expected = factored_matrix.S[..., :, None] * utils.transpose(
- factored_matrix.Vh
- )
+ expected = factored_matrix.S[..., :, None] * utils.transpose(factored_matrix.Vh)
assert torch.allclose(result, expected)
def test_collapse_r(self, factored_matrices):
diff --git a/tests/unit/test_attention_mask.py b/tests/unit/test_attention_mask.py
index df2c147ce..6b0951f5c 100644
--- a/tests/unit/test_attention_mask.py
+++ b/tests/unit/test_attention_mask.py
@@ -35,9 +35,7 @@ def attn_scores_hook(attn_scores, hook):
return attn_scores
def attn_hook(attn, hook):
- assert torch.all(
- attn[:, :, masked] == 0
- ), "Attention pattern attends outside the mask"
+ assert torch.all(attn[:, :, masked] == 0), "Attention pattern attends outside the mask"
return attn
diff --git a/tests/unit/test_cache_pos_slice.py b/tests/unit/test_cache_pos_slice.py
new file mode 100644
index 000000000..1b26c981d
--- /dev/null
+++ b/tests/unit/test_cache_pos_slice.py
@@ -0,0 +1,257 @@
+# %%
+
+import torch
+
+from transformer_lens import HookedTransformer
+
+MODEL = "tiny-stories-1M"
+
+prompt = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."
+model = HookedTransformer.from_pretrained(MODEL)
+# %%
+d_model = model.cfg.d_model
+d_head = model.cfg.d_head
+n_heads = model.cfg.n_heads
+n_layers = model.cfg.n_layers
+# %%
+
+
+def test_run_with_cache_pos_slice_keep_batch():
+ _, cache_no_slice = model.run_with_cache(prompt, return_type=None)
+ num_tokens = len(model.tokenizer.encode(prompt))
+
+ for i in range(-1, num_tokens + 1):
+ _, cache_with_slice = model.run_with_cache(prompt, return_type=None, pos_slice=i)
+
+ assert cache_with_slice["embed"].shape == torch.Size([1, 1, d_model])
+ assert cache_with_slice["q", 0].shape == torch.Size([1, 1, n_heads, d_head])
+
+ assert torch.equal(cache_no_slice["embed"][0, i, :], cache_with_slice["embed"][0, 0, :])
+ assert torch.equal(
+ cache_no_slice["pos_embed"][0, i, :], cache_with_slice["pos_embed"][0, 0, :]
+ )
+
+ for layer in range(n_layers):
+ assert torch.equal(
+ cache_no_slice["resid_pre", layer][0, i, :],
+ cache_with_slice["resid_pre", layer][0, 0, :],
+ )
+ assert torch.equal(
+ cache_no_slice["resid_post", layer][0, i, :],
+ cache_with_slice["resid_post", layer][0, 0, :],
+ )
+ assert torch.equal(
+ cache_no_slice["resid_mid", layer][0, i, :],
+ cache_with_slice["resid_mid", layer][0, 0, :],
+ )
+ assert torch.equal(
+ cache_no_slice["scale", layer, "ln1"][0, i, :],
+ cache_with_slice["scale", layer, "ln1"][0, 0, :],
+ )
+ assert torch.equal(
+ cache_no_slice["scale", layer, "ln2"][0, i, :],
+ cache_with_slice["scale", layer, "ln2"][0, 0, :],
+ )
+ assert torch.equal(
+ cache_no_slice["normalized", layer, "ln1"][0, i, :],
+ cache_with_slice["normalized", layer, "ln1"][0, 0, :],
+ )
+ assert torch.equal(
+ cache_no_slice["normalized", layer, "ln2"][0, i, :],
+ cache_with_slice["normalized", layer, "ln2"][0, 0, :],
+ )
+ assert torch.equal(
+ cache_no_slice[
+ "q",
+ layer,
+ ][0, i, :, :],
+ cache_with_slice[
+ "q",
+ layer,
+ ][0, 0, :, :],
+ )
+ assert torch.equal(
+ cache_no_slice[
+ "k",
+ layer,
+ ][0, i, :, :],
+ cache_with_slice[
+ "k",
+ layer,
+ ][0, 0, :, :],
+ )
+ assert torch.equal(
+ cache_no_slice[
+ "v",
+ layer,
+ ][0, i, :, :],
+ cache_with_slice[
+ "v",
+ layer,
+ ][0, 0, :, :],
+ )
+ assert torch.equal(
+ cache_no_slice[
+ "z",
+ layer,
+ ][0, i, :, :],
+ cache_with_slice[
+ "z",
+ layer,
+ ][0, 0, :, :],
+ )
+ assert torch.equal(
+ cache_no_slice[
+ "attn_scores",
+ layer,
+ ][0, :, i, :],
+ cache_with_slice[
+ "attn_scores",
+ layer,
+ ][0, :, 0, :],
+ )
+ assert torch.equal(
+ cache_no_slice[
+ "pattern",
+ layer,
+ ][0, :, i, :],
+ cache_with_slice[
+ "pattern",
+ layer,
+ ][0, :, 0, :],
+ )
+ assert torch.equal(
+ cache_no_slice["attn_out", layer][0, i, :],
+ cache_with_slice["attn_out", layer][0, 0, :],
+ )
+ assert torch.equal(
+ cache_no_slice["pre", layer][0, i, :],
+ cache_with_slice["pre", layer][0, 0, :],
+ )
+ assert torch.equal(
+ cache_no_slice["post", layer][0, i, :],
+ cache_with_slice["post", layer][0, 0, :],
+ )
+ assert torch.equal(
+ cache_no_slice["mlp_out", layer][0, i, :],
+ cache_with_slice["mlp_out", layer][0, 0, :],
+ )
+
+
+def test_run_with_cache_pos_slice_remove_batch():
+ _, cache_no_slice = model.run_with_cache(prompt, remove_batch_dim=True, return_type=None)
+ num_tokens = len(model.tokenizer.encode(prompt))
+
+ for i in range(-1, num_tokens + 1):
+ _, cache_with_slice = model.run_with_cache(prompt, remove_batch_dim=True, pos_slice=i)
+
+ assert cache_with_slice["embed"].shape == torch.Size([1, d_model])
+ assert cache_with_slice["q", 0].shape == torch.Size([1, n_heads, d_head])
+
+ assert torch.equal(cache_no_slice["embed"][i, :], cache_with_slice["embed"][0, :])
+ assert torch.equal(cache_no_slice["pos_embed"][i, :], cache_with_slice["pos_embed"][0, :])
+
+ for layer in range(n_layers):
+ assert torch.equal(
+ cache_no_slice["resid_pre", layer][i, :],
+ cache_with_slice["resid_pre", layer][0, :],
+ )
+ assert torch.equal(
+ cache_no_slice["resid_post", layer][i, :],
+ cache_with_slice["resid_post", layer][0, :],
+ )
+ assert torch.equal(
+ cache_no_slice["resid_mid", layer][i, :],
+ cache_with_slice["resid_mid", layer][0, :],
+ )
+ assert torch.equal(
+ cache_no_slice["scale", layer, "ln1"][i, :],
+ cache_with_slice["scale", layer, "ln1"][0, :],
+ )
+ assert torch.equal(
+ cache_no_slice["scale", layer, "ln2"][i, :],
+ cache_with_slice["scale", layer, "ln2"][0, :],
+ )
+ assert torch.equal(
+ cache_no_slice["normalized", layer, "ln1"][i, :],
+ cache_with_slice["normalized", layer, "ln1"][0, :],
+ )
+ assert torch.equal(
+ cache_no_slice["normalized", layer, "ln2"][i, :],
+ cache_with_slice["normalized", layer, "ln2"][0, :],
+ )
+ assert torch.equal(
+ cache_no_slice[
+ "q",
+ layer,
+ ][i, :, :],
+ cache_with_slice[
+ "q",
+ layer,
+ ][0, :, :],
+ )
+ assert torch.equal(
+ cache_no_slice[
+ "k",
+ layer,
+ ][i, :, :],
+ cache_with_slice[
+ "k",
+ layer,
+ ][0, :, :],
+ )
+ assert torch.equal(
+ cache_no_slice[
+ "v",
+ layer,
+ ][i, :, :],
+ cache_with_slice[
+ "v",
+ layer,
+ ][0, :, :],
+ )
+ assert torch.equal(
+ cache_no_slice[
+ "z",
+ layer,
+ ][i, :, :],
+ cache_with_slice[
+ "z",
+ layer,
+ ][0, :, :],
+ )
+ assert torch.equal(
+ cache_no_slice[
+ "attn_scores",
+ layer,
+ ][:, i, :],
+ cache_with_slice[
+ "attn_scores",
+ layer,
+ ][:, 0, :],
+ )
+ assert torch.equal(
+ cache_no_slice[
+ "pattern",
+ layer,
+ ][:, i, :],
+ cache_with_slice[
+ "pattern",
+ layer,
+ ][:, 0, :],
+ )
+ assert torch.equal(
+ cache_no_slice["attn_out", layer][i, :],
+ cache_with_slice["attn_out", layer][0, :],
+ )
+ assert torch.equal(
+ cache_no_slice["pre", layer][i, :], cache_with_slice["pre", layer][0, :]
+ )
+ assert torch.equal(
+ cache_no_slice["post", layer][i, :],
+ cache_with_slice["post", layer][0, :],
+ )
+ assert torch.equal(
+ cache_no_slice["mlp_out", layer][i, :],
+ cache_with_slice["mlp_out", layer][0, :],
+ )
diff --git a/tests/unit/test_create_hooked_encoder.py b/tests/unit/test_create_hooked_encoder.py
index a1adc7ef3..45e549aaa 100644
--- a/tests/unit/test_create_hooked_encoder.py
+++ b/tests/unit/test_create_hooked_encoder.py
@@ -6,9 +6,7 @@
@pytest.fixture
def cfg():
- return HookedTransformerConfig(
- d_head=4, d_model=12, n_ctx=5, n_layers=3, act_fn="gelu"
- )
+ return HookedTransformerConfig(d_head=4, d_model=12, n_ctx=5, n_layers=3, act_fn="gelu")
def test_pass_tokenizer(cfg):
diff --git a/tests/unit/test_grouped_query_attention.py b/tests/unit/test_grouped_query_attention.py
new file mode 100644
index 000000000..e5e603454
--- /dev/null
+++ b/tests/unit/test_grouped_query_attention.py
@@ -0,0 +1,94 @@
+import einops
+import torch
+
+from transformer_lens.components import Attention, GroupedQueryAttention
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+
+
+def test_grouped_query_attention_output_is_correct():
+ """Verifies that grouped query attention (GPA) block behaves correctly - see https://arxiv.org/abs/2305.13245v2 for details on GPA.
+ A GPA block with h query heads, n key-value heads, key parameters _K and value parameters _V should have the same output as a regular attention block
+ with h heads, whose parameters K and V are _K and _V repeated h/n times respectively. This test uses torch.repeat_interleave, which is also used by
+ the GPA block internally, to generate K and V from _K and _V"""
+ d_model = 512
+ d_head = 32
+ n_heads = 16
+ n_ctx = 128
+ n_key_value_heads = 4
+ n_layers = 1
+
+ cfg = HookedTransformerConfig(
+ d_model=d_model,
+ d_head=d_head,
+ n_heads=n_heads,
+ n_ctx=n_ctx,
+ n_key_value_heads=n_key_value_heads,
+ n_layers=n_layers,
+ act_fn="silu",
+ )
+
+ regular_attention = Attention(cfg)
+ grouped_query_attention = GroupedQueryAttention(cfg)
+
+ W_Q = torch.rand((n_heads, d_model, d_head))
+ b_Q = torch.rand((n_heads, d_head))
+ _W_K = torch.rand((n_key_value_heads, d_model, d_head))
+ W_K = torch.repeat_interleave(_W_K, dim=0, repeats=n_heads // n_key_value_heads)
+ _b_K = torch.rand((n_key_value_heads, d_head))
+ b_K = torch.repeat_interleave(_b_K, dim=0, repeats=n_heads // n_key_value_heads)
+ _W_V = torch.rand((n_key_value_heads, d_model, d_head))
+ W_V = torch.repeat_interleave(_W_V, dim=0, repeats=n_heads // n_key_value_heads)
+ _b_V = torch.rand((n_key_value_heads, d_head))
+ b_V = torch.repeat_interleave(_b_V, dim=0, repeats=n_heads // n_key_value_heads)
+ W_O = torch.rand((n_heads, d_head, d_model))
+ b_O = torch.rand(d_model)
+
+ regular_attention_state_dict = {
+ "W_Q": W_Q,
+ "b_Q": b_Q,
+ "W_O": W_O,
+ "b_O": b_O,
+ "W_K": W_K,
+ "b_K": b_K,
+ "W_V": W_V,
+ "b_V": b_V,
+ "mask": regular_attention.state_dict()["mask"],
+ "IGNORE": regular_attention.state_dict()["IGNORE"],
+ }
+ grouped_query_attemtion_state_dict = {
+ "W_Q": W_Q,
+ "b_Q": b_Q,
+ "W_O": W_O,
+ "b_O": b_O,
+ "_W_K": _W_K,
+ "_b_K": _b_K,
+ "_W_V": _W_V,
+ "_b_V": _b_V,
+ "mask": grouped_query_attention.state_dict()["mask"],
+ "IGNORE": grouped_query_attention.state_dict()["IGNORE"],
+ }
+
+ regular_attention.load_state_dict(regular_attention_state_dict)
+ grouped_query_attention.load_state_dict(grouped_query_attemtion_state_dict)
+
+ query_input = torch.rand((1, 5, d_model))
+ key_input = torch.rand((1, 5, d_model))
+ value_input = torch.rand((1, 5, d_model))
+
+ regular_attn_output = regular_attention(query_input, key_input, value_input)
+ grouped_query_attn_output = grouped_query_attention(query_input, key_input, value_input)
+
+ assert torch.equal(regular_attn_output, grouped_query_attn_output)
+
+ # Test GQA behaves correctly when use_split_qkv_input is True
+ grouped_query_attention.cfg.use_split_qkv_input = True
+
+ split_query_input = einops.repeat(query_input, "b n d -> b n h d", h=n_heads).clone()
+ split_key_input = einops.repeat(key_input, "b n d -> b n h d", h=n_key_value_heads).clone()
+ split_value_input = einops.repeat(value_input, "b n d -> b n h d", h=n_key_value_heads).clone()
+
+ split_grouped_query_attn_output = grouped_query_attention(
+ split_query_input, split_key_input, split_value_input
+ )
+
+ assert torch.allclose(regular_attn_output, split_grouped_query_attn_output, rtol=1e-6)
diff --git a/tests/unit/test_head_detector.py b/tests/unit/test_head_detector.py
index a039602a2..e465f8805 100644
--- a/tests/unit/test_head_detector.py
+++ b/tests/unit/test_head_detector.py
@@ -284,9 +284,7 @@ def test_detect_head_exclude_bos(error_measure: ErrorMeasure, expected: torch.Te
("abs", expected_previous_exclude_current_token_match_abs),
),
)
-def test_detect_head_exclude_current_token(
- error_measure: ErrorMeasure, expected: torch.Tensor
-):
+def test_detect_head_exclude_current_token(error_measure: ErrorMeasure, expected: torch.Tensor):
assert torch.allclose(
detect_head(
model,
@@ -381,9 +379,7 @@ def test_detect_head_with_invalid_detection_pattern():
class Test_detect_head_non_lower_triangular_detection_pattern:
- detection_pattern = torch.tril(
- torch.ones(test_duplicated_seq_len, test_duplicated_seq_len)
- )
+ detection_pattern = torch.tril(torch.ones(test_duplicated_seq_len, test_duplicated_seq_len))
def test_no_error(self):
detect_head(
@@ -440,16 +436,14 @@ def test_allclose_abs(self):
def test_isclose_mul(self):
assert math.isclose(
torch.sum(self.match_abs),
- self.match_mul[0, 0].item()
- - (model.cfg.n_layers * model.cfg.n_heads - 1),
+ self.match_mul[0, 0].item() - (model.cfg.n_layers * model.cfg.n_heads - 1),
abs_tol=ATOL,
)
def test_isclose_abs(self):
assert math.isclose(
torch.sum(self.match_abs),
- self.match_abs[0, 0].item()
- - (model.cfg.n_layers * model.cfg.n_heads - 1),
+ self.match_abs[0, 0].item() - (model.cfg.n_layers * model.cfg.n_heads - 1),
abs_tol=ATOL,
)
@@ -486,16 +480,14 @@ def test_allclose_abs(self):
def test_isclose_mul(self):
assert math.isclose(
torch.sum(self.match_mul),
- self.match_mul[0, 0].item()
- - (model.cfg.n_layers * model.cfg.n_heads - 1),
+ self.match_mul[0, 0].item() - (model.cfg.n_layers * model.cfg.n_heads - 1),
abs_tol=ATOL,
)
def test_isclose_abs(self):
assert math.isclose(
torch.sum(self.match_abs),
- self.match_abs[0, 0].item()
- - (model.cfg.n_layers * model.cfg.n_heads - 1),
+ self.match_abs[0, 0].item() - (model.cfg.n_layers * model.cfg.n_heads - 1),
abs_tol=ATOL,
)
@@ -532,16 +524,14 @@ def test_allclose_abs(self):
def test_isclose_mul(self):
assert math.isclose(
torch.sum(self.match_mul),
- self.match_mul[0, 0].item()
- - (model.cfg.n_layers * model.cfg.n_heads - 1),
+ self.match_mul[0, 0].item() - (model.cfg.n_layers * model.cfg.n_heads - 1),
abs_tol=ATOL,
)
def test_isclose_abs(self):
assert math.isclose(
torch.sum(self.match_abs),
- self.match_abs[0, 0].item()
- - (model.cfg.n_layers * model.cfg.n_heads - 1),
+ self.match_abs[0, 0].item() - (model.cfg.n_layers * model.cfg.n_heads - 1),
abs_tol=ATOL,
)
@@ -578,16 +568,14 @@ def test_allclose_abs(self):
def test_isclose_mul(self):
assert math.isclose(
torch.sum(self.match_mul),
- self.match_mul[0, 0].item()
- - (model.cfg.n_layers * model.cfg.n_heads - 1),
+ self.match_mul[0, 0].item() - (model.cfg.n_layers * model.cfg.n_heads - 1),
abs_tol=ATOL,
)
def test_isclose_abs(self):
assert math.isclose(
torch.sum(self.match_abs),
- self.match_abs[0, 0].item()
- - (model.cfg.n_layers * model.cfg.n_heads - 1),
+ self.match_abs[0, 0].item() - (model.cfg.n_layers * model.cfg.n_heads - 1),
abs_tol=ATOL,
)
@@ -632,9 +620,7 @@ class Test_duplicate_token_head:
def test1(self):
assert (
- get_duplicate_token_head_detection_pattern(
- model.to_tokens(test_regular_sequence).cpu()
- )
+ get_duplicate_token_head_detection_pattern(model.to_tokens(test_regular_sequence).cpu())
== torch.zeros(4, 4)
).all()
@@ -655,9 +641,7 @@ class Test_induction_head_detection:
def test1(self):
assert (
- get_duplicate_token_head_detection_pattern(
- model.to_tokens(test_regular_sequence).cpu()
- )
+ get_duplicate_token_head_detection_pattern(model.to_tokens(test_regular_sequence).cpu())
== torch.zeros(4, 4)
).all()
diff --git a/tests/unit/test_hooked_sae.py b/tests/unit/test_hooked_sae.py
new file mode 100644
index 000000000..de311772e
--- /dev/null
+++ b/tests/unit/test_hooked_sae.py
@@ -0,0 +1,191 @@
+import einops
+import pytest
+import torch
+
+from transformer_lens import HookedSAE, HookedSAEConfig, HookedSAETransformer
+
+MODEL = "solu-1l"
+prompt = "Hello World!"
+
+
+class Counter:
+ def __init__(self):
+ self.count = 0
+
+ def inc(self, *args, **kwargs):
+ self.count += 1
+
+
+@pytest.fixture(scope="module")
+def model():
+ model = HookedSAETransformer.from_pretrained(MODEL)
+ yield model
+ model.reset_saes()
+
+
+def get_sae_config(model, act_name):
+ site_to_size = {
+ "hook_z": model.cfg.d_head * model.cfg.n_heads,
+ "hook_mlp_out": model.cfg.d_model,
+ "hook_resid_pre": model.cfg.d_model,
+ "hook_post": model.cfg.d_mlp,
+ }
+ site = act_name.split(".")[-1]
+ d_in = site_to_size[site]
+ return HookedSAEConfig(d_in=d_in, d_sae=d_in * 2, hook_name=act_name)
+
+
+@pytest.mark.parametrize(
+ "act_name",
+ [
+ "blocks.0.attn.hook_z",
+ "blocks.0.hook_mlp_out",
+ "blocks.0.mlp.hook_post",
+ "blocks.0.hook_resid_pre",
+ ],
+)
+def test_forward_reconstructs_input(model, act_name):
+ """Verfiy that the HookedSAE returns an output with the same shape as the input activations."""
+ sae_cfg = get_sae_config(model, act_name)
+ hooked_sae = HookedSAE(sae_cfg)
+
+ _, cache = model.run_with_cache(prompt, names_filter=act_name)
+ x = cache[act_name]
+
+ sae_output = hooked_sae(x)
+ assert sae_output.shape == x.shape
+
+
+@pytest.mark.parametrize(
+ "act_name",
+ [
+ "blocks.0.attn.hook_z",
+ "blocks.0.hook_mlp_out",
+ "blocks.0.mlp.hook_post",
+ "blocks.0.hook_resid_pre",
+ ],
+)
+def test_run_with_cache(model, act_name):
+ """Verifies that run_with_cache caches SAE activations"""
+ sae_cfg = get_sae_config(model, act_name)
+ hooked_sae = HookedSAE(sae_cfg)
+
+ _, cache = model.run_with_cache(prompt, names_filter=act_name)
+ x = cache[act_name]
+
+ sae_output, cache = hooked_sae.run_with_cache(x)
+ assert sae_output.shape == x.shape
+
+ assert "hook_sae_input" in cache
+ assert "hook_sae_acts_pre" in cache
+ assert "hook_sae_acts_post" in cache
+ assert "hook_sae_recons" in cache
+ assert "hook_sae_output" in cache
+
+
+@pytest.mark.parametrize(
+ "act_name",
+ [
+ "blocks.0.attn.hook_z",
+ "blocks.0.hook_mlp_out",
+ "blocks.0.mlp.hook_post",
+ "blocks.0.hook_resid_pre",
+ ],
+)
+def test_run_with_hooks(model, act_name):
+ """Verifies that run_with_hooks works with SAE activations"""
+ c = Counter()
+ sae_cfg = get_sae_config(model, act_name)
+ hooked_sae = HookedSAE(sae_cfg)
+
+ _, cache = model.run_with_cache(prompt, names_filter=act_name)
+ x = cache[act_name]
+
+ sae_hooks = [
+ "hook_sae_input",
+ "hook_sae_acts_pre",
+ "hook_sae_acts_post",
+ "hook_sae_recons",
+ "hook_sae_output",
+ ]
+
+ sae_output = hooked_sae.run_with_hooks(
+ x, fwd_hooks=[(sae_hook_name, c.inc) for sae_hook_name in sae_hooks]
+ )
+ assert sae_output.shape == x.shape
+
+ assert c.count == len(sae_hooks)
+
+
+@pytest.mark.parametrize(
+ "act_name",
+ [
+ "blocks.0.attn.hook_z",
+ "blocks.0.hook_mlp_out",
+ "blocks.0.mlp.hook_post",
+ "blocks.0.hook_resid_pre",
+ ],
+)
+def test_error_term(model, act_name):
+ """Verifies that that if we use error_terms, HookedSAE returns an output that is equal to the input activations."""
+ sae_cfg = get_sae_config(model, act_name)
+ sae_cfg.use_error_term = True
+ hooked_sae = HookedSAE(sae_cfg)
+
+ _, cache = model.run_with_cache(prompt, names_filter=act_name)
+ x = cache[act_name]
+
+ sae_output = hooked_sae(x)
+ assert sae_output.shape == x.shape
+ assert torch.allclose(sae_output, x, atol=1e-6)
+
+
+# %%
+@pytest.mark.parametrize(
+ "act_name",
+ [
+ "blocks.0.attn.hook_z",
+ "blocks.0.hook_mlp_out",
+ "blocks.0.mlp.hook_post",
+ "blocks.0.hook_resid_pre",
+ ],
+)
+def test_feature_grads_with_error_term(model, act_name):
+ """Verifies that pytorch backward computes the correct feature gradients when using error_terms. Motivated by the need to compute feature gradients for attribution patching."""
+
+ # Load SAE
+ sae_cfg = get_sae_config(model, act_name)
+ sae_cfg.use_error_term = True
+ hooked_sae = HookedSAE(sae_cfg)
+
+ # Get input activations
+ _, cache = model.run_with_cache(prompt, names_filter=act_name)
+ x = cache[act_name]
+
+ # Cache gradients with respect to feature acts
+ hooked_sae.reset_hooks()
+ grad_cache = {}
+
+ def backward_cache_hook(act, hook):
+ grad_cache[hook.name] = act.detach()
+
+ hooked_sae.add_hook("hook_sae_acts_post", backward_cache_hook, "bwd")
+ hooked_sae.add_hook("hook_sae_output", backward_cache_hook, "bwd")
+
+ sae_output = hooked_sae(x)
+ assert torch.allclose(sae_output, x, atol=1e-6)
+ value = sae_output.sum()
+ value.backward()
+ hooked_sae.reset_hooks()
+
+ # Compute gradient analytically
+ if act_name.endswith("hook_z"):
+ reshaped_output_grad = einops.rearrange(
+ grad_cache["hook_sae_output"], "... n_heads d_head -> ... (n_heads d_head)"
+ )
+ analytic_grad = reshaped_output_grad @ hooked_sae.W_dec.T
+ else:
+ analytic_grad = grad_cache["hook_sae_output"] @ hooked_sae.W_dec.T
+
+ # Compare analytic gradient with pytorch computed gradient
+ assert torch.allclose(grad_cache["hook_sae_acts_post"], analytic_grad, atol=1e-6)
diff --git a/tests/unit/test_hooked_sae_transformer.py b/tests/unit/test_hooked_sae_transformer.py
new file mode 100644
index 000000000..bfb428c8e
--- /dev/null
+++ b/tests/unit/test_hooked_sae_transformer.py
@@ -0,0 +1,515 @@
+import pytest
+import torch
+
+from transformer_lens import (
+ HookedSAE,
+ HookedSAEConfig,
+ HookedSAETransformer,
+ HookedTransformer,
+)
+from transformer_lens.ActivationCache import ActivationCache
+from transformer_lens.hook_points import HookPoint # Hooking utilities
+from transformer_lens.HookedSAETransformer import get_deep_attr
+
+MODEL = "solu-1l"
+prompt = "Hello World!"
+
+
+class Counter:
+ def __init__(self):
+ self.count = 0
+
+ def inc(self, *args, **kwargs):
+ self.count += 1
+
+
+@pytest.fixture(scope="module")
+def original_logits():
+ original_model = HookedTransformer.from_pretrained(MODEL)
+ return original_model(prompt)
+
+
+@pytest.fixture(scope="module")
+def model():
+ model = HookedSAETransformer.from_pretrained(MODEL)
+ yield model
+ model.reset_saes()
+
+
+def get_sae_config(model, act_name):
+ site_to_size = {
+ "hook_z": model.cfg.d_head * model.cfg.n_heads,
+ "hook_mlp_out": model.cfg.d_model,
+ "hook_resid_pre": model.cfg.d_model,
+ "hook_post": model.cfg.d_mlp,
+ }
+ site = act_name.split(".")[-1]
+ d_in = site_to_size[site]
+ return HookedSAEConfig(d_in=d_in, d_sae=d_in * 2, hook_name=act_name)
+
+
+def test_model_with_no_saes_matches_original_model(model, original_logits):
+ """Verifies that HookedSAETransformer behaves like a normal HookedTransformer model when no SAEs are attached."""
+ assert len(model.acts_to_saes) == 0
+ logits = model(prompt)
+ assert torch.allclose(original_logits, logits)
+
+
+@pytest.mark.parametrize(
+ "act_name",
+ [
+ "blocks.0.attn.hook_z",
+ "blocks.0.hook_mlp_out",
+ "blocks.0.mlp.hook_post",
+ "blocks.0.hook_resid_pre",
+ ],
+)
+def test_model_with_saes_does_not_match_original_model(model, act_name, original_logits):
+ """Verifies that the attached (and turned on) SAEs actually affect the models output logits"""
+ assert len(model.acts_to_saes) == 0
+ sae_cfg = get_sae_config(model, act_name)
+ hooked_sae = HookedSAE(sae_cfg)
+ model.add_sae(hooked_sae)
+ assert len(model.acts_to_saes) == 1
+ logits_with_saes = model(prompt)
+ assert not torch.allclose(original_logits, logits_with_saes)
+ model.reset_saes()
+
+
+@pytest.mark.parametrize(
+ "act_name",
+ [
+ "blocks.0.attn.hook_z",
+ "blocks.0.hook_mlp_out",
+ "blocks.0.mlp.hook_post",
+ "blocks.0.hook_resid_pre",
+ ],
+)
+def test_add_sae(model, act_name):
+ """Verifies that add_sae correctly updates the model's acts_to_saes dictionary and replaces the HookPoint."""
+ sae_cfg = get_sae_config(model, act_name)
+ hooked_sae = HookedSAE(sae_cfg)
+ model.add_sae(hooked_sae)
+ assert len(model.acts_to_saes) == 1
+ assert model.acts_to_saes[act_name] == hooked_sae
+ assert get_deep_attr(model, act_name) == hooked_sae
+ model.reset_saes()
+
+
+@pytest.mark.parametrize(
+ "act_name",
+ [
+ "blocks.0.attn.hook_z",
+ "blocks.0.hook_mlp_out",
+ "blocks.0.mlp.hook_post",
+ "blocks.0.hook_resid_pre",
+ ],
+)
+def test_add_sae_overwrites_prev_sae(model, act_name):
+ """Verifies that add_sae correctly updates the model's acts_to_saes dictionary and replaces the HookPoint."""
+ prev_sae_cfg = get_sae_config(model, act_name)
+ prev_hooked_sae = HookedSAE(prev_sae_cfg)
+ model.add_sae(prev_hooked_sae)
+ assert len(model.acts_to_saes) == 1
+ assert model.acts_to_saes[act_name] == prev_hooked_sae
+ assert get_deep_attr(model, act_name) == prev_hooked_sae
+
+ sae_cfg = get_sae_config(model, act_name)
+ hooked_sae = HookedSAE(sae_cfg)
+ model.add_sae(hooked_sae)
+ assert len(model.acts_to_saes) == 1
+ assert model.acts_to_saes[act_name] == hooked_sae
+ assert get_deep_attr(model, act_name) == hooked_sae
+ model.reset_saes()
+
+
+@pytest.mark.parametrize(
+ "act_name",
+ [
+ "blocks.0.attn.hook_z",
+ "blocks.0.hook_mlp_out",
+ "blocks.0.mlp.hook_post",
+ "blocks.0.hook_resid_pre",
+ ],
+)
+def test_reset_sae_removes_sae_by_default(model, act_name):
+ """Verifies that reset_sae correctly removes the SAE from the model's acts_to_saes dictionary and replaces the HookedSAE with a HookPoint."""
+ sae_cfg = get_sae_config(model, act_name)
+ hooked_sae = HookedSAE(sae_cfg)
+ model.add_sae(hooked_sae)
+ assert len(model.acts_to_saes) == 1
+ assert model.acts_to_saes[act_name] == hooked_sae
+ assert get_deep_attr(model, act_name) == hooked_sae
+ model._reset_sae(act_name)
+ assert len(model.acts_to_saes) == 0
+ assert isinstance(get_deep_attr(model, act_name), HookPoint)
+ model.reset_saes()
+
+
+@pytest.mark.parametrize(
+ "act_name",
+ [
+ "blocks.0.attn.hook_z",
+ "blocks.0.hook_mlp_out",
+ "blocks.0.mlp.hook_post",
+ "blocks.0.hook_resid_pre",
+ ],
+)
+def test_reset_sae_replaces_sae(model, act_name):
+ """Verifies that reset_sae correctly removes the SAE from the model's acts_to_saes dictionary and replaces the HookedSAE with a HookPoint."""
+ sae_cfg = get_sae_config(model, act_name)
+ hooked_sae = HookedSAE(sae_cfg)
+
+ prev_sae_cfg = get_sae_config(model, act_name)
+ prev_sae = HookedSAE(prev_sae_cfg)
+
+ model.add_sae(hooked_sae)
+ assert len(model.acts_to_saes) == 1
+ assert model.acts_to_saes[act_name] == hooked_sae
+ assert get_deep_attr(model, act_name) == hooked_sae
+ model._reset_sae(act_name, prev_sae)
+ assert len(model.acts_to_saes) == 1
+ assert get_deep_attr(model, act_name) == prev_sae
+ model.reset_saes()
+
+
+@pytest.mark.parametrize(
+ "act_names",
+ [
+ ["blocks.0.attn.hook_z"],
+ ["blocks.0.hook_mlp_out"],
+ ["blocks.0.mlp.hook_post"],
+ ["blocks.0.hook_resid_pre"],
+ [
+ "blocks.0.attn.hook_z",
+ "blocks.0.hook_mlp_out",
+ "blocks.0.mlp.hook_post",
+ "blocks.0.hook_resid_pre",
+ ],
+ ],
+)
+def test_reset_saes_removes_all_saes_by_default(model, act_names):
+ """Verifies that reset_saes correctly removes all SAEs from the model's acts_to_saes dictionary and replaces the HookedSAEs with HookPoints."""
+ sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names]
+ hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in sae_cfgs]
+ for hooked_sae in hooked_saes:
+ model.add_sae(hooked_sae)
+ assert len(model.acts_to_saes) == len(act_names)
+ for act_name, hooked_sae in zip(act_names, hooked_saes):
+ assert model.acts_to_saes[act_name] == hooked_sae
+ assert get_deep_attr(model, act_name) == hooked_sae
+ model.reset_saes()
+ assert len(model.acts_to_saes) == 0
+ for act_name in act_names:
+ assert isinstance(get_deep_attr(model, act_name), HookPoint)
+ model.reset_saes()
+
+
+@pytest.mark.parametrize(
+ "act_names",
+ [
+ ["blocks.0.attn.hook_z"],
+ ["blocks.0.hook_mlp_out"],
+ ["blocks.0.mlp.hook_post"],
+ ["blocks.0.hook_resid_pre"],
+ [
+ "blocks.0.attn.hook_z",
+ "blocks.0.hook_mlp_out",
+ "blocks.0.mlp.hook_post",
+ "blocks.0.hook_resid_pre",
+ ],
+ ],
+)
+def test_reset_saes_replaces_saes(model, act_names):
+ """Verifies that reset_saes correctly removes all SAEs from the model's acts_to_saes dictionary and replaces the HookedSAEs with HookPoints."""
+ sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names]
+ hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in sae_cfgs]
+ for hooked_sae in hooked_saes:
+ model.add_sae(hooked_sae)
+
+ prev_sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names]
+ prev_hooked_saes = [HookedSAE(prev_sae_cfg) for prev_sae_cfg in prev_sae_cfgs]
+
+ assert len(model.acts_to_saes) == len(act_names)
+ for act_name, hooked_sae in zip(act_names, hooked_saes):
+ assert model.acts_to_saes[act_name] == hooked_sae
+ assert get_deep_attr(model, act_name) == hooked_sae
+ model.reset_saes(act_names, prev_hooked_saes)
+ assert len(model.acts_to_saes) == len(prev_hooked_saes)
+ for act_name, prev_hooked_sae in zip(act_names, prev_hooked_saes):
+ assert get_deep_attr(model, act_name) == prev_hooked_sae
+ model.reset_saes()
+
+
+@pytest.mark.parametrize(
+ "act_names",
+ [
+ ["blocks.0.attn.hook_z"],
+ ["blocks.0.hook_mlp_out"],
+ ["blocks.0.mlp.hook_post"],
+ ["blocks.0.hook_resid_pre"],
+ [
+ "blocks.0.attn.hook_z",
+ "blocks.0.hook_mlp_out",
+ "blocks.0.mlp.hook_post",
+ "blocks.0.hook_resid_pre",
+ ],
+ ],
+)
+def test_saes_context_manager_removes_saes_after(model, act_names):
+ """Verifies that the model.saes context manager successfully adds the SAEs for the specified activation name in the context manager and resets off after the context manager exits."""
+ sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names]
+ hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in sae_cfgs]
+ assert len(model.acts_to_saes) == 0
+ for act_name in act_names:
+ assert isinstance(get_deep_attr(model, act_name), HookPoint)
+ with model.saes(saes=hooked_saes):
+ for act_name, hooked_sae in zip(act_names, hooked_saes):
+ assert model.acts_to_saes[act_name] == hooked_sae
+ assert isinstance(get_deep_attr(model, act_name), HookedSAE)
+ assert get_deep_attr(model, act_name) == hooked_sae
+ model.forward(prompt)
+ assert len(model.acts_to_saes) == 0
+ for act_name in act_names:
+ assert isinstance(get_deep_attr(model, act_name), HookPoint)
+ model.reset_saes()
+
+
+@pytest.mark.parametrize(
+ "act_names",
+ [
+ ["blocks.0.attn.hook_z"],
+ ["blocks.0.hook_mlp_out"],
+ ["blocks.0.mlp.hook_post"],
+ ["blocks.0.hook_resid_pre"],
+ [
+ "blocks.0.attn.hook_z",
+ "blocks.0.hook_mlp_out",
+ "blocks.0.mlp.hook_post",
+ "blocks.0.hook_resid_pre",
+ ],
+ ],
+)
+def test_saes_context_manager_restores_previous_sae_state(model, act_names):
+ """Verifies that the model.saes context manager successfully adds the SAEs for the specified activation name in the context manager and resets off after the context manager exits."""
+ # First add SAEs statefully
+ prev_sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names]
+ prev_hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in prev_sae_cfgs]
+ for act_name, prev_hooked_sae in zip(act_names, prev_hooked_saes):
+ model.add_sae(prev_hooked_sae)
+ assert get_deep_attr(model, act_name) == prev_hooked_sae
+ assert len(model.acts_to_saes) == len(prev_hooked_saes)
+
+ # Now temporarily run with new SAEs
+ sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names]
+ hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in sae_cfgs]
+ with model.saes(saes=hooked_saes):
+ for act_name, hooked_sae in zip(act_names, hooked_saes):
+ assert model.acts_to_saes[act_name] == hooked_sae
+ assert isinstance(get_deep_attr(model, act_name), HookedSAE)
+ assert get_deep_attr(model, act_name) == hooked_sae
+ model.forward(prompt)
+
+ # Check that the previously attached SAEs have been restored
+ assert len(model.acts_to_saes) == len(prev_hooked_saes)
+ for act_name, prev_hooked_sae in zip(act_names, prev_hooked_saes):
+ assert isinstance(get_deep_attr(model, act_name), HookedSAE)
+ assert get_deep_attr(model, act_name) == prev_hooked_sae
+ model.reset_saes()
+
+
+@pytest.mark.parametrize(
+ "act_names",
+ [
+ ["blocks.0.attn.hook_z"],
+ ["blocks.0.hook_mlp_out"],
+ ["blocks.0.mlp.hook_post"],
+ ["blocks.0.hook_resid_pre"],
+ [
+ "blocks.0.attn.hook_z",
+ "blocks.0.hook_mlp_out",
+ "blocks.0.mlp.hook_post",
+ "blocks.0.hook_resid_pre",
+ ],
+ ],
+)
+def test_saes_context_manager_run_with_cache(model, act_names):
+ """Verifies that the model.run_with_cache method works correctly in the context manager."""
+ sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names]
+ hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in sae_cfgs]
+ assert len(model.acts_to_saes) == 0
+ for act_name in act_names:
+ assert isinstance(get_deep_attr(model, act_name), HookPoint)
+ with model.saes(saes=hooked_saes):
+ for act_name, hooked_sae in zip(act_names, hooked_saes):
+ assert model.acts_to_saes[act_name] == hooked_sae
+ assert isinstance(get_deep_attr(model, act_name), HookedSAE)
+ assert get_deep_attr(model, act_name) == hooked_sae
+ model.run_with_cache(prompt)
+ assert len(model.acts_to_saes) == 0
+ for act_name in act_names:
+ assert isinstance(get_deep_attr(model, act_name), HookPoint)
+ model.reset_saes()
+
+
+@pytest.mark.parametrize(
+ "act_names",
+ [
+ ["blocks.0.attn.hook_z"],
+ ["blocks.0.hook_mlp_out"],
+ ["blocks.0.mlp.hook_post"],
+ ["blocks.0.hook_resid_pre"],
+ [
+ "blocks.0.attn.hook_z",
+ "blocks.0.hook_mlp_out",
+ "blocks.0.mlp.hook_post",
+ "blocks.0.hook_resid_pre",
+ ],
+ ],
+)
+def test_run_with_saes(model, act_names, original_logits):
+ """Verifies that the model.run_with_saes method works correctly. The logits with SAEs should be different from the original logits, but the SAE should be removed immediately after the forward pass."""
+ sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names]
+ hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in sae_cfgs]
+ assert len(model.acts_to_saes) == 0
+ logits_with_saes = model.run_with_saes(prompt, saes=hooked_saes)
+ assert not torch.allclose(logits_with_saes, original_logits)
+ assert len(model.acts_to_saes) == 0
+ for act_name in act_names:
+ assert isinstance(get_deep_attr(model, act_name), HookPoint)
+ model.reset_saes()
+
+
+@pytest.mark.parametrize(
+ "act_names",
+ [
+ ["blocks.0.attn.hook_z"],
+ ["blocks.0.hook_mlp_out"],
+ ["blocks.0.mlp.hook_post"],
+ ["blocks.0.hook_resid_pre"],
+ [
+ "blocks.0.attn.hook_z",
+ "blocks.0.hook_mlp_out",
+ "blocks.0.mlp.hook_post",
+ "blocks.0.hook_resid_pre",
+ ],
+ ],
+)
+def test_run_with_cache(model, act_names, original_logits):
+ """Verifies that the model.run_with_cache method works correctly. The logits with SAEs should be different from the original logits and the cache should contain SAE activations for the attached SAE."""
+ sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names]
+ hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in sae_cfgs]
+ for hooked_sae in hooked_saes:
+ model.add_sae(hooked_sae)
+ assert len(model.acts_to_saes) == len(hooked_saes)
+ logits_with_saes, cache = model.run_with_cache(prompt)
+ assert not torch.allclose(logits_with_saes, original_logits)
+ assert isinstance(cache, ActivationCache)
+ for act_name, hooked_sae in zip(act_names, hooked_saes):
+ assert act_name + ".hook_sae_acts_post" in cache
+ assert isinstance(get_deep_attr(model, act_name), HookedSAE)
+ assert get_deep_attr(model, act_name) == hooked_sae
+ model.reset_saes()
+
+
+@pytest.mark.parametrize(
+ "act_names",
+ [
+ ["blocks.0.attn.hook_z"],
+ ["blocks.0.hook_mlp_out"],
+ ["blocks.0.mlp.hook_post"],
+ ["blocks.0.hook_resid_pre"],
+ [
+ "blocks.0.attn.hook_z",
+ "blocks.0.hook_mlp_out",
+ "blocks.0.mlp.hook_post",
+ "blocks.0.hook_resid_pre",
+ ],
+ ],
+)
+def test_run_with_cache_with_saes(model, act_names, original_logits):
+ """Verifies that the model.run_with_cache_with_saes method works correctly. The logits with SAEs should be different from the original logits and the cache should contain SAE activations for the attached SAE."""
+ sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names]
+ hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in sae_cfgs]
+ logits_with_saes, cache = model.run_with_cache_with_saes(prompt, saes=hooked_saes)
+ assert not torch.allclose(logits_with_saes, original_logits)
+ assert isinstance(cache, ActivationCache)
+
+ assert len(model.acts_to_saes) == 0
+ for act_name, hooked_sae in zip(act_names, hooked_saes):
+ assert act_name + ".hook_sae_acts_post" in cache
+ assert isinstance(get_deep_attr(model, act_name), HookPoint)
+ model.reset_saes()
+
+
+@pytest.mark.parametrize(
+ "act_names",
+ [
+ ["blocks.0.attn.hook_z"],
+ ["blocks.0.hook_mlp_out"],
+ ["blocks.0.mlp.hook_post"],
+ ["blocks.0.hook_resid_pre"],
+ [
+ "blocks.0.attn.hook_z",
+ "blocks.0.hook_mlp_out",
+ "blocks.0.mlp.hook_post",
+ "blocks.0.hook_resid_pre",
+ ],
+ ],
+)
+def test_run_with_hooks(model, act_names, original_logits):
+ """Verifies that the model.run_with_hooks method works correctly when SAEs are attached. The count should be incremented by 1 when the hooked SAE is called, and the SAE should stay attached after the forward pass"""
+ c = Counter()
+ sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names]
+ hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in sae_cfgs]
+
+ for hooked_sae in hooked_saes:
+ model.add_sae(hooked_sae)
+
+ logits_with_saes = model.run_with_hooks(
+ prompt, fwd_hooks=[(act_name + ".hook_sae_acts_post", c.inc) for act_name in act_names]
+ )
+ assert not torch.allclose(logits_with_saes, original_logits)
+
+ for act_name, hooked_sae in zip(act_names, hooked_saes):
+ assert isinstance(get_deep_attr(model, act_name), HookedSAE)
+ assert get_deep_attr(model, act_name) == hooked_sae
+ assert c.count == len(act_names)
+ model.reset_saes()
+ model.remove_all_hook_fns(including_permanent=True)
+
+
+@pytest.mark.parametrize(
+ "act_names",
+ [
+ ["blocks.0.attn.hook_z"],
+ ["blocks.0.hook_mlp_out"],
+ ["blocks.0.mlp.hook_post"],
+ ["blocks.0.hook_resid_pre"],
+ [
+ "blocks.0.attn.hook_z",
+ "blocks.0.hook_mlp_out",
+ "blocks.0.mlp.hook_post",
+ "blocks.0.hook_resid_pre",
+ ],
+ ],
+)
+def test_run_with_hooks_with_saes(model, act_names, original_logits):
+ """Verifies that the model.run_with_hooks_with_saes method works correctly when SAEs are attached. The count should be incremented by 1 when the hooked SAE is called, but the SAE should be removed immediately after the forward pass."""
+ c = Counter()
+ sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names]
+ hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in sae_cfgs]
+
+ logits_with_saes = model.run_with_hooks_with_saes(
+ prompt,
+ saes=hooked_saes,
+ fwd_hooks=[(act_name + ".hook_sae_acts_post", c.inc) for act_name in act_names],
+ )
+ assert not torch.allclose(logits_with_saes, original_logits)
+ assert c.count == len(act_names)
+
+ assert len(model.acts_to_saes) == 0
+ for act_name in act_names:
+ assert isinstance(get_deep_attr(model, act_name), HookPoint)
+ model.reset_saes()
+ model.remove_all_hook_fns(including_permanent=True)
diff --git a/tests/unit/test_hooks.py b/tests/unit/test_hooks.py
index 1c41b8e24..231a57150 100644
--- a/tests/unit/test_hooks.py
+++ b/tests/unit/test_hooks.py
@@ -116,9 +116,7 @@ def test_remove_hook():
model.add_perma_hook(embed, c.inc)
assert len(model.hook_dict["hook_embed"].fwd_hooks) == 1 # 1 after adding
model.remove_all_hook_fns()
- assert (
- len(model.hook_dict["hook_embed"].fwd_hooks) == 1
- ) # permanent not removed without flag
+ assert len(model.hook_dict["hook_embed"].fwd_hooks) == 1 # permanent not removed without flag
model.remove_all_hook_fns(including_permanent=True)
assert len(model.hook_dict["hook_embed"].fwd_hooks) == 0 # removed now
model.run_with_hooks(prompt, fwd_hooks=[])
@@ -182,11 +180,7 @@ def identity_hook(z, hook):
@pytest.mark.parametrize(
"zero_attach_pos,prepend",
- [
- (zero_attach_pos, prepend)
- for zero_attach_pos in range(2)
- for prepend in [True, False]
- ],
+ [(zero_attach_pos, prepend) for zero_attach_pos in range(2) for prepend in [True, False]],
)
def test_prepending_hooks(zero_attach_pos, prepend):
"""Add two hooks to a model: one that sets last layer activations to all 0s
diff --git a/tests/unit/test_kv_cache.py b/tests/unit/test_kv_cache.py
index b69b6b3b9..435d7fa42 100644
--- a/tests/unit/test_kv_cache.py
+++ b/tests/unit/test_kv_cache.py
@@ -69,9 +69,7 @@ def test_multiple_new_tokens(pretrained):
past_kv_cache=past_kv_cache,
)
assert t.allclose(no_cache_logits[:, -1], with_cache_logits[:, -1], atol=atol)
- assert t.allclose(
- no_cache_logits[:, -new_tokens_len:], with_cache_logits, atol=atol
- )
+ assert t.allclose(no_cache_logits[:, -new_tokens_len:], with_cache_logits, atol=atol)
@pytest.mark.parametrize("pre_padding", ["left", "right", None])
@@ -95,17 +93,13 @@ def test_multi_token_batch(pretrained, pre_padding, post_padding):
" by the candidate",
]
- first_post_prompt_tokens = model.to_tokens(
- padded_batch_post_prompts[0], prepend_bos=False
- )
+ first_post_prompt_tokens = model.to_tokens(padded_batch_post_prompts[0], prepend_bos=False)
first_full_prompt_tokens = t.cat(
[model.to_tokens(padded_batch_pre_prompts[0]), first_post_prompt_tokens], dim=-1
)
first_post_prompt_len = first_post_prompt_tokens.shape[-1]
first_prompt_no_cache_logits = model(first_full_prompt_tokens)
- first_post_prompt_no_cache_logits = first_prompt_no_cache_logits[
- 0, -first_post_prompt_len:
- ]
+ first_post_prompt_no_cache_logits = first_prompt_no_cache_logits[0, -first_post_prompt_len:]
if pre_padding is None:
batch_pre_prompt_tokens = model.to_tokens(unpadded_batch_pre_prompts)
@@ -116,9 +110,7 @@ def test_multi_token_batch(pretrained, pre_padding, post_padding):
)
if post_padding is None:
- batch_post_prompt_tokens = model.to_tokens(
- unpadded_batch_post_prompts, prepend_bos=False
- )
+ batch_post_prompt_tokens = model.to_tokens(unpadded_batch_post_prompts, prepend_bos=False)
else:
assert post_padding == "left" or post_padding == "right"
batch_post_prompt_tokens = model.to_tokens(
@@ -130,9 +122,7 @@ def test_multi_token_batch(pretrained, pre_padding, post_padding):
past_kv_cache = HookedTransformerKeyValueCache.init_cache(
model.cfg, model.cfg.device, batch_pre_prompt_tokens.shape[0]
)
- model(
- batch_pre_prompt_tokens, past_kv_cache=past_kv_cache, padding_side=pre_padding
- )
+ model(batch_pre_prompt_tokens, past_kv_cache=past_kv_cache, padding_side=pre_padding)
past_kv_cache.freeze()
with_cache_logits = model(
batch_post_prompt_tokens,
@@ -141,14 +131,10 @@ def test_multi_token_batch(pretrained, pre_padding, post_padding):
prepend_bos=False,
)
if post_padding == "left" or post_padding is None:
- first_post_prompt_with_cache_logits = with_cache_logits[
- 0, -first_post_prompt_len:
- ]
+ first_post_prompt_with_cache_logits = with_cache_logits[0, -first_post_prompt_len:]
else:
assert post_padding == "right"
- first_post_prompt_with_cache_logits = with_cache_logits[
- 0, :first_post_prompt_len
- ]
+ first_post_prompt_with_cache_logits = with_cache_logits[0, :first_post_prompt_len]
no_cache_probs = t.softmax(first_post_prompt_no_cache_logits, dim=-1)
with_cache_probs = t.softmax(first_post_prompt_with_cache_logits, dim=-1)
@@ -249,9 +235,7 @@ def test_kv_cache_and_start_at_layer(pretrained):
_, toks, shortformer_pos_embed, attn_mask = model.input_to_embed(
single_new_token, past_kv_cache=past_kv_cache
)
- _, cache = model.run_with_cache(
- single_new_token, stop_at_layer=4, past_kv_cache=past_kv_cache
- )
+ _, cache = model.run_with_cache(single_new_token, stop_at_layer=4, past_kv_cache=past_kv_cache)
resid_3 = cache["blocks.3.hook_resid_pre"]
with_cache_logits = model(
resid_3,
diff --git a/tests/unit/test_left_padding.py b/tests/unit/test_left_padding.py
index b40f97fba..a4961dc74 100644
--- a/tests/unit/test_left_padding.py
+++ b/tests/unit/test_left_padding.py
@@ -89,9 +89,7 @@ def test_pos_embed(self, model, padding_side, prepend_bos):
attended_output_pos_embed = output_pos_embed[attention_mask.bool()]
- assert torch.allclose(
- attended_output_pos_embed, target_output_pos_embed, atol=1e-4
- )
+ assert torch.allclose(attended_output_pos_embed, target_output_pos_embed, atol=1e-4)
# padded positions should have zero pos_embed
assert output_pos_embed[~attention_mask.bool()].sum() == 0
@@ -117,9 +115,7 @@ def test_pos_embed_with_cache(self, model, padding_side, prepend_bos):
model.tokenizer, tokens, prepend_bos
) # [batch pos]
past_kv_cache.append_attention_mask(attention_mask)
- attention_mask_2 = utils.get_attention_mask(
- model.tokenizer, tokens_2, False
- ) # [batch pos]
+ attention_mask_2 = utils.get_attention_mask(model.tokenizer, tokens_2, False) # [batch pos]
cached_attention_mask = past_kv_cache.append_attention_mask(attention_mask_2)
output_pos_embed = model.pos_embed(
@@ -141,9 +137,7 @@ def test_pos_embed_with_cache(self, model, padding_side, prepend_bos):
attended_output_pos_embed = output_pos_embed[attention_mask_2.bool()]
- assert torch.allclose(
- attended_output_pos_embed, target_output_pos_embed, atol=1e-4
- )
+ assert torch.allclose(attended_output_pos_embed, target_output_pos_embed, atol=1e-4)
# padded positions should have zero pos_embed
assert output_pos_embed[~attention_mask_2.bool()].sum() == 0
diff --git a/tests/unit/test_make_docs.py b/tests/unit/test_make_docs.py
index fdb421962..47f5afe96 100644
--- a/tests/unit/test_make_docs.py
+++ b/tests/unit/test_make_docs.py
@@ -1,8 +1,9 @@
"""Make Docs Tests."""
+
import pytest
from docs.make_docs import get_config, get_property
-from transformer_lens import HookedTransformerConfig
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
def test_get_config():
diff --git a/tests/unit/test_only_tokenizer.py b/tests/unit/test_only_tokenizer.py
index 867224d3d..fa2642b1b 100644
--- a/tests/unit/test_only_tokenizer.py
+++ b/tests/unit/test_only_tokenizer.py
@@ -30,17 +30,13 @@ def __init__(
elif self.cfg.tokenizer_name is not None:
# If we have a tokenizer name, we can load it from HuggingFace
self.set_tokenizer(
- AutoTokenizer.from_pretrained(
- self.cfg.tokenizer_name, add_bos_token=True
- ),
+ AutoTokenizer.from_pretrained(self.cfg.tokenizer_name, add_bos_token=True),
default_padding_side=default_padding_side,
)
else:
# If no tokenizer name is provided, we assume we're training on an algorithmic task and will pass in tokens
# directly. In this case, we don't need a tokenizer.
- assert (
- self.cfg.d_vocab != -1
- ), "Must provide a tokenizer if d_vocab is not provided"
+ assert self.cfg.d_vocab != -1, "Must provide a tokenizer if d_vocab is not provided"
self.tokenizer = None
if default_padding_side != "right":
logging.warning(
@@ -101,9 +97,7 @@ class TestTokenizer:
# helper functions
def get_num_tokens_in_prompt(self, model, prompt, intended_prepend_bos):
- tokenizer = AutoTokenizer.from_pretrained(
- model.tokenizer.name_or_path, add_bos_token=False
- )
+ tokenizer = AutoTokenizer.from_pretrained(model.tokenizer.name_or_path, add_bos_token=False)
tokens = tokenizer(
prompt,
)["input_ids"]
@@ -126,9 +120,7 @@ def check_tokens_length(self, model, str_tokens, tokens, intended_prepend_bos):
assert len(str_tokens) == tokens.shape[1] == expected_num_tokens
def check_prompt(self, model, intended_prepend_bos, overriding_prepend_bos=None):
- str_tokens = model.to_str_tokens(
- self.prompt, prepend_bos=overriding_prepend_bos
- )
+ str_tokens = model.to_str_tokens(self.prompt, prepend_bos=overriding_prepend_bos)
tokens = model.to_tokens(self.prompt, prepend_bos=overriding_prepend_bos)
self.check_first_token(model, str_tokens, tokens, intended_prepend_bos)
@@ -164,9 +156,7 @@ def check_prompts(
if model.tokenizer.pad_token_id != model.tokenizer.bos_token_id:
if intended_prepend_bos:
- assert (tokens == model.tokenizer.bos_token_id).sum() == tokens.shape[
- 0
- ], tokens
+ assert (tokens == model.tokenizer.bos_token_id).sum() == tokens.shape[0], tokens
else:
assert (tokens == model.tokenizer.bos_token_id).sum() == 0, tokens
@@ -220,9 +210,7 @@ def test_given_defaults(self, model_name):
@pytest.mark.parametrize("intended_prepend_bos", [True, False])
@pytest.mark.parametrize("intended_padding_side", ["left", "right"])
- def test_changing_defaults(
- self, model, intended_prepend_bos, intended_padding_side
- ):
+ def test_changing_defaults(self, model, intended_prepend_bos, intended_padding_side):
model.tokenizer.padding_side = intended_padding_side
model.cfg.default_prepend_bos = intended_prepend_bos
@@ -231,9 +219,7 @@ def test_changing_defaults(
@pytest.mark.parametrize("intended_prepend_bos", [True, False])
@pytest.mark.parametrize("intended_padding_side", ["left", "right"])
- def test_overriding_defaults(
- self, model, intended_prepend_bos, intended_padding_side
- ):
+ def test_overriding_defaults(self, model, intended_prepend_bos, intended_padding_side):
self.check_prompt(model, intended_prepend_bos, intended_prepend_bos)
self.check_prompts(
model,
diff --git a/tests/unit/test_prepend_bos.py b/tests/unit/test_prepend_bos.py
index 949939936..afb85d933 100644
--- a/tests/unit/test_prepend_bos.py
+++ b/tests/unit/test_prepend_bos.py
@@ -9,9 +9,7 @@ class TestPrependBos:
# helper functions
def get_num_tokens_in_prompt(self, model, prompt, intended_prepend_bos):
- tokenizer = AutoTokenizer.from_pretrained(
- model.tokenizer.name_or_path, add_bos_token=False
- )
+ tokenizer = AutoTokenizer.from_pretrained(model.tokenizer.name_or_path, add_bos_token=False)
tokens = tokenizer(
prompt,
)["input_ids"]
@@ -26,15 +24,11 @@ def check_first_token(self, model, str_tokens, tokens, intended_prepend_bos):
assert str_tokens[0] != model.tokenizer.bos_token
assert tokens[0][0] != model.tokenizer.bos_token_id
- def check_tokens_length(
- self, model, logits, str_tokens, tokens, intended_prepend_bos
- ):
+ def check_tokens_length(self, model, logits, str_tokens, tokens, intended_prepend_bos):
expected_num_tokens = self.get_num_tokens_in_prompt(
model, self.prompt, intended_prepend_bos
)
- assert (
- logits.shape[1] == len(str_tokens) == tokens.shape[1] == expected_num_tokens
- )
+ assert logits.shape[1] == len(str_tokens) == tokens.shape[1] == expected_num_tokens
# fixtures
@pytest.fixture(scope="class", params=["gpt2", "facebook/opt-125m"])
@@ -59,13 +53,9 @@ def test_default_prepend_bos(self, model_name):
tokens = model.to_tokens(self.prompt) # [batch pos]
self.check_first_token(model, str_tokens, tokens, intended_prepend_bos)
- self.check_tokens_length(
- model, logits, str_tokens, tokens, intended_prepend_bos
- )
+ self.check_tokens_length(model, logits, str_tokens, tokens, intended_prepend_bos)
- bos_position = model.get_token_position(
- model.tokenizer.bos_token_id, self.prompt
- )
+ bos_position = model.get_token_position(model.tokenizer.bos_token_id, self.prompt)
assert bos_position == 0
def test_default_prepend_bos_to_false(self, model_name):
@@ -80,34 +70,24 @@ def test_default_prepend_bos_to_false(self, model_name):
tokens = model.to_tokens(self.prompt)
self.check_first_token(model, str_tokens, tokens, intended_prepend_bos)
- self.check_tokens_length(
- model, logits, str_tokens, tokens, intended_prepend_bos
- )
+ self.check_tokens_length(model, logits, str_tokens, tokens, intended_prepend_bos)
@pytest.mark.parametrize("intended_prepend_bos", [True, False])
def test_override_prepend_bos(self, model, intended_prepend_bos):
for default_prepend_bos in [True, False]:
model.cfg.default_prepend_bos = default_prepend_bos
- logits = model(
- self.prompt, prepend_bos=intended_prepend_bos
- ) # [batch pos d_vocab]
- str_tokens = model.to_str_tokens(
- self.prompt, prepend_bos=intended_prepend_bos
- )
+ logits = model(self.prompt, prepend_bos=intended_prepend_bos) # [batch pos d_vocab]
+ str_tokens = model.to_str_tokens(self.prompt, prepend_bos=intended_prepend_bos)
tokens = model.to_tokens(self.prompt, prepend_bos=intended_prepend_bos)
self.check_first_token(model, str_tokens, tokens, intended_prepend_bos)
- self.check_tokens_length(
- model, logits, str_tokens, tokens, intended_prepend_bos
- )
+ self.check_tokens_length(model, logits, str_tokens, tokens, intended_prepend_bos)
def test_prepend_bos_with_get_token_position(self, model_name):
model = HookedTransformer.from_pretrained(model_name)
- bos_position = model.get_token_position(
- model.tokenizer.bos_token_id, self.prompt
- )
+ bos_position = model.get_token_position(model.tokenizer.bos_token_id, self.prompt)
assert bos_position == 0
with pytest.raises(AssertionError):
@@ -117,9 +97,7 @@ def test_prepend_bos_with_get_token_position(self, model_name):
model.cfg.default_prepend_bos = False
with pytest.raises(AssertionError):
- bos_position = model.get_token_position(
- model.tokenizer.bos_token_id, self.prompt
- )
+ bos_position = model.get_token_position(model.tokenizer.bos_token_id, self.prompt)
bos_position = model.get_token_position(
model.tokenizer.bos_token_id, self.prompt, prepend_bos=True
diff --git a/tests/unit/test_split_qkv.py b/tests/unit/test_split_qkv.py
new file mode 100644
index 000000000..e8586305b
--- /dev/null
+++ b/tests/unit/test_split_qkv.py
@@ -0,0 +1,73 @@
+import torch
+
+from transformer_lens import HookedTransformer
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+
+
+def test_split_qkv_normal_attn_correct():
+ """Verifies that the split_qkv_input flag does not change the output for models with normal attention."""
+ d_model = 128
+ d_head = 8
+ n_heads = 16
+ n_ctx = 128
+ n_layers = 1
+ d_vocab = 10
+
+ cfg = HookedTransformerConfig(
+ d_model=d_model,
+ d_head=d_head,
+ n_heads=n_heads,
+ n_ctx=n_ctx,
+ n_layers=n_layers,
+ attn_only=True,
+ d_vocab=d_vocab,
+ )
+
+ model = HookedTransformer(cfg)
+ assert model.cfg.use_split_qkv_input is False
+
+ x = torch.arange(1, 9).unsqueeze(0)
+ normal_output = model(x)
+
+ model.set_use_split_qkv_input(True)
+ assert model.cfg.use_split_qkv_input is True
+
+ split_output = model(x)
+
+ assert torch.allclose(normal_output, split_output, atol=1e-6)
+
+
+def test_split_qkv_grouped_query_attn_correct():
+ """Verifies that the split_qkv_input flag does not change the output for models with grouped query attention."""
+
+ d_model = 128
+ d_head = 8
+ n_heads = 16
+ n_ctx = 128
+ n_key_value_heads = 2
+ n_layers = 1
+ d_vocab = 10
+
+ cfg = HookedTransformerConfig(
+ d_model=d_model,
+ d_head=d_head,
+ n_heads=n_heads,
+ n_ctx=n_ctx,
+ n_key_value_heads=n_key_value_heads,
+ n_layers=n_layers,
+ attn_only=True,
+ d_vocab=d_vocab,
+ )
+
+ model = HookedTransformer(cfg)
+ assert model.cfg.use_split_qkv_input is False
+
+ x = torch.arange(1, 9).unsqueeze(0)
+ normal_output = model(x)
+
+ model.set_use_split_qkv_input(True)
+ assert model.cfg.use_split_qkv_input is True
+
+ split_output = model(x)
+
+ assert torch.allclose(normal_output, split_output, atol=1e-6)
diff --git a/tests/unit/test_start_at_layer.py b/tests/unit/test_start_at_layer.py
index c87779bca..f1d007829 100644
--- a/tests/unit/test_start_at_layer.py
+++ b/tests/unit/test_start_at_layer.py
@@ -124,9 +124,7 @@ def test_no_start_logit_output(setup_data: Dict[str, Any]):
def test_no_start_none_output(setup_data: Dict[str, Any]):
model, rand_input = setup_data["model"], setup_data["rand_input"]
- output, cache = model.run_with_cache(
- rand_input, start_at_layer=None, return_type=None
- )
+ output, cache = model.run_with_cache(rand_input, start_at_layer=None, return_type=None)
assert output is None
assert "hook_embed" in cache.keys()
@@ -183,11 +181,7 @@ def test_start_at_layer_kwargs():
shortformer_pos_embed,
attention_mask,
) = model.input_to_embed(input)
- assert (
- tokens is not None
- and shortformer_pos_embed is not None
- and attention_mask is not None
- )
+ assert tokens is not None and shortformer_pos_embed is not None and attention_mask is not None
start_at_layer_output = model(
rand_embed,
diff --git a/tests/unit/test_stop_at_layer.py b/tests/unit/test_stop_at_layer.py
index 3bbae6cd4..2692c8f49 100644
--- a/tests/unit/test_stop_at_layer.py
+++ b/tests/unit/test_stop_at_layer.py
@@ -220,9 +220,7 @@ def test_no_stop_no_output():
)
rand_input = torch.randint(0, 20, (2, 10))
- output, cache = model.run_with_cache(
- rand_input, stop_at_layer=None, return_type=None
- )
+ output, cache = model.run_with_cache(rand_input, stop_at_layer=None, return_type=None)
assert output is None
assert "hook_embed" in cache.keys()
diff --git a/tests/unit/test_svd_interpreter.py b/tests/unit/test_svd_interpreter.py
index d23d30ac5..face0643b 100644
--- a/tests/unit/test_svd_interpreter.py
+++ b/tests/unit/test_svd_interpreter.py
@@ -114,9 +114,7 @@ def test_svd_interpreter_returns_different_answers_for_different_models():
def test_svd_interpreter_fails_on_invalid_vector_type():
svd_interpreter = SVDInterpreter(model)
with pytest.raises(BeartypeCallHintParamViolation) as e:
- svd_interpreter.get_singular_vectors(
- "test", layer_index=0, num_vectors=4, head_index=0
- )
+ svd_interpreter.get_singular_vectors("test", layer_index=0, num_vectors=4, head_index=0)
def test_svd_interpreter_fails_on_not_passing_required_head_index():
@@ -130,9 +128,7 @@ def test_svd_interpreter_fails_on_invalid_layer_index():
svd_interpreter = SVDInterpreter(model)
for vector in VECTOR_TYPES:
with pytest.raises(AssertionError) as e:
- svd_interpreter.get_singular_vectors(
- vector, layer_index=2, num_vectors=4, head_index=0
- )
+ svd_interpreter.get_singular_vectors(vector, layer_index=2, num_vectors=4, head_index=0)
assert str(e.value) == "Layer index must be between 0 and 1 but got 2"
@@ -140,7 +136,5 @@ def test_svd_interpreter_fails_on_invalid_head_index():
# Only OV uses head index.
svd_interpreter = SVDInterpreter(model)
with pytest.raises(AssertionError) as e:
- svd_interpreter.get_singular_vectors(
- "OV", layer_index=0, num_vectors=4, head_index=8
- )
+ svd_interpreter.get_singular_vectors("OV", layer_index=0, num_vectors=4, head_index=8)
assert str(e.value) == "Head index must be between 0 and 7 but got 8"
diff --git a/tests/unit/test_tokenization_methods.py b/tests/unit/test_tokenization_methods.py
index 83e189c03..acba3ebd7 100644
--- a/tests/unit/test_tokenization_methods.py
+++ b/tests/unit/test_tokenization_methods.py
@@ -58,9 +58,7 @@ def test_to_tokens_device():
s = "Hello, world!"
tokens1 = model.to_tokens(s, move_to_device=False)
tokens2 = model.to_tokens(s, move_to_device=True)
- assert equal(
- tokens1, tokens2
- ), "move to device has no effect when running tests on CPU"
+ assert equal(tokens1, tokens2), "move to device has no effect when running tests on CPU"
def test_to_tokens_truncate():
@@ -125,9 +123,7 @@ def test_get_token_position_not_found():
input = "There were some biomolecules"
with pytest.raises(AssertionError) as exc_info:
model.get_token_position(single, input)
- assert (
- str(exc_info.value) == "The token does not occur in the prompt"
- ), "assertion error"
+ assert str(exc_info.value) == f"The token does not occur in the prompt", "assertion error"
def test_get_token_position_str():
diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py
index 6995a7f00..7feec34a0 100644
--- a/tests/unit/test_utils.py
+++ b/tests/unit/test_utils.py
@@ -3,6 +3,7 @@
import numpy as np
import pytest
import torch
+from torch import nn
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer
@@ -273,21 +274,13 @@ def test_test_prompt(
def test_override_or_use_default_value():
# Case when override is not None
assert utils.override_or_use_default_value(default_flag=True, override=True) == True
- assert (
- utils.override_or_use_default_value(default_flag=True, override=False) == False
- )
- assert (
- utils.override_or_use_default_value(default_flag=False, override=True) == True
- )
- assert (
- utils.override_or_use_default_value(default_flag=False, override=False) == False
- )
+ assert utils.override_or_use_default_value(default_flag=True, override=False) == False
+ assert utils.override_or_use_default_value(default_flag=False, override=True) == True
+ assert utils.override_or_use_default_value(default_flag=False, override=False) == False
# Case when override is None
assert utils.override_or_use_default_value(default_flag=True, override=None) == True
- assert (
- utils.override_or_use_default_value(default_flag=False, override=None) == False
- )
+ assert utils.override_or_use_default_value(default_flag=False, override=None) == False
# Case when override is not passed
assert utils.override_or_use_default_value(default_flag=True) == True
@@ -321,9 +314,7 @@ def model(self, model_name):
@pytest.mark.parametrize("padding_side", ["left", "right"])
@pytest.mark.parametrize("prepend_bos", [True, False])
@pytest.mark.parametrize("prompts_with_sep", [True, False])
- def test_get_attention_mask(
- self, model, padding_side, prepend_bos, prompts_with_sep
- ):
+ def test_get_attention_mask(self, model, padding_side, prepend_bos, prompts_with_sep):
# setup
model.tokenizer.padding_side = padding_side
model.tokenizer.sep_token_id = model.tokenizer.pad_token_id
@@ -377,3 +368,178 @@ def test_get_attention_mask(
else:
# otherwise, there should be no attended but non-pad token
assert attended_but_non_pad_mask.sum() == 0
+
+
+def test_calc_fan_in_fan_out():
+ """
+ Test for the calc_fan_in_and_fan_out function in the utils module.
+ """
+ # Test for the case when the tensor is 1D
+ tensor_1d = torch.tensor([1, 2, 3, 4, 5])
+ fan_in, fan_out = utils.calc_fan_in_and_fan_out(tensor_1d)
+ assert fan_in == 1
+ assert fan_out == 5
+
+ # Test for the case when the tensor is 2D
+ tensor_2d = torch.tensor([[1, 2, 3], [4, 5, 6]])
+ fan_in, fan_out = utils.calc_fan_in_and_fan_out(tensor_2d)
+ assert fan_in == 2
+ assert fan_out == 3
+
+ # Test for the case when the tensor is 3D
+ tensor_3d = nn.Parameter(torch.rand(2, 25, 5)) # 2 x 25 x 5, I'm not writing this out
+ fan_in, fan_out = utils.calc_fan_in_and_fan_out(tensor_3d)
+ assert fan_in == 25
+ assert fan_out == 10
+
+ # Test for the case when the tensor is 4D (should raise a ValueError)
+ tensor_4d = torch.tensor([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]])
+ with pytest.raises(ValueError):
+ fan_in, fan_out = utils.calc_fan_in_and_fan_out(tensor_4d)
+
+ # Test for the case when the tensor is 0D (also should raise a ValueError)
+ tensor_0d = torch.tensor(1)
+ with pytest.raises(ValueError):
+ fan_in, fan_out = utils.calc_fan_in_and_fan_out(tensor_0d)
+
+
+class TestInitKaiming:
+ """Test cases for kaiming init."""
+
+ @pytest.mark.parametrize(
+ "d_model", [4096, 10_000]
+ ) # this needs to be large so std and min/max estimates are accurate
+ @pytest.mark.parametrize("d_mlp", [256, 512])
+ @pytest.mark.parametrize("nonlinearity", ["linear", "relu"])
+ def test_init_kaiming_uniform(self, d_model, d_mlp, nonlinearity):
+ """
+ Test init_kaiming_uniform function in the utils module on 3/2/1D tensors.
+ """
+ torch.manual_seed(1234)
+
+ gain = np.sqrt(2.0) if nonlinearity == "relu" else 1.0
+
+ x = nn.Parameter(torch.empty(2, d_model, 137)) # n_head and d_head don't matter
+ utils.init_kaiming_uniform_(x, nonlinearity=nonlinearity)
+ std = gain / np.sqrt(d_model)
+ assert np.isclose(x.std().detach().numpy(), std, rtol=1e-2)
+ # for uniform distributions, min/max is sqrt(3) times the std
+ assert np.isclose(x.max().detach().numpy(), np.sqrt(3) * std, rtol=1e-2)
+ assert np.isclose(x.min().detach().numpy(), -np.sqrt(3) * std, rtol=1e-2)
+
+ y = nn.Parameter(torch.empty(d_mlp, d_model))
+ utils.init_kaiming_uniform_(y, nonlinearity=nonlinearity)
+ std = gain / np.sqrt(d_mlp)
+ assert np.isclose(y.std().detach().numpy(), std, rtol=1e-2)
+ # for uniform distributions, min/max is sqrt(3) times the std
+ assert np.isclose(y.max().detach().numpy(), np.sqrt(3) * std, rtol=1e-2)
+ assert np.isclose(y.min().detach().numpy(), -np.sqrt(3) * std, rtol=1e-2)
+
+ z = nn.Parameter(torch.empty(d_model * 123))
+ utils.init_kaiming_uniform_(z, nonlinearity=nonlinearity)
+ std = gain # bias has fan_in 1
+ assert np.isclose(z.std().detach().numpy(), std, rtol=1e-2)
+ # for uniform distributions, min/max is sqrt(3) times the std
+ assert np.isclose(z.max().detach().numpy(), np.sqrt(3) * std, rtol=1e-2)
+ assert np.isclose(z.min().detach().numpy(), -np.sqrt(3) * std, rtol=1e-2)
+
+ torch.manual_seed(1234)
+ x_new = nn.Parameter(torch.empty(2, d_model, 137))
+ utils.init_kaiming_uniform_(x_new, nonlinearity=nonlinearity)
+ assert torch.allclose(x_new, x, rtol=1e-2)
+
+ @pytest.mark.parametrize("d_model", [4096, 10_000])
+ @pytest.mark.parametrize("d_mlp", [256, 512])
+ @pytest.mark.parametrize("nonlinearity", ["linear", "relu"])
+ def test_init_kaiming_normal(self, d_model, d_mlp, nonlinearity):
+ """
+ Test init_kaiming_normal function in the utils module on 3/2/1D tensors.
+ """
+ torch.manual_seed(1234)
+
+ gain = np.sqrt(2.0) if nonlinearity == "relu" else 1.0
+
+ x = nn.Parameter(torch.empty(2, d_model, 137))
+ utils.init_kaiming_normal_(x, nonlinearity=nonlinearity)
+ std = gain / np.sqrt(d_model)
+ assert np.isclose(x.std().detach().numpy(), std, rtol=1e-2)
+
+ y = nn.Parameter(torch.empty(d_mlp, d_model))
+ utils.init_kaiming_normal_(y, nonlinearity=nonlinearity)
+ std = gain / np.sqrt(d_mlp)
+ assert np.isclose(y.std().detach().numpy(), std, rtol=1e-2)
+
+ z = nn.Parameter(torch.empty(d_model * 123))
+ utils.init_kaiming_normal_(z, nonlinearity=nonlinearity)
+ std = gain # bias has fan_in 1
+ assert np.isclose(z.std().detach().numpy(), std, rtol=1e-2)
+
+ torch.manual_seed(1234)
+ x_new = nn.Parameter(torch.empty(2, d_model, 137))
+ utils.init_kaiming_normal_(x_new, nonlinearity=nonlinearity)
+ assert torch.allclose(x_new, x, rtol=1e-2)
+
+
+class TestInitXavier:
+ """Test cases for Xavier init. Std of distribution should be scaled to sqrt(2/(fan_in + fan_out))."""
+
+ @pytest.mark.parametrize("d_model", [4096, 10_000])
+ @pytest.mark.parametrize("d_mlp", [256, 512])
+ def test_init_xavier_uniform(self, d_model, d_mlp):
+ """Test init_xavier_uniform function in the utils module on 3/2/1D tensors."""
+ torch.manual_seed(1234)
+
+ x = nn.Parameter(torch.empty(2, d_model, 137))
+ utils.init_xavier_uniform_(x)
+ std = np.sqrt(2 / (d_model + 137 * 2))
+ assert np.isclose(x.std().detach().numpy(), std, rtol=1e-2)
+ # for uniform distributions, min/max is sqrt(3) times the std
+ assert np.isclose(x.max().detach().numpy(), np.sqrt(3) * std, rtol=1e-2)
+ assert np.isclose(x.min().detach().numpy(), -np.sqrt(3) * std, rtol=1e-2)
+
+ y = nn.Parameter(torch.empty(d_mlp, d_model))
+ utils.init_xavier_uniform_(y)
+ std = np.sqrt(2 / (d_mlp + d_model))
+ assert np.isclose(y.std().detach().numpy(), std, rtol=1e-2)
+ # for uniform distributions, min/max is sqrt(3) times the std
+ assert np.isclose(y.max().detach().numpy(), np.sqrt(3) * std, rtol=1e-2)
+ assert np.isclose(y.min().detach().numpy(), -np.sqrt(3) * std, rtol=1e-2)
+
+ z = nn.Parameter(torch.empty(d_model * 123))
+ utils.init_xavier_uniform_(z)
+ std = np.sqrt(2 / (1 + d_model * 123))
+ assert np.isclose(z.std().detach().numpy(), std, rtol=1e-2)
+ # for uniform distributions, min/max is sqrt(3) times the std
+ assert np.isclose(z.max().detach().numpy(), np.sqrt(3) * std, rtol=1e-2)
+ assert np.isclose(z.min().detach().numpy(), -np.sqrt(3) * std, rtol=1e-2)
+
+ torch.manual_seed(1234)
+ x_new = nn.Parameter(torch.empty(2, d_model, 137))
+ utils.init_xavier_uniform_(x_new)
+ assert torch.allclose(x_new, x, rtol=1e-2)
+
+ @pytest.mark.parametrize("d_model", [4096, 10_000])
+ @pytest.mark.parametrize("d_mlp", [256, 512])
+ def test_init_xavier_normal(self, d_model, d_mlp):
+ """Test init_xavier_normal function in the utils module on 3/2/1D tensors."""
+ torch.manual_seed(1234)
+
+ x = nn.Parameter(torch.empty(2, d_model, 137))
+ utils.init_xavier_normal_(x)
+ std = np.sqrt(2 / (d_model + 137 * 2))
+ assert np.isclose(x.std().detach().numpy(), std, rtol=1e-2)
+
+ y = nn.Parameter(torch.empty(d_mlp, d_model))
+ utils.init_xavier_normal_(y)
+ std = np.sqrt(2 / (d_mlp + d_model))
+ assert np.isclose(y.std().detach().numpy(), std, rtol=1e-2)
+
+ z = nn.Parameter(torch.empty(d_model * 123)) # need to make this larger so std is accurate
+ utils.init_xavier_normal_(z)
+ std = np.sqrt(2 / (1 + d_model * 123))
+ assert np.isclose(z.std().detach().numpy(), std, rtol=1e-2)
+
+ torch.manual_seed(1234)
+ x_new = nn.Parameter(torch.empty(2, d_model, 137))
+ utils.init_xavier_normal_(x_new)
+ assert torch.allclose(x_new, x, rtol=1e-2)
diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py
index 7ee732005..caff121c3 100644
--- a/transformer_lens/ActivationCache.py
+++ b/transformer_lens/ActivationCache.py
@@ -10,11 +10,12 @@
class first, including the examples, and then skimming the available methods. You can then refer
back to these docs depending on what you need to do.
"""
+
from __future__ import annotations
import logging
import warnings
-from typing import Dict, Iterator, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
import einops
import numpy as np
@@ -113,9 +114,7 @@ class ActivationCache:
Whether the activations have a batch dimension.
"""
- def __init__(
- self, cache_dict: Dict[str, torch.Tensor], model, has_batch_dim: bool = True
- ):
+ def __init__(self, cache_dict: Dict[str, torch.Tensor], model, has_batch_dim: bool = True):
self.cache_dict = cache_dict
self.model = model
self.has_batch_dim = has_batch_dim
@@ -137,9 +136,7 @@ def remove_batch_dim(self) -> ActivationCache:
self.cache_dict[key] = self.cache_dict[key][0]
self.has_batch_dim = False
else:
- logging.warning(
- "Tried removing batch dimension after already having removed it."
- )
+ logging.warning("Tried removing batch dimension after already having removed it.")
return self
def __repr__(self) -> str:
@@ -206,9 +203,7 @@ def to(self, device: Union[str, torch.device], move_model=False) -> ActivationCa
DeprecationWarning,
)
- self.cache_dict = {
- key: value.to(device) for key, value in self.cache_dict.items()
- }
+ self.cache_dict = {key: value.to(device) for key, value in self.cache_dict.items()}
if move_model:
self.model.to(device)
@@ -276,7 +271,7 @@ def items(self):
"""
return self.cache_dict.items()
- def __iter__(self) -> Iterator[Tuple[str, torch.Tensor]]:
+ def __iter__(self) -> Iterator[str]:
"""ActivationCache Iterator.
Special method that returns an iterator over the ActivationCache. Allows looping over the
@@ -300,9 +295,7 @@ def __iter__(self) -> Iterator[Tuple[str, torch.Tensor]]:
"""
return self.cache_dict.__iter__()
- def apply_slice_to_batch_dim(
- self, batch_slice: Union[Slice, SliceInput]
- ) -> ActivationCache:
+ def apply_slice_to_batch_dim(self, batch_slice: Union[Slice, SliceInput]) -> ActivationCache:
"""Apply a Slice to the Batch Dimension.
Args:
@@ -314,31 +307,27 @@ def apply_slice_to_batch_dim(
"""
if not isinstance(batch_slice, Slice):
batch_slice = Slice(batch_slice)
+ batch_slice = cast(Slice, batch_slice) # mypy can't seem to infer this
assert (
self.has_batch_dim or batch_slice.mode == "empty"
), "Cannot index into a cache without a batch dim"
still_has_batch_dim = (batch_slice.mode != "int") and self.has_batch_dim
new_cache_dict = {
- name: batch_slice.apply(param, dim=0)
- for name, param in self.cache_dict.items()
+ name: batch_slice.apply(param, dim=0) for name, param in self.cache_dict.items()
}
- return ActivationCache(
- new_cache_dict, self.model, has_batch_dim=still_has_batch_dim
- )
+ return ActivationCache(new_cache_dict, self.model, has_batch_dim=still_has_batch_dim)
def accumulated_resid(
self,
layer: Optional[int] = None,
- incl_mid: Optional[bool] = False,
- apply_ln: Optional[bool] = False,
+ incl_mid: bool = False,
+ apply_ln: bool = False,
pos_slice: Optional[Union[Slice, SliceInput]] = None,
- mlp_input: Optional[bool] = False,
- return_labels: Optional[bool] = False,
+ mlp_input: bool = False,
+ return_labels: bool = False,
) -> Union[
Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"],
- Tuple[
- Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]
- ],
+ Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]],
]:
"""Accumulated Residual Stream.
@@ -438,19 +427,19 @@ def accumulated_resid(
layer = self.model.cfg.n_layers
assert isinstance(layer, int)
labels = []
- components = []
+ components_list = []
for l in range(layer + 1):
if l == self.model.cfg.n_layers:
- components.append(self[("resid_post", self.model.cfg.n_layers - 1)])
+ components_list.append(self[("resid_post", self.model.cfg.n_layers - 1)])
labels.append("final_post")
continue
- components.append(self[("resid_pre", l)])
+ components_list.append(self[("resid_pre", l)])
labels.append(f"{l}_pre")
if (incl_mid and l < layer) or (mlp_input and l == layer):
- components.append(self[("resid_mid", l)])
+ components_list.append(self[("resid_mid", l)])
labels.append(f"{l}_mid")
- components = [pos_slice.apply(c, dim=-2) for c in components]
- components = torch.stack(components, dim=0)
+ components_list = [pos_slice.apply(c, dim=-2) for c in components_list]
+ components = torch.stack(components_list, dim=0)
if apply_ln:
components = self.apply_ln_to_stack(
components, layer, pos_slice=pos_slice, mlp_input=mlp_input
@@ -462,9 +451,7 @@ def accumulated_resid(
def logit_attrs(
self,
- residual_stack: Float[
- torch.Tensor, "num_components *batch_and_pos_dims d_model"
- ],
+ residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
tokens: Union[
str,
int,
@@ -545,9 +532,7 @@ def logit_attrs(
if incorrect_tokens is not None:
if isinstance(incorrect_tokens, str):
- incorrect_tokens = torch.as_tensor(
- self.model.to_single_token(incorrect_tokens)
- )
+ incorrect_tokens = torch.as_tensor(self.model.to_single_token(incorrect_tokens))
elif isinstance(incorrect_tokens, int):
incorrect_tokens = torch.as_tensor(incorrect_tokens)
@@ -560,9 +545,8 @@ def logit_attrs(
)
# If incorrect_tokens was provided, take the logit difference
- logit_directions = (
- logit_directions
- - self.model.tokens_to_residual_directions(incorrect_tokens)
+ logit_directions = logit_directions - self.model.tokens_to_residual_directions(
+ incorrect_tokens
)
scaled_residual_stack = self.apply_ln_to_stack(
@@ -590,9 +574,7 @@ def decompose_resid(
return_labels: bool = False,
) -> Union[
Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"],
- Tuple[
- Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]
- ],
+ Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]],
]:
"""Decompose the Residual Stream.
@@ -607,9 +589,6 @@ def decompose_resid(
layer==n_layers means to return all layer outputs incl in the final layer, layer==0
means just embed and pos_embed. The indices are taken such that this gives the
accumulated streams up to the input to layer l
- incl_mid:
- Whether to return resid_mid for all previous
- layers.
mlp_input:
Whether to include attn_out for the current
layer - essentially decomposing the residual stream that's input to the MLP input
@@ -635,6 +614,7 @@ def decompose_resid(
"""
if not isinstance(pos_slice, Slice):
pos_slice = Slice(pos_slice)
+ pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this
if layer is None or layer == -1:
# Default to the residual stream immediately pre unembed
layer = self.model.cfg.n_layers
@@ -642,28 +622,28 @@ def decompose_resid(
incl_attn = mode != "mlp"
incl_mlp = mode != "attn" and not self.model.cfg.attn_only
- components = []
+ components_list = []
labels = []
if incl_embeds:
if self.has_embed:
- components = [self["hook_embed"]]
+ components_list = [self["hook_embed"]]
labels.append("embed")
if self.has_pos_embed:
- components.append(self["hook_pos_embed"])
+ components_list.append(self["hook_pos_embed"])
labels.append("pos_embed")
for l in range(layer):
if incl_attn:
- components.append(self[("attn_out", l)])
+ components_list.append(self[("attn_out", l)])
labels.append(f"{l}_attn_out")
if incl_mlp:
- components.append(self[("mlp_out", l)])
+ components_list.append(self[("mlp_out", l)])
labels.append(f"{l}_mlp_out")
if mlp_input and incl_attn:
- components.append(self[("attn_out", layer)])
+ components_list.append(self[("attn_out", layer)])
labels.append(f"{layer}_attn_out")
- components = [pos_slice.apply(c, dim=-2) for c in components]
- components = torch.stack(components, dim=0)
+ components_list = [pos_slice.apply(c, dim=-2) for c in components_list]
+ components = torch.stack(components_list, dim=0)
if apply_ln:
components = self.apply_ln_to_stack(
components, layer, pos_slice=pos_slice, mlp_input=mlp_input
@@ -684,9 +664,7 @@ def compute_head_results(
be useful if you forget.
"""
if "blocks.0.attn.hook_result" in self.cache_dict:
- logging.warning(
- "Tried to compute head results when they were already cached"
- )
+ logging.warning("Tried to compute head results when they were already cached")
return
for l in range(self.model.cfg.n_layers):
# Note that we haven't enabled set item on this object so we need to edit the underlying
@@ -727,6 +705,7 @@ def stack_head_results(
"""
if not isinstance(pos_slice, Slice):
pos_slice = Slice(pos_slice)
+ pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this
if layer is None or layer == -1:
# Default to the residual stream immediately pre unembed
layer = self.model.cfg.n_layers
@@ -737,7 +716,7 @@ def stack_head_results(
)
self.compute_head_results()
- components = []
+ components: Any = []
labels = []
for l in range(layer):
# Note that this has shape batch x pos x head_index x d_model
@@ -773,7 +752,7 @@ def stack_head_results(
components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice)
if return_labels:
- return components, labels
+ return components, labels # type: ignore # TODO: fix this properly
else:
return components
@@ -832,11 +811,9 @@ def get_neuron_results(
Returns:
Tensor of the results.
"""
- if type(neuron_slice) is not Slice:
- assert isinstance(neuron_slice, SliceInput)
+ if not isinstance(neuron_slice, Slice):
neuron_slice = Slice(neuron_slice)
- if type(pos_slice) is not Slice:
- assert isinstance(pos_slice, SliceInput)
+ if not isinstance(pos_slice, Slice):
pos_slice = Slice(pos_slice)
neuron_acts = self[("post", layer, "mlp")]
@@ -860,9 +837,7 @@ def stack_neuron_results(
apply_ln: bool = False,
) -> Union[
Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
- Tuple[
- Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]
- ],
+ Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]],
]:
"""Stack Neuron Results
@@ -894,7 +869,7 @@ def stack_neuron_results(
# Default to the residual stream immediately pre unembed
layer = self.model.cfg.n_layers
- components = []
+ components: Any = [] # TODO: fix typing properly
labels = []
if not isinstance(neuron_slice, Slice):
@@ -902,15 +877,15 @@ def stack_neuron_results(
if not isinstance(pos_slice, Slice):
pos_slice = Slice(pos_slice)
- neuron_labels = neuron_slice.apply(torch.arange(self.model.cfg.d_mlp), dim=0)
+ neuron_labels: torch.Tensor | np.ndarray = neuron_slice.apply(
+ torch.arange(self.model.cfg.d_mlp), dim=0
+ )
if type(neuron_labels) == int:
neuron_labels = np.array([neuron_labels])
for l in range(layer):
# Note that this has shape batch x pos x head_index x d_model
components.append(
- self.get_neuron_results(
- l, pos_slice=pos_slice, neuron_slice=neuron_slice
- )
+ self.get_neuron_results(l, pos_slice=pos_slice, neuron_slice=neuron_slice)
)
labels.extend([f"L{l}N{h}" for h in neuron_labels])
if components:
@@ -944,9 +919,7 @@ def stack_neuron_results(
def apply_ln_to_stack(
self,
- residual_stack: Float[
- torch.Tensor, "num_components *batch_and_pos_dims d_model"
- ],
+ residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
layer: Optional[int] = None,
mlp_input: bool = False,
pos_slice: Union[Slice, SliceInput] = None,
@@ -1059,6 +1032,7 @@ def get_full_resid_decomposition(
if layer is None or layer == -1:
# Default to the residual stream immediately pre unembed
layer = self.model.cfg.n_layers
+ assert layer is not None # keep mypy happy
if not isinstance(pos_slice, Slice):
pos_slice = Slice(pos_slice)
@@ -1096,9 +1070,7 @@ def get_full_resid_decomposition(
labels.append("pos_embed")
components.append(pos_slice.apply(self["pos_embed"], -2)[None])
# If we didn't expand the neurons, the MLP biases are already included in the MLP outputs.
- bias = self.model.accumulated_bias(
- layer, mlp_input, include_mlp_biases=expand_neurons
- )
+ bias = self.model.accumulated_bias(layer, mlp_input, include_mlp_biases=expand_neurons)
bias = bias.expand((1,) + head_stack.shape[1:])
labels.append("bias")
components.append(bias)
@@ -1109,6 +1081,6 @@ def get_full_resid_decomposition(
)
if return_labels:
- return residual_stack, labels
+ return residual_stack, labels # type: ignore # TODO: fix this properly
else:
return residual_stack
diff --git a/transformer_lens/FactoredMatrix.py b/transformer_lens/FactoredMatrix.py
index c097646c4..1e1c813a6 100644
--- a/transformer_lens/FactoredMatrix.py
+++ b/transformer_lens/FactoredMatrix.py
@@ -3,10 +3,11 @@
Utilities for representing a matrix as a product of two matrices, and for efficient calculation of
eigenvalues, norm and SVD.
"""
+
from __future__ import annotations
from functools import lru_cache
-from typing import List, Tuple, Union
+from typing import List, Tuple, Union, overload
import torch
from jaxtyping import Float
@@ -40,6 +41,23 @@ def __init__(
self.A = self.A.broadcast_to(self.shape[:-2] + (self.ldim, self.mdim))
self.B = self.B.broadcast_to(self.shape[:-2] + (self.mdim, self.rdim))
+ @overload
+ def __matmul__(
+ self,
+ other: Union[
+ Float[torch.Tensor, "... rdim new_rdim"],
+ "FactoredMatrix",
+ ],
+ ) -> "FactoredMatrix":
+ ...
+
+ @overload
+ def __matmul__( # type: ignore
+ self,
+ other: Float[torch.Tensor, "rdim"],
+ ) -> Float[torch.Tensor, "... ldim"]:
+ ...
+
def __matmul__(
self,
other: Union[
@@ -64,7 +82,24 @@ def __matmul__(
elif isinstance(other, FactoredMatrix):
return (self @ other.A) @ other.B
- def __rmatmul__(
+ @overload
+ def __rmatmul__( # type: ignore
+ self,
+ other: Union[
+ Float[torch.Tensor, "... new_rdim ldim"],
+ "FactoredMatrix",
+ ],
+ ) -> "FactoredMatrix":
+ ...
+
+ @overload
+ def __rmatmul__( # type: ignore
+ self,
+ other: Float[torch.Tensor, "ldim"],
+ ) -> Float[torch.Tensor, "... rdim"]:
+ ...
+
+ def __rmatmul__( # type: ignore
self,
other: Union[
Float[torch.Tensor, "... new_rdim ldim"],
@@ -96,7 +131,7 @@ def __mul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix:
), f"Tensor must be a scalar for use with * but was of shape {scalar.shape}. For matrix multiplication, use @ instead."
return FactoredMatrix(self.A * scalar, self.B)
- def __rmul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix:
+ def __rmul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix: # type: ignore
"""
Right scalar multiplication. For scalar multiplication from the right, we can reuse the __mul__ method.
"""
@@ -183,9 +218,7 @@ def __getitem__(self, idx: Union[int, Tuple]) -> FactoredMatrix:
elif length == len(self.shape):
idx = self._convert_to_slice(idx, -1)
idx = self._convert_to_slice(idx, -2)
- return FactoredMatrix(
- self.A[idx[:-1]], self.B[idx[:-2] + (slice(None), idx[-1])]
- )
+ return FactoredMatrix(self.A[idx[:-1]], self.B[idx[:-2] + (slice(None), idx[-1])])
else:
raise ValueError(
f"{idx} is too long an index for a FactoredMatrix with shape {self.shape}"
diff --git a/transformer_lens/HookedEncoder.py b/transformer_lens/HookedEncoder.py
index 99cdc1cdd..59ede19af 100644
--- a/transformer_lens/HookedEncoder.py
+++ b/transformer_lens/HookedEncoder.py
@@ -3,9 +3,11 @@
Contains a BERT style model. This is separate from :class:`transformer_lens.HookedTransformer`
because it has a significantly different architecture to e.g. GPT style transformers.
"""
+
from __future__ import annotations
import logging
+import os
from typing import Dict, List, Optional, Tuple, Union, cast, overload
import torch
@@ -16,9 +18,11 @@
from typing_extensions import Literal
import transformer_lens.loading_from_pretrained as loading
-from transformer_lens import ActivationCache, FactoredMatrix, HookedTransformerConfig
+from transformer_lens.ActivationCache import ActivationCache
from transformer_lens.components import BertBlock, BertEmbed, BertMLMHead, Unembed
+from transformer_lens.FactoredMatrix import FactoredMatrix
from transformer_lens.hook_points import HookedRootModule, HookPoint
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.utilities import devices
@@ -45,29 +49,27 @@ def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs):
)
self.cfg = cfg
- assert (
- self.cfg.n_devices == 1
- ), "Multiple devices not supported for HookedEncoder"
+ assert self.cfg.n_devices == 1, "Multiple devices not supported for HookedEncoder"
if tokenizer is not None:
self.tokenizer = tokenizer
elif self.cfg.tokenizer_name is not None:
- self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.tokenizer_name)
+ huggingface_token = os.environ.get("HF_TOKEN", None)
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ self.cfg.tokenizer_name,
+ token=huggingface_token,
+ )
else:
self.tokenizer = None
if self.cfg.d_vocab == -1:
# If we have a tokenizer, vocab size can be inferred from it.
- assert (
- self.tokenizer is not None
- ), "Must provide a tokenizer if d_vocab is not provided"
+ assert self.tokenizer is not None, "Must provide a tokenizer if d_vocab is not provided"
self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1
if self.cfg.d_vocab_out == -1:
self.cfg.d_vocab_out = self.cfg.d_vocab
self.embed = BertEmbed(self.cfg)
- self.blocks = nn.ModuleList(
- [BertBlock(self.cfg) for _ in range(self.cfg.n_layers)]
- )
+ self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)])
self.mlm_head = BertMLMHead(cfg)
self.unembed = Unembed(self.cfg)
@@ -130,9 +132,7 @@ def forward(
else None
)
additive_attention_mask = (
- torch.where(mask == 1, large_negative_number, 0)
- if mask is not None
- else None
+ torch.where(mask == 1, large_negative_number, 0) if mask is not None else None
)
for block in self.blocks:
@@ -140,7 +140,7 @@ def forward(
resid = self.mlm_head(resid)
if return_type is None:
- return
+ return None
logits = self.unembed(resid)
return logits
@@ -153,7 +153,7 @@ def run_with_cache(
@overload
def run_with_cache(
- self, *model_args, return_cache_object: Literal[False] = False, **kwargs
+ self, *model_args, return_cache_object: Literal[False], **kwargs
) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]:
...
@@ -174,14 +174,12 @@ def run_with_cache(
*model_args, remove_batch_dim=remove_batch_dim, **kwargs
)
if return_cache_object:
- cache = ActivationCache(
- cache_dict, self, has_batch_dim=not remove_batch_dim
- )
+ cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim)
return out, cache
else:
return out, cache_dict
- def to(
+ def to( # type: ignore
self,
device_or_dtype: Union[torch.device, str, torch.dtype],
print_details: bool = True,
@@ -270,6 +268,9 @@ def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]:
@property
def b_U(self) -> Float[torch.Tensor, "d_vocab"]:
+ """
+ Convenience to get the unembedding bias
+ """
return self.unembed.b_U
@property
@@ -296,98 +297,74 @@ def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]:
@property
def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
"""Stacks the key weights across all layers"""
- return torch.stack(
- [cast(BertBlock, block).attn.W_K for block in self.blocks], dim=0
- )
+ return torch.stack([cast(BertBlock, block).attn.W_K for block in self.blocks], dim=0)
@property
def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
"""Stacks the query weights across all layers"""
- return torch.stack(
- [cast(BertBlock, block).attn.W_Q for block in self.blocks], dim=0
- )
+ return torch.stack([cast(BertBlock, block).attn.W_Q for block in self.blocks], dim=0)
@property
def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
"""Stacks the value weights across all layers"""
- return torch.stack(
- [cast(BertBlock, block).attn.W_V for block in self.blocks], dim=0
- )
+ return torch.stack([cast(BertBlock, block).attn.W_V for block in self.blocks], dim=0)
@property
def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]:
"""Stacks the attn output weights across all layers"""
- return torch.stack(
- [cast(BertBlock, block).attn.W_O for block in self.blocks], dim=0
- )
+ return torch.stack([cast(BertBlock, block).attn.W_O for block in self.blocks], dim=0)
@property
def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]:
"""Stacks the MLP input weights across all layers"""
- return torch.stack(
- [cast(BertBlock, block).mlp.W_in for block in self.blocks], dim=0
- )
+ return torch.stack([cast(BertBlock, block).mlp.W_in for block in self.blocks], dim=0)
@property
def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]:
"""Stacks the MLP output weights across all layers"""
- return torch.stack(
- [cast(BertBlock, block).mlp.W_out for block in self.blocks], dim=0
- )
+ return torch.stack([cast(BertBlock, block).mlp.W_out for block in self.blocks], dim=0)
@property
def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
"""Stacks the key biases across all layers"""
- return torch.stack(
- [cast(BertBlock, block).attn.b_K for block in self.blocks], dim=0
- )
+ return torch.stack([cast(BertBlock, block).attn.b_K for block in self.blocks], dim=0)
@property
def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
"""Stacks the query biases across all layers"""
- return torch.stack(
- [cast(BertBlock, block).attn.b_Q for block in self.blocks], dim=0
- )
+ return torch.stack([cast(BertBlock, block).attn.b_Q for block in self.blocks], dim=0)
@property
def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
"""Stacks the value biases across all layers"""
- return torch.stack(
- [cast(BertBlock, block).attn.b_V for block in self.blocks], dim=0
- )
+ return torch.stack([cast(BertBlock, block).attn.b_V for block in self.blocks], dim=0)
@property
def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]:
"""Stacks the attn output biases across all layers"""
- return torch.stack(
- [cast(BertBlock, block).attn.b_O for block in self.blocks], dim=0
- )
+ return torch.stack([cast(BertBlock, block).attn.b_O for block in self.blocks], dim=0)
@property
def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]:
"""Stacks the MLP input biases across all layers"""
- return torch.stack(
- [cast(BertBlock, block).mlp.b_in for block in self.blocks], dim=0
- )
+ return torch.stack([cast(BertBlock, block).mlp.b_in for block in self.blocks], dim=0)
@property
def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]:
"""Stacks the MLP output biases across all layers"""
- return torch.stack(
- [cast(BertBlock, block).mlp.b_out for block in self.blocks], dim=0
- )
+ return torch.stack([cast(BertBlock, block).mlp.b_out for block in self.blocks], dim=0)
@property
def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model]
+ """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head.
+ Useful for visualizing attention patterns."""
return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1))
@property
def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model]
+ """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head."""
return FactoredMatrix(self.W_V, self.W_O)
def all_head_labels(self) -> List[str]:
- return [
- f"L{l}H{h}"
- for l in range(self.cfg.n_layers)
- for h in range(self.cfg.n_heads)
- ]
+ """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index."""
+ return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)]
diff --git a/transformer_lens/HookedSAE.py b/transformer_lens/HookedSAE.py
new file mode 100644
index 000000000..df9b29d05
--- /dev/null
+++ b/transformer_lens/HookedSAE.py
@@ -0,0 +1,118 @@
+from typing import Dict, Union
+
+import einops
+import torch
+import torch.nn.functional as F
+from jaxtyping import Float
+from torch import nn
+
+from transformer_lens.hook_points import ( # Hooking utilities
+ HookedRootModule,
+ HookPoint,
+)
+from transformer_lens.HookedSAEConfig import HookedSAEConfig
+
+
+class HookedSAE(HookedRootModule):
+ """Hooked SAE.
+
+ Implements a standard SAE with a TransformerLens hooks for SAE activations
+
+ Designed for inference / analysis, not training. For training, see Joseph Bloom's SAELens (https://github.com/jbloomAus/SAELens)
+
+ Note that HookedSAETransformer is fairly modular, and doesn't make strong assumptions about the architecture of the SAEs that get attached. We provide HookedSAE as a useful default class, but if you want to eg experiment with other SAE architectures, you can just copy the HookedSAE code into a notebook, edit it, and add instances of the new SAE class to a HookedSAETransformer (e.g. with HookedSAETransformer.add_sae(sae))
+ """
+
+ def __init__(self, cfg: Union[HookedSAEConfig, Dict]):
+ super().__init__()
+ if isinstance(cfg, Dict):
+ cfg = HookedSAEConfig(**cfg)
+ elif isinstance(cfg, str):
+ raise ValueError("Please pass in a config dictionary or HookedSAEConfig object.")
+ self.cfg = cfg
+
+ self.W_enc = nn.Parameter(
+ torch.nn.init.kaiming_uniform_(
+ torch.empty(self.cfg.d_in, self.cfg.d_sae, dtype=self.cfg.dtype)
+ )
+ )
+ self.W_dec = nn.Parameter(
+ torch.nn.init.kaiming_uniform_(
+ torch.empty(self.cfg.d_sae, self.cfg.d_in, dtype=self.cfg.dtype)
+ )
+ )
+ self.b_enc = nn.Parameter(torch.zeros(self.cfg.d_sae, dtype=self.cfg.dtype))
+ self.b_dec = nn.Parameter(torch.zeros(self.cfg.d_in, dtype=self.cfg.dtype))
+
+ self.hook_sae_input = HookPoint()
+ self.hook_sae_acts_pre = HookPoint()
+ self.hook_sae_acts_post = HookPoint()
+ self.hook_sae_recons = HookPoint()
+ self.hook_sae_error = HookPoint()
+ self.hook_sae_output = HookPoint()
+
+ self.to(self.cfg.device)
+ self.setup()
+
+ def forward(self, input: Float[torch.Tensor, "... d_in"]) -> Float[torch.Tensor, "... d_in"]:
+ """SAE Forward Pass.
+
+ Args:
+ input: The input tensor of activations to the SAE. Shape [..., d_in].
+ Also supports hook_z activations of shape [..., n_heads, d_head], where n_heads * d_head = d_in, for attention output (hook_z) SAEs.
+
+ Returns:
+ output: The reconstructed output tensor from the SAE, with the error term optionally added. Same shape as input (eg [..., d_in])
+ """
+ self.hook_sae_input(input)
+ if input.shape[-1] == self.cfg.d_in:
+ x = input
+ else:
+ # Assume this this is an attention output (hook_z) SAE
+ assert self.cfg.hook_name.endswith(
+ "_z"
+ ), f"You passed in an input shape {input.shape} does not match SAE input size {self.cfg.d_in} for hook_name {self.cfg.hook_name}. This is only supported for attn output (hook_z) SAEs."
+ x = einops.rearrange(input, "... n_heads d_head -> ... (n_heads d_head)")
+ assert (
+ x.shape[-1] == self.cfg.d_in
+ ), f"Input shape {x.shape} does not match SAE input size {self.cfg.d_in}"
+
+ x_cent = x - self.b_dec
+ # WARNING: if editing this block of code, also edit the error computation inside `if self.cfg.use_error_term`
+ sae_acts_pre = self.hook_sae_acts_pre(
+ einops.einsum(x_cent, self.W_enc, "... d_in, d_in d_sae -> ... d_sae")
+ + self.b_enc # [..., d_sae]
+ )
+ sae_acts_post = self.hook_sae_acts_post(F.relu(sae_acts_pre)) # [..., d_sae]
+ x_reconstruct = self.hook_sae_recons(
+ (
+ einops.einsum(sae_acts_post, self.W_dec, "... d_sae, d_sae d_in -> ... d_in")
+ + self.b_dec
+ ).reshape(input.shape)
+ )
+ # END WARNING
+
+ if self.cfg.use_error_term:
+ with torch.no_grad():
+ # Recompute everything without hooks to get true error term
+ # Otherwise, the output with error term will always equal input, even for causal interventions that affect x_reconstruct
+ # This is in a no_grad context to detach the error, so we can compute SAE feature gradients (eg for attribution patching). See A.3 in https://arxiv.org/pdf/2403.19647.pdf for more detail
+ # NOTE: we can't just use `sae_error = input - x_reconstruct.detach()` or something simpler, since this would mean intervening on features would mean ablating features still results in perfect reconstruction.
+ sae_acts_pre_clean = (
+ einops.einsum(x_cent, self.W_enc, "... d_in, d_in d_sae -> ... d_sae")
+ + self.b_enc
+ ) # [..., d_sae]
+ sae_acts_post_clean = F.relu(sae_acts_pre_clean)
+ x_reconstruct_clean = (
+ einops.einsum(
+ sae_acts_post_clean,
+ self.W_dec,
+ "... d_sae, d_sae d_in -> ... d_in",
+ )
+ + self.b_dec
+ ).reshape(input.shape)
+
+ sae_error = self.hook_sae_error(input - x_reconstruct_clean)
+ return self.hook_sae_output(x_reconstruct + sae_error)
+
+ return self.hook_sae_output(x_reconstruct)
diff --git a/transformer_lens/HookedSAEConfig.py b/transformer_lens/HookedSAEConfig.py
new file mode 100644
index 000000000..2892329e4
--- /dev/null
+++ b/transformer_lens/HookedSAEConfig.py
@@ -0,0 +1,64 @@
+from __future__ import annotations
+
+import pprint
+import random
+from dataclasses import dataclass
+from typing import Any, Dict, Optional
+
+import numpy as np
+import torch
+
+from transformer_lens import utils
+
+
+@dataclass
+class HookedSAEConfig:
+ """
+ Configuration class to store the configuration of a HookedSAE model.
+
+ Args:
+ d_sae (int): The size of the dictionary.
+ d_in (int): The dimension of the input activations.
+ hook_name (str): The hook name of the activation the SAE was trained on (eg. blocks.0.attn.hook_z)
+ use_error_term (bool): Whether to use the error term in the loss function. Defaults to False.
+ dtype (torch.dtype, *optional*): The SAE's dtype. Defaults to torch.float32.
+ seed (int, *optional*): The seed to use for the SAE.
+ Used to set sources of randomness (Python, PyTorch and
+ NumPy) and to initialize weights. Defaults to None. We recommend setting a seed, so your experiments are reproducible.
+ device(str): The device to use for the SAE. Defaults to 'cuda' if
+ available, else 'cpu'.
+ """
+
+ d_sae: int
+ d_in: int
+ hook_name: str
+ use_error_term: bool = False
+ dtype: torch.dtype = torch.float32
+ seed: Optional[int] = None
+ device: Optional[str] = None
+
+ def __post_init__(self):
+ if self.seed is not None:
+ self.set_seed_everywhere(self.seed)
+
+ if self.device is None:
+ self.device = utils.get_device()
+
+ @classmethod
+ def from_dict(cls, config_dict: Dict[str, Any]) -> HookedSAEConfig:
+ """
+ Instantiates a `HookedSAEConfig` from a Python dictionary of
+ parameters.
+ """
+ return cls(**config_dict)
+
+ def to_dict(self):
+ return self.__dict__
+
+ def __repr__(self):
+ return "HookedSAEConfig:\n" + pprint.pformat(self.to_dict())
+
+ def set_seed_everywhere(self, seed: int):
+ torch.manual_seed(seed)
+ random.seed(seed)
+ np.random.seed(seed)
diff --git a/transformer_lens/HookedSAETransformer.py b/transformer_lens/HookedSAETransformer.py
new file mode 100644
index 000000000..47e88ebb9
--- /dev/null
+++ b/transformer_lens/HookedSAETransformer.py
@@ -0,0 +1,290 @@
+import logging
+from contextlib import contextmanager
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from jaxtyping import Float
+
+from transformer_lens.ActivationCache import ActivationCache
+from transformer_lens.hook_points import HookPoint # Hooking utilities
+from transformer_lens.HookedSAE import HookedSAE
+from transformer_lens.HookedTransformer import HookedTransformer
+
+SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor
+LossPerToken = Float[torch.Tensor, "batch pos-1"]
+Loss = Union[SingleLoss, LossPerToken]
+
+
+def get_deep_attr(obj: Any, path: str):
+ """Helper function to get a nested attribute from a object.
+ In practice used to access HookedTransformer HookPoints (eg model.blocks[0].attn.hook_z)
+
+ Args:
+ obj: Any object. In practice, this is a HookedTransformer (or subclass)
+ path: str. The path to the attribute you want to access. (eg "blocks.0.attn.hook_z")
+
+ returns:
+ Any. The attribute at the end of the path
+ """
+ parts = path.split(".")
+ # Navigate to the last component in the path
+ for part in parts:
+ if part.isdigit(): # This is a list index
+ obj = obj[int(part)]
+ else: # This is an attribute
+ obj = getattr(obj, part)
+ return obj
+
+
+def set_deep_attr(obj: Any, path: str, value: Any):
+ """Helper function to change the value of a nested attribute from a object.
+ In practice used to swap HookedTransformer HookPoints (eg model.blocks[0].attn.hook_z) with HookedSAEs and vice versa
+
+ Args:
+ obj: Any object. In practice, this is a HookedTransformer (or subclass)
+ path: str. The path to the attribute you want to access. (eg "blocks.0.attn.hook_z")
+ value: Any. The value you want to set the attribute to (eg a HookedSAE object)
+ """
+ parts = path.split(".")
+ # Navigate to the last component in the path
+ for part in parts[:-1]:
+ if part.isdigit(): # This is a list index
+ obj = obj[int(part)]
+ else: # This is an attribute
+ obj = getattr(obj, part)
+ # Set the value on the final attribute
+ setattr(obj, parts[-1], value)
+
+
+class HookedSAETransformer(HookedTransformer):
+ def __init__(
+ self,
+ *model_args,
+ **model_kwargs,
+ ):
+ """Model initialization. Just HookedTransformer init, but adds a dictionary to keep track of attached SAEs.
+
+ Note that if you want to load the model from pretrained weights, you should use
+ :meth:`from_pretrained` instead.
+
+ Args:
+ *model_args: Positional arguments for HookedTransformer initialization
+ **model_kwargs: Keyword arguments for HookedTransformer initialization
+ """
+ super().__init__(*model_args, **model_kwargs)
+ self.acts_to_saes: Dict[str, HookedSAE] = {}
+
+ def add_sae(self, sae: HookedSAE):
+ """Attaches an SAE to the model
+
+ WARNING: This sae will be permanantly attached until you remove it with reset_saes. This function will also overwrite any existing SAE attached to the same hook point.
+
+ Args:
+ sae: HookedSAE. The SAE to attach to the model
+ """
+ act_name = sae.cfg.hook_name
+ if (act_name not in self.acts_to_saes) and (act_name not in self.hook_dict):
+ logging.warning(
+ f"No hook found for {act_name}. Skipping. Check model.hook_dict for available hooks."
+ )
+ return
+
+ self.acts_to_saes[act_name] = sae
+ set_deep_attr(self, act_name, sae)
+ self.setup()
+
+ def _reset_sae(self, act_name: str, prev_sae: Optional[HookedSAE] = None):
+ """Resets an SAE that was attached to the model
+
+ By default will remove the SAE from that hook_point.
+ If prev_sae is provided, will replace the current SAE with the provided one.
+ This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes)
+
+ Args:
+ act_name: str. The hook_name of the SAE to reset
+ prev_sae: Optional[HookedSAE]. The SAE to replace the current one with. If None, will just remove the SAE from this hook point. Defaults to None
+ """
+ if act_name not in self.acts_to_saes:
+ logging.warning(f"No SAE is attached to {act_name}. There's nothing to reset.")
+ return
+
+ if prev_sae:
+ set_deep_attr(self, act_name, prev_sae)
+ self.acts_to_saes[act_name] = prev_sae
+ else:
+ set_deep_attr(self, act_name, HookPoint())
+ del self.acts_to_saes[act_name]
+
+ def reset_saes(
+ self,
+ act_names: Optional[Union[str, List[str]]] = None,
+ prev_saes: Optional[List[Union[HookedSAE, None]]] = None,
+ ):
+ """Reset the SAEs attached to the model
+
+ If act_names are provided will just reset SAEs attached to those hooks. Otherwise will reset all SAEs attached to the model.
+ Optionally can provide a list of prev_saes to reset to. This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes).
+
+ Args:
+ act_names (Optional[Union[str, List[str]]): The act_names of the SAEs to reset. If None, will reset all SAEs attached to the model. Defaults to None.
+ prev_saes (Optional[List[Union[HookedSAE, None]]]): List of SAEs to replace the current ones with. If None, will just remove the SAEs. Defaults to None.
+ """
+ if isinstance(act_names, str):
+ act_names = [act_names]
+ elif act_names is None:
+ act_names = list(self.acts_to_saes.keys())
+
+ if prev_saes:
+ assert len(act_names) == len(
+ prev_saes
+ ), "act_names and prev_saes must have the same length"
+ else:
+ prev_saes = [None] * len(act_names)
+
+ for act_name, prev_sae in zip(act_names, prev_saes):
+ self._reset_sae(act_name, prev_sae)
+
+ self.setup()
+
+ def run_with_saes(
+ self,
+ *model_args,
+ saes: Union[HookedSAE, List[HookedSAE]] = [],
+ reset_saes_end: bool = True,
+ **model_kwargs,
+ ) -> Union[
+ None,
+ Float[torch.Tensor, "batch pos d_vocab"],
+ Loss,
+ Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
+ ]:
+ """Wrapper around HookedTransformer forward pass.
+
+ Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.
+
+ Args:
+ *model_args: Positional arguments for the model forward pass
+ saes: (Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass
+ reset_saes_end (bool): If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.
+ **model_kwargs: Keyword arguments for the model forward pass
+ """
+ with self.saes(saes=saes, reset_saes_end=reset_saes_end):
+ return self(*model_args, **model_kwargs)
+
+ def run_with_cache_with_saes(
+ self,
+ *model_args,
+ saes: Union[HookedSAE, List[HookedSAE]] = [],
+ reset_saes_end: bool = True,
+ return_cache_object=True,
+ remove_batch_dim=False,
+ **kwargs,
+ ) -> Tuple[
+ Union[
+ None,
+ Float[torch.Tensor, "batch pos d_vocab"],
+ Loss,
+ Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
+ ],
+ Union[ActivationCache, Dict[str, torch.Tensor]],
+ ]:
+ """Wrapper around 'run_with_cache' in HookedTransformer.
+
+ Attaches given SAEs before running the model with cache and then removes them.
+ By default, will reset all SAEs to original state after.
+
+ Args:
+ *model_args: Positional arguments for the model forward pass
+ saes: (Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass
+ reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.
+ return_cache_object: (bool) if True, this will return an ActivationCache object, with a bunch of
+ useful HookedTransformer specific methods, otherwise it will return a dictionary of
+ activations as in HookedRootModule.
+ remove_batch_dim: (bool) Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
+ **kwargs: Keyword arguments for the model forward pass
+ """
+ with self.saes(saes=saes, reset_saes_end=reset_saes_end):
+ return self.run_with_cache(
+ *model_args,
+ return_cache_object=return_cache_object,
+ remove_batch_dim=remove_batch_dim,
+ **kwargs,
+ )
+
+ def run_with_hooks_with_saes(
+ self,
+ *model_args,
+ saes: Union[HookedSAE, List[HookedSAE]] = [],
+ reset_saes_end: bool = True,
+ fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],
+ bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],
+ reset_hooks_end=True,
+ clear_contexts=False,
+ **model_kwargs,
+ ):
+ """Wrapper around 'run_with_hooks' in HookedTransformer.
+
+ Attaches the given SAEs to the model before running the model with hooks and then removes them.
+ By default, will reset all SAEs to original state after.
+
+ Args:
+ *model_args: Positional arguments for the model forward pass
+ act_names: (Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass
+ reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. (default: True)
+ fwd_hooks: (List[Tuple[Union[str, Callable], Callable]]) List of forward hooks to apply
+ bwd_hooks: (List[Tuple[Union[str, Callable], Callable]]) List of backward hooks to apply
+ reset_hooks_end: (bool) Whether to reset the hooks at the end of the forward pass (default: True)
+ clear_contexts: (bool) Whether to clear the contexts at the end of the forward pass (default: False)
+ **model_kwargs: Keyword arguments for the model forward pass
+ """
+ with self.saes(saes=saes, reset_saes_end=reset_saes_end):
+ return self.run_with_hooks(
+ *model_args,
+ fwd_hooks=fwd_hooks,
+ bwd_hooks=bwd_hooks,
+ reset_hooks_end=reset_hooks_end,
+ clear_contexts=clear_contexts,
+ **model_kwargs,
+ )
+
+ @contextmanager
+ def saes(
+ self,
+ saes: Union[HookedSAE, List[HookedSAE]] = [],
+ reset_saes_end: bool = True,
+ ):
+ """
+ A context manager for adding temporary SAEs to the model.
+ See HookedTransformer.hooks for a similar context manager for hooks.
+ By default will keep track of previously attached SAEs, and restore them when the context manager exits.
+
+ Example:
+
+ .. code-block:: python
+
+ from transformer_lens import HookedSAETransformer, HookedSAE, HookedSAEConfig
+
+ model = HookedSAETransformer.from_pretrained('gpt2-small')
+ sae_cfg = HookedSAEConfig(...)
+ sae = HookedSAE(sae_cfg)
+ with model.saes(saes=[sae]):
+ spliced_logits = model(text)
+
+
+ Args:
+ saes (Union[HookedSAE, List[HookedSAE]]): SAEs to be attached.
+ reset_saes_end (bool): If True, removes all SAEs added by this context manager when the context manager exits, returning previously attached SAEs to their original state.
+ """
+ act_names_to_reset = []
+ prev_saes = []
+ if isinstance(saes, HookedSAE):
+ saes = [saes]
+ try:
+ for sae in saes:
+ act_names_to_reset.append(sae.cfg.hook_name)
+ prev_saes.append(self.acts_to_saes.get(sae.cfg.hook_name, None))
+ self.add_sae(sae)
+ yield self
+ finally:
+ if reset_saes_end:
+ self.reset_saes(act_names_to_reset, prev_saes)
diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py
index de8a6f27e..baf32ad05 100644
--- a/transformer_lens/HookedTransformer.py
+++ b/transformer_lens/HookedTransformer.py
@@ -8,8 +8,10 @@
alteration of activations in individual components like attention heads and MLP layers, facilitating
a deeper understanding of the internal workings of transformers like GPT-2.
"""
+
import logging
-from typing import Dict, List, NamedTuple, Optional, Tuple, Union, overload
+import os
+from typing import Dict, List, NamedTuple, Optional, Tuple, Union, cast, overload
import einops
import numpy as np
@@ -18,12 +20,12 @@
import tqdm.auto as tqdm
from fancy_einsum import einsum
from jaxtyping import Float, Int
+from packaging import version
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
from typing_extensions import Literal
import transformer_lens.loading_from_pretrained as loading
import transformer_lens.utils as utils
-from transformer_lens import HookedTransformerConfig
from transformer_lens.ActivationCache import ActivationCache
from transformer_lens.components import (
Embed,
@@ -37,12 +39,20 @@
)
from transformer_lens.FactoredMatrix import FactoredMatrix
from transformer_lens.hook_points import HookedRootModule, HookPoint
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES
# Note - activation cache is used with run_with_cache, past_key_value_caching is used for
# generation.
from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCache
from transformer_lens.utilities import devices
-from transformer_lens.utils import USE_DEFAULT_VALUE
+from transformer_lens.utils import (
+ USE_DEFAULT_VALUE,
+ init_kaiming_normal_,
+ init_kaiming_uniform_,
+ init_xavier_normal_,
+ init_xavier_uniform_,
+)
SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor
LossPerToken = Float[torch.Tensor, "batch pos-1"]
@@ -82,6 +92,8 @@ class HookedTransformer(HookedRootModule):
investigating. This can be done with :func:`transformer_lens.utils.test_prompt`.
"""
+ ln_final: nn.Module
+
def __init__(
self,
cfg: Union[HookedTransformerConfig, Dict],
@@ -112,28 +124,39 @@ def __init__(
"Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a "
"pretrained model, use HookedTransformer.from_pretrained() instead."
)
- self.cfg = cfg
+ self.cfg: HookedTransformerConfig = cfg
if tokenizer is not None:
self.set_tokenizer(tokenizer, default_padding_side=default_padding_side)
elif self.cfg.tokenizer_name is not None:
# If we have a tokenizer name, we can load it from HuggingFace
- if "llama" in self.cfg.tokenizer_name.lower():
- # llama tokenizer requires special handling
- logging.warning("LLaMA tokenizer not loaded. Please load manually.")
+ if self.cfg.tokenizer_name in NON_HF_HOSTED_MODEL_NAMES:
+ logging.warning(
+ "%s tokenizer not loaded. Please load manually.",
+ self.cfg.tokenizer_name,
+ )
else:
+ # Hugging Face defaults to use_fast to True
+ use_fast = True
+ # Phi model's fast tokenizer does not support adding a BOS token, use_fast
+ # should be False
+ if "phi" in self.cfg.tokenizer_name.lower():
+ use_fast = False
+ huggingface_token = os.environ.get("HF_TOKEN", None)
self.set_tokenizer(
AutoTokenizer.from_pretrained(
- self.cfg.tokenizer_name, add_bos_token=True
+ self.cfg.tokenizer_name,
+ add_bos_token=True,
+ trust_remote_code=self.cfg.trust_remote_code,
+ use_fast=use_fast,
+ token=huggingface_token,
),
default_padding_side=default_padding_side,
)
else:
# If no tokenizer name is provided, we assume we're training on an algorithmic task and
# will pass in tokens directly. In this case, we don't need a tokenizer.
- assert (
- self.cfg.d_vocab != -1
- ), "Must provide a tokenizer if d_vocab is not provided"
+ assert self.cfg.d_vocab != -1, "Must provide a tokenizer if d_vocab is not provided"
self.tokenizer = None
if default_padding_side != "right":
logging.warning(
@@ -151,10 +174,7 @@ def __init__(
self.hook_tokens = HookPoint() # [batch, pos]
self.blocks = nn.ModuleList(
- [
- TransformerBlock(self.cfg, block_index)
- for block_index in range(self.cfg.n_layers)
- ]
+ [TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg.n_layers)]
)
if self.cfg.normalization_type == "RMS":
@@ -176,9 +196,7 @@ def __init__(
# If it's None, don't create either layer
pass
else:
- logging.warning(
- f"Invalid normalization_type passed in {self.cfg.normalization_type}"
- )
+ logging.warning("Invalid normalization_type passed in %s", self.cfg.normalization_type)
self.unembed = Unembed(self.cfg)
if self.cfg.init_weights:
@@ -229,9 +247,7 @@ def input_to_embed(
self,
input: Union[str, List[str], Int[torch.Tensor, "batch pos"]],
prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
- padding_side: Optional[
- Union[Literal["left", "right"], None]
- ] = USE_DEFAULT_VALUE,
+ padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
) -> Tuple[
Float[torch.Tensor, "batch pos d_model"], # residual
@@ -253,15 +269,13 @@ def input_to_embed(
past_kv_cache (HookedTransformerKeyValueCache, optional): If passed, we're doing caching
and attention_mask will be stored in the cache.
"""
- if type(input) == str or type(input) == list:
+ if isinstance(input, str) or isinstance(input, list):
# If text, convert to tokens (batch_size=1)
assert (
self.tokenizer is not None
), "Must provide a tokenizer if passing a string to the model"
# This is only intended to support passing in a single string
- tokens = self.to_tokens(
- input, prepend_bos=prepend_bos, padding_side=padding_side
- )
+ tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
else:
tokens = input
if len(tokens.shape) == 1:
@@ -270,18 +284,14 @@ def input_to_embed(
if tokens.device.type != self.cfg.device:
tokens = tokens.to(devices.get_device_for_block_index(0, self.cfg))
- if (
- self.tokenizer and self.tokenizer.padding_side == "left"
- ) or past_kv_cache is not None:
- # If the padding side is left or we are using caching, we need to compute the attention mask
- # for the adjustment of absolute positional embeddings and attention masking so that pad
- # tokens are not attended.
+ if (self.tokenizer and self.tokenizer.padding_side == "left") or past_kv_cache is not None:
+ # If the padding side is left or we are using caching, we need to compute the attention
+ # mask for the adjustment of absolute positional embeddings and attention masking so
+ # that pad tokens are not attended.
if prepend_bos is USE_DEFAULT_VALUE:
prepend_bos = self.cfg.default_prepend_bos
- attention_mask = utils.get_attention_mask(
- self.tokenizer, tokens, prepend_bos
- )
+ attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos)
if past_kv_cache is not None:
# past_kv_cache is not None, so we're doing caching.
@@ -309,7 +319,10 @@ def input_to_embed(
d_head_in_cache,
) = past_kv_cache[0].past_keys.shape
assert cached_batch_size == batch_size
- assert num_heads_in_cache == self.cfg.n_heads
+ if self.cfg.n_key_value_heads is None:
+ assert num_heads_in_cache == self.cfg.n_heads
+ else:
+ assert num_heads_in_cache == self.cfg.n_key_value_heads
assert d_head_in_cache == self.cfg.d_head
pos_offset = cache_ctx_length
if self.cfg.use_hook_tokens:
@@ -349,16 +362,12 @@ def forward(
self,
input,
return_type: Literal["logits"],
- loss_per_token: Optional[bool] = False,
+ loss_per_token: bool = False,
prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
- padding_side: Optional[
- Union[Literal["left", "right"], None]
- ] = USE_DEFAULT_VALUE,
+ padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
start_at_layer: Optional[int] = None,
tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
- shortformer_pos_embed: Optional[
- Float[torch.Tensor, "batch pos d_model"]
- ] = None,
+ shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
attention_mask: Optional[torch.Tensor] = None, # [batch pos]
stop_at_layer: Optional[int] = None,
past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
@@ -370,16 +379,12 @@ def forward(
self,
input,
return_type: Literal["loss"],
- loss_per_token: Optional[bool] = False,
+ loss_per_token: bool = False,
prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
- padding_side: Optional[
- Union[Literal["left", "right"], None]
- ] = USE_DEFAULT_VALUE,
+ padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
start_at_layer: Optional[int] = None,
tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
- shortformer_pos_embed: Optional[
- Float[torch.Tensor, "batch pos d_model"]
- ] = None,
+ shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
attention_mask: Optional[torch.Tensor] = None, # [batch pos]
stop_at_layer: Optional[int] = None,
past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
@@ -391,16 +396,12 @@ def forward(
self,
input,
return_type: Literal["both"],
- loss_per_token: Optional[bool] = False,
+ loss_per_token: bool = False,
prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
- padding_side: Optional[
- Union[Literal["left", "right"], None]
- ] = USE_DEFAULT_VALUE,
+ padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
start_at_layer: Optional[int] = None,
tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
- shortformer_pos_embed: Optional[
- Float[torch.Tensor, "batch pos d_model"]
- ] = None,
+ shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
attention_mask: Optional[torch.Tensor] = None, # [batch pos]
stop_at_layer: Optional[int] = None,
past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
@@ -412,16 +413,12 @@ def forward(
self,
input,
return_type: Literal[None],
- loss_per_token: Optional[bool] = False,
+ loss_per_token: bool = False,
prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
- padding_side: Optional[
- Union[Literal["left", "right"], None]
- ] = USE_DEFAULT_VALUE,
+ padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
start_at_layer: Optional[int] = None,
tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
- shortformer_pos_embed: Optional[
- Float[torch.Tensor, "batch pos d_model"]
- ] = None,
+ shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
attention_mask: Optional[torch.Tensor] = None, # [batch pos]
stop_at_layer: Optional[int] = None,
past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
@@ -437,14 +434,12 @@ def forward(
Float[torch.Tensor, "batch pos d_model"],
],
return_type: Optional[str] = "logits",
- loss_per_token: Optional[bool] = False,
+ loss_per_token: bool = False,
prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
start_at_layer: Optional[int] = None,
tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
- shortformer_pos_embed: Optional[
- Float[torch.Tensor, "batch pos d_model"]
- ] = None,
+ shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
attention_mask: Optional[torch.Tensor] = None, # [batch pos]
stop_at_layer: Optional[int] = None,
past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
@@ -511,7 +506,7 @@ def forward(
tokens that have already been through the model. Also caches attention_mask so
previous tokens are masked correctly (unless frozen). Padding should be ignored in
all cases, so it's okay to eg. pass in left padded tokens twice in a row.
- Warning: Don't accidently prepend_bos to the second half of a prompt.
+ Warning: Don't accidentally prepend_bos to the second half of a prompt.
Defaults to None (don't use caching).
"""
@@ -556,9 +551,7 @@ def forward(
residual,
# Cache contains a list of HookedTransformerKeyValueCache objects, one for each
# block
- past_kv_cache_entry=past_kv_cache[i]
- if past_kv_cache is not None
- else None,
+ past_kv_cache_entry=past_kv_cache[i] if past_kv_cache is not None else None,
shortformer_pos_embed=shortformer_pos_embed,
attention_mask=attention_mask,
) # [batch, pos, d_model]
@@ -610,7 +603,7 @@ def run_with_cache(
@overload
def run_with_cache(
- self, *model_args, return_cache_object: Literal[False] = False, **kwargs
+ self, *model_args, return_cache_object: Literal[False], **kwargs
) -> Tuple[Output, Dict[str, torch.Tensor]]:
...
@@ -635,9 +628,7 @@ def run_with_cache(
*model_args, remove_batch_dim=remove_batch_dim, **kwargs
)
if return_cache_object:
- cache = ActivationCache(
- cache_dict, self, has_batch_dim=not remove_batch_dim
- )
+ cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim)
return out, cache
else:
return out, cache_dict
@@ -670,6 +661,7 @@ def set_tokenizer(
# (https://github.com/huggingface/transformers/issues/25886).
tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer)
self.tokenizer = tokenizer_with_bos
+ assert self.tokenizer is not None # keep mypy happy
self.tokenizer.padding_side = default_padding_side
# Some tokenizers doesn't automatically prepend the BOS token even when they are initialized
@@ -693,11 +685,9 @@ def to_tokens(
self,
input: Union[str, List[str]],
prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
- padding_side: Optional[
- Union[Literal["left", "right"], None]
- ] = USE_DEFAULT_VALUE,
- move_to_device: Optional[bool] = True,
- truncate: Optional[bool] = True,
+ padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
+ move_to_device: bool = True,
+ truncate: bool = True,
) -> Int[torch.Tensor, "batch pos"]:
"""Converts a string to a tensor of tokens.
@@ -731,18 +721,14 @@ def to_tokens(
with utils.LocallyOverridenDefaults(
self, prepend_bos=prepend_bos, padding_side=padding_side
):
- assert (
- self.tokenizer is not None
- ), "Cannot use to_tokens without a tokenizer"
+ assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer"
assert (
self.cfg.tokenizer_prepends_bos is not None
), "Set the tokenizer for the model by calling set_tokenizer"
if self.cfg.default_prepend_bos and not self.cfg.tokenizer_prepends_bos:
# We want to prepend bos but the tokenizer doesn't automatically do it, so we add it manually
- input = utils.get_input_with_manually_prepended_bos(
- self.tokenizer, input
- )
+ input = utils.get_input_with_manually_prepended_bos(self.tokenizer, input)
tokens = self.tokenizer(
input,
@@ -787,9 +773,7 @@ def to_string(
# it's set, then tokenization is no longer invertible, and some tokens
# with a bunch of whitespace get collapsed together
if len(tokens.shape) == 2:
- return self.tokenizer.batch_decode(
- tokens, clean_up_tokenization_spaces=False
- )
+ return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False)
elif len(tokens.shape) <= 1:
return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False)
else:
@@ -806,9 +790,7 @@ def to_str_tokens(
list,
],
prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
- padding_side: Optional[
- Union[Literal["left", "right"], None]
- ] = USE_DEFAULT_VALUE,
+ padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
) -> Union[List[str], List[List[str]]]:
"""Map text, a list of text or tokens to a list of tokens as strings.
@@ -844,19 +826,22 @@ def to_str_tokens(
with utils.LocallyOverridenDefaults(
self, prepend_bos=prepend_bos, padding_side=padding_side
):
+ assert self.tokenizer is not None # keep mypy happy
+ tokens: Union[np.ndarray, torch.Tensor]
if isinstance(input, list):
return list(
map(
- lambda tokens: self.to_str_tokens(
- tokens, prepend_bos, padding_side
- ),
+ lambda tokens: self.to_str_tokens(tokens, prepend_bos, padding_side),
input,
)
) # type: ignore
elif isinstance(input, str):
- tokens = self.to_tokens(
- input, prepend_bos=prepend_bos, padding_side=padding_side
- )[0]
+ tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[
+ 0
+ ]
+ # Gemma tokenizer expects a batch dimension
+ if "gemma" in self.tokenizer.name_or_path and tokens.ndim == 1:
+ tokens = tokens.unsqueeze(1)
elif isinstance(input, torch.Tensor):
tokens = input
tokens = tokens.squeeze() # Get rid of a trivial batch dimension
@@ -877,9 +862,7 @@ def to_str_tokens(
), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
else:
raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}")
- str_tokens = self.tokenizer.batch_decode(
- tokens, clean_up_tokenization_spaces=False
- )
+ str_tokens = self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False)
return str_tokens
def to_single_token(self, string):
@@ -899,19 +882,15 @@ def to_single_str_token(self, int_token: int) -> str:
assert isinstance(int_token, int)
token = self.to_str_tokens(torch.tensor([int_token]))
assert len(token) == 1
- return token[0]
+ return cast(str, token[0])
def get_token_position(
self,
single_token: Union[str, int],
- input: Union[
- str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]]
- ],
+ input: Union[str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]]],
mode="first",
prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
- padding_side: Optional[
- Union[Literal["left", "right"], None]
- ] = USE_DEFAULT_VALUE,
+ padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
):
"""Get the position of a single_token in a string or sequence of tokens.
@@ -942,9 +921,7 @@ def get_token_position(
"""
if isinstance(input, str):
# If the input is a string, convert to tensor
- tokens = self.to_tokens(
- input, prepend_bos=prepend_bos, padding_side=padding_side
- )
+ tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
else:
tokens = input
@@ -961,10 +938,8 @@ def get_token_position(
elif isinstance(single_token, torch.Tensor):
single_token = single_token.item()
- indices = torch.arange(len(tokens), device=tokens.device)[
- tokens == single_token
- ]
- assert len(indices) > 0, f"The token does not occur in the prompt"
+ indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token]
+ assert len(indices) > 0, "The token does not occur in the prompt"
if mode == "first":
return indices[0].item()
elif mode == "last":
@@ -1029,10 +1004,10 @@ def tokens_to_residual_directions(
residual_direction = self.W_U[:, token]
return residual_direction
- def to(
+ def to( # type: ignore
self,
device_or_dtype: Union[torch.device, str, torch.dtype],
- print_details: Optional[bool] = True,
+ print_details: bool = True,
):
return devices.move_to_and_update_config(self, device_or_dtype, print_details)
@@ -1055,12 +1030,8 @@ def move_model_modules_to_device(self):
self.pos_embed.to(devices.get_device_for_block_index(0, self.cfg))
self.hook_pos_embed.to(devices.get_device_for_block_index(0, self.cfg))
if hasattr(self, "ln_final"):
- self.ln_final.to(
- devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg)
- )
- self.unembed.to(
- devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg)
- )
+ self.ln_final.to(devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg))
+ self.unembed.to(devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg))
for i, block in enumerate(self.blocks):
block.to(devices.get_device_for_block_index(i, self.cfg))
@@ -1068,20 +1039,20 @@ def move_model_modules_to_device(self):
def from_pretrained(
cls,
model_name: str,
- fold_ln: Optional[bool] = True,
- center_writing_weights: Optional[bool] = True,
- center_unembed: Optional[bool] = True,
- refactor_factored_attn_matrices: Optional[bool] = False,
+ fold_ln: bool = True,
+ center_writing_weights: bool = True,
+ center_unembed: bool = True,
+ refactor_factored_attn_matrices: bool = False,
checkpoint_index: Optional[int] = None,
checkpoint_value: Optional[int] = None,
hf_model: Optional[AutoModelForCausalLM] = None,
device: Optional[Union[str, torch.device]] = None,
- n_devices: Optional[int] = 1,
+ n_devices: int = 1,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
- move_to_device: Optional[bool] = True,
- fold_value_biases: Optional[bool] = True,
- default_prepend_bos: Optional[bool] = True,
- default_padding_side: Optional[Literal["left", "right"]] = "right",
+ move_to_device: bool = True,
+ fold_value_biases: bool = True,
+ default_prepend_bos: bool = True,
+ default_padding_side: Literal["left", "right"] = "right",
dtype="float32",
**from_pretrained_kwargs,
) -> "HookedTransformer":
@@ -1218,11 +1189,32 @@ def from_pretrained(
default_padding_side: Which side to pad on when tokenizing. Defaults to
"right".
"""
+
assert not (
from_pretrained_kwargs.get("load_in_8bit", False)
or from_pretrained_kwargs.get("load_in_4bit", False)
), "Quantization not supported"
+ if hf_model is not None:
+ hf_cfg = hf_model.config.to_dict()
+ qc = hf_cfg.get("quantization_config", {})
+ load_in_4bit = qc.get("load_in_4bit", False)
+ load_in_8bit = qc.get("load_in_8bit", False)
+ quant_method = qc.get("quant_method", "")
+ assert not load_in_8bit, "8-bit quantization is not supported"
+ assert not (
+ load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1"))
+ ), "Quantization is only supported for torch versions >= 2.1.1"
+ assert not (
+ load_in_4bit and ("llama" not in model_name.lower())
+ ), "Quantization is only supported for Llama models"
+ if load_in_4bit:
+ assert (
+ qc.get("quant_method", "") == "bitsandbytes"
+ ), "Only bitsandbytes quantization is supported"
+ else:
+ hf_cfg = {}
+
if isinstance(dtype, str):
# Convert from string to a torch dtype
dtype = DTYPE_FROM_STRING[dtype]
@@ -1235,9 +1227,7 @@ def from_pretrained(
(from_pretrained_kwargs.get("torch_dtype", None) == torch.float16)
or dtype == torch.float16
) and device in ["cpu", None]:
- logging.warning(
- "float16 models may not work on CPU. Consider using a GPU or bfloat16."
- )
+ logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.")
# Get the model name used in HuggingFace, rather than the alias.
official_model_name = loading.get_official_model_name(model_name)
@@ -1247,6 +1237,7 @@ def from_pretrained(
# checkpoint
cfg = loading.get_pretrained_model_config(
official_model_name,
+ hf_cfg=hf_cfg,
checkpoint_index=checkpoint_index,
checkpoint_value=checkpoint_value,
fold_ln=fold_ln,
@@ -1342,10 +1333,6 @@ def from_pretrained_no_processing(
def init_weights(self):
"""Initialize weights.
- Initialize weights matrices with a normal of std=initializer_range (default=0.02). This
- roughly follows the GPT-2 paper's scheme (but with truncation, and not halving the std for
- W_pos).
-
LayerNorm weights are already initialized to 1.0, and all biases are initialized to 0.0
(including LayerNorm), so this just initializes weight matrices.
@@ -1358,28 +1345,124 @@ def init_weights(self):
This does NOT follow the PyTorch scheme, which as far as I can tell is super out of date but
no one has gotten round to updating it? https://github.com/pytorch/pytorch/issues/18182
+ The default PyTorch scheme is the following: all linear layers use uniform(-1/sqrt(fan_in),
+ 1/sqrt(fan_in)) for weights, and uniform(-1/sqrt(fan_in), 1/sqrt(fan_in)) for biases. For
+ biases, fan_in is computed using the fan_in for the weight matrix of the linear layer. Note
+ tha it *does not actually* use Kaiming initialization, despite the fact that it calls the
+ function.
+
+ However, for Transformer blocks, it instead initializes biases to zero and weights using Xavier uniform, that
+ is: uniform(-sqrt(6 / (fan_in + fan_out)), sqrt(6 / (fan_in + fan_out))) for weights.
+
PyTorch Transformers are especially bad - TransformerEncoder initializes all layers to the
- exact same weights?! https://github.com/pytorch/pytorch/issues/72253
+ exact same weights?! https://github.com/pytorch/pytorch/issues/72253.
The best paper I've found on transformer initialization is the muP paper, but haven't
integrated those ideas yet: https://arxiv.org/abs/2203.03466
+
+ We split off the initialization into separate functions because muP initialization handles
+ different parts of the model differently.
"""
if self.cfg.seed is not None:
torch.manual_seed(self.cfg.seed)
+ if self.cfg.init_mode == "gpt2":
+ self._init_weights_gpt2()
+ elif self.cfg.init_mode == "xavier_uniform":
+ self._init_weights_xavier(dist_type="uniform")
+ elif self.cfg.init_mode == "xavier_normal":
+ self._init_weights_xavier(dist_type="normal")
+ elif self.cfg.init_mode == "kaiming_uniform":
+ self._init_weights_kaiming(dist_type="uniform")
+ elif self.cfg.init_mode == "kaiming_normal":
+ self._init_weights_kaiming(dist_type="normal")
+ elif self.cfg.init_mode == "muP":
+ self._init_weights_muP(dist_type="normal") # muP uses normal initialization
+
+ def _init_weights_gpt2(self):
+ """Initialize weights with GPT-2 initialization. Biases are initialized to 0.0 and weights
+ are initialized to N(0, 0.64/d_model) if initializer_range is not set, otherwise std is initializer_range.
+ """
for name, param in self.named_parameters():
if "W_" in name:
nn.init.normal_(param, std=self.cfg.initializer_range)
+ def _init_weights_xavier(self, dist_type="normal"):
+ """
+ Initialize weights with Xavier initialization -- that is, scale the weights by sqrt(6 /
+ (fan_in + fan_out)) for a [-1, 1] uniform distribution, or sqrt(2 / (fan_in + fan_out)) for a
+ standard normal.
+
+ Note that since TransformerLens implements the matrices in the opposite orientation to what
+ torch does (e.g. it's d_in x d_out, not d_out x d_in as in torch), we need to calculate it
+ ourselves.
+ """
+ gain = self.cfg.initializer_range
+ for name, param in self.named_parameters():
+ if "W_" in name:
+ if dist_type == "uniform":
+ init_xavier_uniform_(param, gain=gain)
+ elif dist_type == "normal":
+ init_xavier_normal_(param, gain=gain)
+
+ def _init_weights_kaiming(self, dist_type="uniform"):
+ """
+ Initialize weights with Kaiming initialization -- that is, scale the weights by
+ c / sqrt(fan_in), where c = sqrt(2) if the params were immediately preceded by a relu and 1 for
+ everything else.
+
+ Note that the numbers are actually incorrect here when you're using a nonlinearity other
+ than relu, e.g. the correct c for SiLu is ~1.74, for tanh it's 5/3 ~= 1.67, and for GeLU it's ~1.57.
+ But this is unlikely to matter in practice.
+
+ I'm just using fan_mode = "fan_in" for now, but it should be trivial to add fan_out.
+
+ Again, we have to implement it ourselves because of the orientation of the matrices.
+ """
+ gain = self.cfg.initializer_range
+ for name, param in self.named_parameters():
+ if "W_" in name:
+ if dist_type == "uniform":
+ init_kaiming_uniform_(param, gain=gain, nonlinearity="relu", mode="fan_in")
+ elif dist_type == "normal":
+ init_kaiming_normal_(param, gain=gain, nonlinearity="relu", mode="fan_in")
+
+ def _init_weights_muP(self, dist_type="uniform"):
+ """
+ Initialize weights with muParameterization. This involves scaling output weights by a factor
+ of 1/fan_in, input weights and biases by 1, everything else by a factor of 1/sqrt(fan_in).
+
+ Also, you need to use muAdamW, which rescales the learning rate for output weights and
+ hidden weights by a factor of 1/fan_in.
+
+ All biases are still assumed to be initialized to 0.0, so we only need to change the
+ weights.
+ """
+ for name, param in self.named_parameters():
+ if "W_" in name:
+ fan_in, _ = utils.calc_fan_in_and_fan_out(param)
+ if "embed" in name:
+ scale = float(1)
+ elif "unembed" in name:
+ scale = 1 / fan_in
+ else:
+ scale = 1 / fan_in**0.5
+
+ if dist_type == "uniform":
+ scale *= 3**0.5
+ nn.init.uniform_(param, -scale, scale)
+ elif dist_type == "normal":
+ nn.init.normal_(param, std=scale)
+
def load_and_process_state_dict(
self,
state_dict: Dict[str, torch.Tensor],
- fold_ln: Optional[bool] = True,
- center_writing_weights: Optional[bool] = True,
- center_unembed: Optional[bool] = True,
- fold_value_biases: Optional[bool] = True,
- refactor_factored_attn_matrices: Optional[bool] = False,
+ fold_ln: bool = True,
+ center_writing_weights: bool = True,
+ center_unembed: bool = True,
+ fold_value_biases: bool = True,
+ refactor_factored_attn_matrices: bool = False,
):
"""Load & Process State Dict.
@@ -1414,16 +1497,32 @@ def load_and_process_state_dict(
"With reduced precision, it is advised to use `from_pretrained_no_processing` instead of `from_pretrained`."
)
+ if (
+ self.cfg.dtype not in [torch.float32, torch.float64]
+ and self.cfg.num_experts
+ and self.cfg.num_experts > 1
+ ):
+ logging.warning(
+ "When running MoE models, it is advised to use a higher precision data type. See docs for more info."
+ )
+
state_dict = self.fill_missing_keys(state_dict)
if fold_ln:
- if self.cfg.normalization_type not in ["LN", "LNPre"]:
+ if self.cfg.num_experts and self.cfg.num_experts > 1:
logging.warning(
- "You are not using LayerNorm, so the layer norm weights can't be folded! Skipping"
+ "You are using MoE, so the layer norm weights can't be folded! Skipping"
)
- else:
- # Note - you can run fold_layer_norm while normalization_type is LN, but this is not advised! It mostly
- # goes wrong when you're training the model.
+ elif self.cfg.normalization_type in ["LN", "LNPre"]:
state_dict = self.fold_layer_norm(state_dict)
+ elif self.cfg.normalization_type in ["RMS", "RMSPre"]:
+ state_dict = self.fold_layer_norm(
+ state_dict, fold_biases=False, center_weights=False
+ )
+ else:
+ logging.warning(
+ "You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping"
+ )
+
if center_writing_weights:
if self.cfg.normalization_type not in ["LN", "LNPre"]:
logging.warning(
@@ -1435,19 +1534,28 @@ def load_and_process_state_dict(
)
else:
state_dict = self.center_writing_weights(state_dict)
+
if center_unembed:
state_dict = self.center_unembed(state_dict)
if fold_value_biases:
state_dict = self.fold_value_biases(state_dict)
if refactor_factored_attn_matrices:
state_dict = self.refactor_factored_attn_matrices(state_dict)
- self.load_state_dict(state_dict)
+
+ if self.cfg.load_in_4bit:
+ # with quantization, parameters should be assigned
+ # so that quantization settings are not lost
+ self.load_state_dict(state_dict, assign=True, strict=False)
+ else:
+ self.load_state_dict(state_dict, strict=False)
def fill_missing_keys(self, state_dict):
return loading.fill_missing_keys(self, state_dict)
- def fold_layer_norm(self, state_dict: Dict[str, torch.Tensor]):
- """Fold Layer Norm.
+ def fold_layer_norm(
+ self, state_dict: Dict[str, torch.Tensor], fold_biases=True, center_weights=True
+ ):
+ """Fold Layer Norm. Can also be used to fold RMS Norm, when fold_biases and center_weights are set to False.
Takes in a state dict from a pretrained model, formatted to be consistent with
HookedTransformer but with LayerNorm weights and biases. Folds these into the neighbouring
@@ -1455,132 +1563,155 @@ def fold_layer_norm(self, state_dict: Dict[str, torch.Tensor]):
Args:
state_dict (Dict[str, torch.Tensor]): State dict of pretrained model.
+ fold_biases (bool): Enables folding of LN biases. Should be disabled when RMS Norm is used.
+ center_weights (bool): Enables the centering of weights after folding in LN. Should be disabled when RMS Norm is used.
"""
+
+ # Models that use Grouped Query Attention (Only Mistral at the time of writing) prefix their K/V weights and
+ # biases with an underscore in order to distinguish them, but folding the LN into them still works the same,
+ # so we just add the underscore if GQA is used (i.e. if `cfg.n_key_value_heads is specified`).
+ gqa = "" if self.cfg.n_key_value_heads is None else "_"
+
for l in range(self.cfg.n_layers):
# Fold ln1 into attention - it's important to fold biases first, since biases depend on
# weights but not vice versa The various indexing is just to broadcast ln.b and ln.w
# along every axis other than d_model. Each weight matrix right multiplies. To fold in
# the bias, we use the W_ matrix to map it to the hidden space of the layer, so we need
# to sum along axis -2, which is the residual stream space axis.
- state_dict[f"blocks.{l}.attn.b_Q"] = state_dict[f"blocks.{l}.attn.b_Q"] + (
- state_dict[f"blocks.{l}.attn.W_Q"]
- * state_dict[f"blocks.{l}.ln1.b"][None, :, None]
- ).sum(-2)
- state_dict[f"blocks.{l}.attn.b_K"] = state_dict[f"blocks.{l}.attn.b_K"] + (
- state_dict[f"blocks.{l}.attn.W_K"]
- * state_dict[f"blocks.{l}.ln1.b"][None, :, None]
- ).sum(-2)
- state_dict[f"blocks.{l}.attn.b_V"] = state_dict[f"blocks.{l}.attn.b_V"] + (
- state_dict[f"blocks.{l}.attn.W_V"]
- * state_dict[f"blocks.{l}.ln1.b"][None, :, None]
- ).sum(-2)
+ if fold_biases:
+ state_dict[f"blocks.{l}.attn.b_Q"] = state_dict[f"blocks.{l}.attn.b_Q"] + (
+ state_dict[f"blocks.{l}.attn.W_Q"]
+ * state_dict[f"blocks.{l}.ln1.b"][None, :, None]
+ ).sum(-2)
+ state_dict[f"blocks.{l}.attn.{gqa}b_K"] = state_dict[
+ f"blocks.{l}.attn.{gqa}b_K"
+ ] + (
+ state_dict[f"blocks.{l}.attn.{gqa}W_K"]
+ * state_dict[f"blocks.{l}.ln1.b"][None, :, None]
+ ).sum(
+ -2
+ )
+ state_dict[f"blocks.{l}.attn.{gqa}b_V"] = state_dict[
+ f"blocks.{l}.attn.{gqa}b_V"
+ ] + (
+ state_dict[f"blocks.{l}.attn.{gqa}W_V"]
+ * state_dict[f"blocks.{l}.ln1.b"][None, :, None]
+ ).sum(
+ -2
+ )
+ del state_dict[f"blocks.{l}.ln1.b"]
state_dict[f"blocks.{l}.attn.W_Q"] = (
- state_dict[f"blocks.{l}.attn.W_Q"]
- * state_dict[f"blocks.{l}.ln1.w"][None, :, None]
+ state_dict[f"blocks.{l}.attn.W_Q"] * state_dict[f"blocks.{l}.ln1.w"][None, :, None]
)
- state_dict[f"blocks.{l}.attn.W_K"] = (
- state_dict[f"blocks.{l}.attn.W_K"]
+ state_dict[f"blocks.{l}.attn.{gqa}W_K"] = (
+ state_dict[f"blocks.{l}.attn.{gqa}W_K"]
* state_dict[f"blocks.{l}.ln1.w"][None, :, None]
)
- state_dict[f"blocks.{l}.attn.W_V"] = (
- state_dict[f"blocks.{l}.attn.W_V"]
+ state_dict[f"blocks.{l}.attn.{gqa}W_V"] = (
+ state_dict[f"blocks.{l}.attn.{gqa}W_V"]
* state_dict[f"blocks.{l}.ln1.w"][None, :, None]
)
+ del state_dict[f"blocks.{l}.ln1.w"]
# Finally, we center the weights reading from the residual stream. The output of the
# first part of the LayerNorm is mean 0 and standard deviation 1, so the mean of any
# input vector of the matrix doesn't matter and can be set to zero. Equivalently, the
# output of LayerNormPre is orthogonal to the vector of all 1s (because dotting with
# that gets the sum), so we can remove the component of the matrix parallel to this.
- state_dict[f"blocks.{l}.attn.W_Q"] -= einops.reduce(
- state_dict[f"blocks.{l}.attn.W_Q"],
- "head_index d_model d_head -> head_index 1 d_head",
- "mean",
- )
- state_dict[f"blocks.{l}.attn.W_K"] -= einops.reduce(
- state_dict[f"blocks.{l}.attn.W_K"],
- "head_index d_model d_head -> head_index 1 d_head",
- "mean",
- )
- state_dict[f"blocks.{l}.attn.W_V"] -= einops.reduce(
- state_dict[f"blocks.{l}.attn.W_V"],
- "head_index d_model d_head -> head_index 1 d_head",
- "mean",
- )
-
- del (
- state_dict[f"blocks.{l}.ln1.w"],
- state_dict[f"blocks.{l}.ln1.b"],
- )
+ if center_weights:
+ state_dict[f"blocks.{l}.attn.W_Q"] -= einops.reduce(
+ state_dict[f"blocks.{l}.attn.W_Q"],
+ "head_index d_model d_head -> head_index 1 d_head",
+ "mean",
+ )
+ state_dict[f"blocks.{l}.attn.{gqa}W_K"] -= einops.reduce(
+ state_dict[f"blocks.{l}.attn.{gqa}W_K"],
+ "head_index d_model d_head -> head_index 1 d_head",
+ "mean",
+ )
+ state_dict[f"blocks.{l}.attn.{gqa}W_V"] -= einops.reduce(
+ state_dict[f"blocks.{l}.attn.{gqa}W_V"],
+ "head_index d_model d_head -> head_index 1 d_head",
+ "mean",
+ )
# Fold ln2 into MLP
if not self.cfg.attn_only:
- state_dict[f"blocks.{l}.mlp.b_in"] = state_dict[
- f"blocks.{l}.mlp.b_in"
- ] + (
- state_dict[f"blocks.{l}.mlp.W_in"]
- * state_dict[f"blocks.{l}.ln2.b"][:, None]
- ).sum(
- -2
- )
+ if fold_biases:
+ state_dict[f"blocks.{l}.mlp.b_in"] = state_dict[f"blocks.{l}.mlp.b_in"] + (
+ state_dict[f"blocks.{l}.mlp.W_in"]
+ * state_dict[f"blocks.{l}.ln2.b"][:, None]
+ ).sum(-2)
+ del state_dict[f"blocks.{l}.ln2.b"]
+
state_dict[f"blocks.{l}.mlp.W_in"] = (
- state_dict[f"blocks.{l}.mlp.W_in"]
- * state_dict[f"blocks.{l}.ln2.w"][:, None]
+ state_dict[f"blocks.{l}.mlp.W_in"] * state_dict[f"blocks.{l}.ln2.w"][:, None]
)
- # Center the weights that read in from the LayerNormPre
- state_dict[f"blocks.{l}.mlp.W_in"] -= einops.reduce(
- state_dict[f"blocks.{l}.mlp.W_in"],
- "d_model d_mlp -> 1 d_mlp",
- "mean",
- )
+ if self.cfg.gated_mlp:
+ state_dict[f"blocks.{l}.mlp.W_gate"] = (
+ state_dict[f"blocks.{l}.mlp.W_gate"]
+ * state_dict[f"blocks.{l}.ln2.w"][:, None]
+ )
- del state_dict[f"blocks.{l}.ln2.w"], state_dict[f"blocks.{l}.ln2.b"]
+ del state_dict[f"blocks.{l}.ln2.w"]
- if self.cfg.act_fn.startswith("solu"):
- # Fold ln3 into activation
- state_dict[f"blocks.{l}.mlp.b_out"] = state_dict[
- f"blocks.{l}.mlp.b_out"
- ] + (
- state_dict[f"blocks.{l}.mlp.W_out"]
- * state_dict[f"blocks.{l}.mlp.ln.b"][:, None]
- ).sum(
- -2
+ if center_weights:
+ # Center the weights that read in from the LayerNormPre
+ state_dict[f"blocks.{l}.mlp.W_in"] -= einops.reduce(
+ state_dict[f"blocks.{l}.mlp.W_in"],
+ "d_model d_mlp -> 1 d_mlp",
+ "mean",
)
+
+ if self.cfg.act_fn is not None and self.cfg.act_fn.startswith("solu"):
+ # Fold ln3 into activation
+ if fold_biases:
+ state_dict[f"blocks.{l}.mlp.b_out"] = state_dict[
+ f"blocks.{l}.mlp.b_out"
+ ] + (
+ state_dict[f"blocks.{l}.mlp.W_out"]
+ * state_dict[f"blocks.{l}.mlp.ln.b"][:, None]
+ ).sum(
+ -2
+ )
+
+ del state_dict[f"blocks.{l}.mlp.ln.b"]
+
state_dict[f"blocks.{l}.mlp.W_out"] = (
state_dict[f"blocks.{l}.mlp.W_out"]
* state_dict[f"blocks.{l}.mlp.ln.w"][:, None]
)
- # Center the weights that read in from the LayerNormPre
- state_dict[f"blocks.{l}.mlp.W_out"] -= einops.reduce(
- state_dict[f"blocks.{l}.mlp.W_out"],
- "d_mlp d_model -> 1 d_model",
- "mean",
- )
- del (
- state_dict[f"blocks.{l}.mlp.ln.w"],
- state_dict[f"blocks.{l}.mlp.ln.b"],
- )
+ if center_weights:
+ # Center the weights that read in from the LayerNormPre
+ state_dict[f"blocks.{l}.mlp.W_out"] -= einops.reduce(
+ state_dict[f"blocks.{l}.mlp.W_out"],
+ "d_mlp d_model -> 1 d_model",
+ "mean",
+ )
+
+ del state_dict[f"blocks.{l}.mlp.ln.w"]
+
# Fold ln_final into Unembed
- if not self.cfg.final_rms:
+ if not self.cfg.final_rms and fold_biases:
# Dumb bug from my old SoLU training code, some models have RMSNorm instead of LayerNorm
# pre unembed.
state_dict[f"unembed.b_U"] = state_dict[f"unembed.b_U"] + (
state_dict[f"unembed.W_U"] * state_dict[f"ln_final.b"][:, None]
).sum(dim=-2)
del state_dict[f"ln_final.b"]
- state_dict[f"unembed.W_U"] = (
- state_dict[f"unembed.W_U"] * state_dict[f"ln_final.w"][:, None]
- )
-
- # Center the weights that read in from the LayerNormPre
- state_dict[f"unembed.W_U"] -= einops.reduce(
- state_dict[f"unembed.W_U"], "d_model d_vocab -> 1 d_vocab", "mean"
- )
+ state_dict[f"unembed.W_U"] = state_dict[f"unembed.W_U"] * state_dict[f"ln_final.w"][:, None]
del state_dict[f"ln_final.w"]
+
+ if center_weights:
+ # Center the weights that read in from the LayerNormPre
+ state_dict[f"unembed.W_U"] -= einops.reduce(
+ state_dict[f"unembed.W_U"], "d_model d_vocab -> 1 d_vocab", "mean"
+ )
+
return state_dict
def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]):
@@ -1590,30 +1721,28 @@ def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]):
W_out. This is done by subtracting the mean of the weights from the weights themselves. This
is done in-place. See fold_layer_norm for more details.
"""
- state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict[
- "embed.W_E"
- ].mean(-1, keepdim=True)
+ state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean(
+ -1, keepdim=True
+ )
if self.cfg.positional_embedding_type != "rotary":
state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[
"pos_embed.W_pos"
].mean(-1, keepdim=True)
for l in range(self.cfg.n_layers):
- state_dict[f"blocks.{l}.attn.W_O"] = state_dict[
+ state_dict[f"blocks.{l}.attn.W_O"] = state_dict[f"blocks.{l}.attn.W_O"] - state_dict[
f"blocks.{l}.attn.W_O"
- ] - state_dict[f"blocks.{l}.attn.W_O"].mean(
+ ].mean(
-1, keepdim=True
) # W_O is [head_index, d_model, d_head]
state_dict[f"blocks.{l}.attn.b_O"] = (
- state_dict[f"blocks.{l}.attn.b_O"]
- - state_dict[f"blocks.{l}.attn.b_O"].mean()
+ state_dict[f"blocks.{l}.attn.b_O"] - state_dict[f"blocks.{l}.attn.b_O"].mean()
) # b_O is [d_model]
if not self.cfg.attn_only:
state_dict[f"blocks.{l}.mlp.W_out"] = state_dict[
f"blocks.{l}.mlp.W_out"
] - state_dict[f"blocks.{l}.mlp.W_out"].mean(-1, keepdim=True)
state_dict[f"blocks.{l}.mlp.b_out"] = (
- state_dict[f"blocks.{l}.mlp.b_out"]
- - state_dict[f"blocks.{l}.mlp.b_out"].mean()
+ state_dict[f"blocks.{l}.mlp.b_out"] - state_dict[f"blocks.{l}.mlp.b_out"].mean()
)
return state_dict
@@ -1626,12 +1755,10 @@ def center_unembed(self, state_dict: Dict[str, torch.Tensor]):
how components contribute to the logits, we'll be less misled by components that just add
something to every logit.
"""
- state_dict["unembed.W_U"] = state_dict["unembed.W_U"] - state_dict[
- "unembed.W_U"
- ].mean(-1, keepdim=True)
- state_dict["unembed.b_U"] = (
- state_dict["unembed.b_U"] - state_dict["unembed.b_U"].mean()
+ state_dict["unembed.W_U"] = state_dict["unembed.W_U"] - state_dict["unembed.W_U"].mean(
+ -1, keepdim=True
)
+ state_dict["unembed.b_U"] = state_dict["unembed.b_U"] - state_dict["unembed.b_U"].mean()
return state_dict
def fold_value_biases(self, state_dict: Dict[str, torch.Tensor]):
@@ -1647,16 +1774,26 @@ def fold_value_biases(self, state_dict: Dict[str, torch.Tensor]):
"""
for layer in range(self.cfg.n_layers):
# shape [head_index, d_head]
- b_V = state_dict[f"blocks.{layer}.attn.b_V"]
+ if self.cfg.n_key_value_heads is None:
+ b_V = state_dict[f"blocks.{layer}.attn.b_V"]
+ else:
+ b_V = state_dict[f"blocks.{layer}.attn._b_V"]
+ b_V = torch.repeat_interleave(
+ b_V, dim=0, repeats=self.cfg.n_heads // self.cfg.n_key_value_heads
+ )
# [head_index, d_head, d_model]
W_O = state_dict[f"blocks.{layer}.attn.W_O"]
# [d_model]
b_O_original = state_dict[f"blocks.{layer}.attn.b_O"]
-
folded_b_O = b_O_original + (b_V[:, :, None] * W_O).sum([0, 1])
state_dict[f"blocks.{layer}.attn.b_O"] = folded_b_O
- state_dict[f"blocks.{layer}.attn.b_V"] = torch.zeros_like(b_V)
+ if self.cfg.n_key_value_heads is None:
+ state_dict[f"blocks.{layer}.attn.b_V"] = torch.zeros_like(b_V)
+ else:
+ state_dict[f"blocks.{layer}.attn._b_V"] = torch.zeros_like(
+ state_dict[f"blocks.{layer}.attn._b_V"]
+ )
return state_dict
def refactor_factored_attn_matrices(self, state_dict: Dict[str, torch.Tensor]):
@@ -1789,7 +1926,11 @@ def process_weights_(
version of the same model.
"""
state_dict = self.state_dict()
- if fold_ln and self.cfg.normalization_type == "LN":
+ if fold_ln and self.cfg.num_experts and self.cfg.num_experts > 1:
+ # If we're using MoE, we don't fold the layer norm weights, so we don't need to do any preprocessing
+ # A warning is already issued in `load_and_process_state_dict`
+ pass
+ elif fold_ln and self.cfg.normalization_type == "LN":
# If we're folding the LN into the weights, we need to replace all the layernorm layers
# with LayerNormPres, which do not have learnable parameters. This is somewhat hacky,
# but it's the easiest way to do it.
@@ -1798,8 +1939,17 @@ def process_weights_(
for layer in self.blocks:
layer.ln1 = LayerNormPre(self.cfg)
layer.ln2 = LayerNormPre(self.cfg)
- if self.cfg.act_fn.endswith("_ln"):
+ if self.cfg.act_fn is not None and self.cfg.act_fn.endswith("_ln"):
layer.mlp.ln = LayerNormPre(self.cfg)
+ elif fold_ln and self.cfg.normalization_type == "RMS":
+ # We do the same for RMSNorm if used
+ self.cfg.normalization_type = "RMSPre"
+ self.ln_final = RMSNormPre(self.cfg)
+ for layer in self.blocks:
+ layer.ln1 = RMSNormPre(self.cfg)
+ layer.ln2 = RMSNormPre(self.cfg)
+ if self.cfg.act_fn is not None and self.cfg.act_fn.endswith("_ln"):
+ layer.mlp.ln = RMSNormPre(self.cfg)
self.load_and_process_state_dict(
state_dict,
@@ -1885,9 +2035,7 @@ def generate(
assert (
self.tokenizer is not None
), "Must provide a tokenizer if passing a string to the model"
- tokens = self.to_tokens(
- input, prepend_bos=prepend_bos, padding_side=padding_side
- )
+ tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
else:
tokens = input
@@ -1908,12 +2056,12 @@ def generate(
else:
past_kv_cache = None
- stop_tokens = []
+ stop_tokens: List[int] = []
eos_token_for_padding = 0
+ assert self.tokenizer is not None
if stop_at_eos:
tokenizer_has_eos_token = (
- self.tokenizer is not None
- and self.tokenizer.eos_token_id is not None
+ self.tokenizer is not None and self.tokenizer.eos_token_id is not None
)
if eos_token_id is None:
assert (
@@ -1929,15 +2077,11 @@ def generate(
# eos_token_id is a Sequence (e.g. list or tuple)
stop_tokens = eos_token_id
eos_token_for_padding = (
- self.tokenizer.eos_token_id
- if tokenizer_has_eos_token
- else eos_token_id[0]
+ self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0]
)
# An array to track which sequences in the batch have finished.
- finished_sequences = torch.zeros(
- batch_size, dtype=torch.bool, device=self.cfg.device
- )
+ finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)
# Currently nothing in HookedTransformer changes with eval, but this is here in case
# that changes in the future.
@@ -1995,7 +2139,10 @@ def generate(
# instead.
sampled_tokens[finished_sequences] = eos_token_for_padding
finished_sequences.logical_or_(
- torch.isin(sampled_tokens, torch.tensor(stop_tokens).to(device))
+ torch.isin(
+ sampled_tokens.to(self.cfg.device),
+ torch.tensor(stop_tokens).to(self.cfg.device),
+ )
)
tokens = torch.cat([tokens, sampled_tokens.unsqueeze(-1)], dim=-1)
@@ -2079,7 +2226,7 @@ def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]:
return torch.stack([block.mlp.W_in for block in self.blocks], dim=0)
@property
- def W_gate(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]:
+ def W_gate(self) -> Union[Float[torch.Tensor, "n_layers d_model d_mlp"], None]:
"""Stack the MLP gate weights across all layers.
Only works for models with gated MLPs.
@@ -2161,9 +2308,7 @@ def accumulated_bias(
if include_mlp_biases:
accumulated_bias += self.blocks[i].mlp.b_out
if mlp_input:
- assert (
- layer < self.cfg.n_layers
- ), "Cannot include attn_bias from beyond the final layer"
+ assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer"
accumulated_bias += self.blocks[layer].attn.b_O
return accumulated_bias
@@ -2197,19 +2342,14 @@ def all_composition_scores(
# layer than the left head.
mask = (
torch.arange(self.cfg.n_layers, device=self.cfg.device)[:, None, None, None]
- < torch.arange(self.cfg.n_layers, device=self.cfg.device)[
- None, None, :, None
- ]
+ < torch.arange(self.cfg.n_layers, device=self.cfg.device)[None, None, :, None]
)
scores = torch.where(mask, scores, torch.zeros_like(scores))
return scores
def all_head_labels(self):
- return [
- f"L{l}H{h}"
- for l in range(self.cfg.n_layers)
- for h in range(self.cfg.n_heads)
- ]
+ """Returns a list of all head names in the model."""
+ return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)]
def load_sample_training_dataset(self, **kwargs):
"""Load Sample Training Dataset.
@@ -2254,7 +2394,7 @@ def load_sample_training_dataset(self, **kwargs):
def sample_datapoint(
self,
- tokenize: Optional[bool] = False,
+ tokenize: bool = False,
prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
) -> Union[str, Float[torch.Tensor, "1 pos"]]:
@@ -2281,6 +2421,7 @@ def sample_datapoint(
"""
if self.dataset is None:
self.load_sample_training_dataset()
+ assert self.dataset is not None # keep mypy happy
sample_dataset_size = len(self.dataset)
index = np.random.randint(0, sample_dataset_size)
if not tokenize:
diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py
index 2730c3bd4..7a38c22c5 100644
--- a/transformer_lens/HookedTransformerConfig.py
+++ b/transformer_lens/HookedTransformerConfig.py
@@ -3,6 +3,7 @@
Module with a dataclass for storing the configuration of a
:class:`transformer_lens.HookedTransformer` model.
"""
+
from __future__ import annotations
import logging
@@ -78,9 +79,8 @@ class HookedTransformerConfig:
local attention
weight_init_mode (str): the initialization mode to use for the
weights. Only relevant for custom models, ignored for pre-trained.
- Currently the only supported mode is 'gpt2', where biases are
- initialized to 0 and weights are standard normals of range
- initializer_range.
+ We now support 'gpt2', 'xavier_uniform', 'xavier_normal', 'kaiming_uniform',
+ 'kaiming_normal'. MuP support to come. Defaults to 'gpt2'.
normalization_type (str, *optional*): the type of normalization to use.
Options are None (no normalization), 'LN' (use LayerNorm, including weights
& biases) and 'LNPre' (use LayerNorm, but no weights & biases).
@@ -95,10 +95,12 @@ class HookedTransformerConfig:
attn_only (bool): Whether to only use attention layers, no feedforward
layers. Defaults to False
seed (int, *optional*): The seed to use for the model.
- Used to set sources of randomness (Python, PyTorch and
- NumPy) and to initialize weights. Defaults to None. We recommend setting a seed, so your experiments are reproducible.
+ Used to set sources of randomness (Python, PyTorch and NumPy) and to initialize weights.
+ Defaults to None. We recommend setting a seed, so your experiments are reproducible.
initializer_range (float): The standard deviation of the normal used to
- initialise the weights, initialized to 0.8 / sqrt(d_model) .
+ initialise the weights, initialized to 0.8 / sqrt(d_model). If weight_init_mode is
+ 'xavier_uniform' or 'xavier_normal', this value is instead treated as the `gain` parameter for the weight
+ initialisation (a constant factor to scale the weights by). Defaults to -1.0, which means not set.
init_weights (bool): Whether to initialize the weights. Defaults to
True. If False, does not initialize weights.
scale_attn_by_inverse_layer_idx (bool): Whether to scale the attention
@@ -147,8 +149,16 @@ class HookedTransformerConfig:
tokenizer_prepends_bos (bool, *optional*): This flag is set by set_tokenizer. It is set to True only
when the tokenizer automatically prepends the BOS token if initialized with add_bos_token=True.
We need this information to dynamically control bos prepending.
+ load_in_4bit(bool): If this flag is set, then it's assumed that parameters are 4-bit quantized
+ with bitsandbytes. Currently only supported for Llama.
+ n_key_value_heads (int, *optional*): The number of groups of heads that use the same key and value matrix.
+ Only for models that use Grouped Query Attention.
post_embedding_ln (bool): Whether to apply layer normalization after embedding the tokens. Defaults
to False.
+ num_experts (int, *optional*): The number of experts to use in the MoE layer. If set, experts_per_token
+ must also be set. Set to None if not using MoE.
+ experts_per_token (int, *optional*): The number of experts to use for each pass in the MoE layer. If set,
+ num_experts must also be set. Set to None if not using MoE.
"""
n_layers: int
@@ -196,7 +206,14 @@ class HookedTransformerConfig:
default_prepend_bos: bool = True
dtype: torch.dtype = torch.float32
tokenizer_prepends_bos: Optional[bool] = None
+ n_key_value_heads: Optional[int] = None
post_embedding_ln: bool = False
+ rotary_base: int = 10000
+ trust_remote_code: bool = False
+ rotary_adjacent_pairs: bool = False
+ load_in_4bit: bool = False
+ num_experts: Optional[int] = None
+ experts_per_token: Optional[int] = None
def __post_init__(self):
if self.n_heads == -1:
@@ -204,47 +221,62 @@ def __post_init__(self):
if not self.d_model % (self.d_head) == 0:
logging.warning(
- f"d_model {self.d_model} is not divisible by d_head {self.d_head}. n_heads was inferred to be {self.n_heads}, rounding down the ratio."
+ "d_model %d is not divisible by d_head %d."
+ "n_heads was inferred to be %d, rounding down the ratio.",
+ self.d_model,
+ self.d_head,
+ self.n_heads,
)
if self.seed is not None:
self.set_seed_everywhere(self.seed)
if self.use_local_attn:
- assert (
- self.window_size is not None
- ), "window_size must be specified for local attention"
- assert (
- self.attn_types is not None
- ), "attn_types must be specified for local attention"
+ assert self.window_size is not None, "window_size must be specified for local attention"
+ assert self.attn_types is not None, "attn_types must be specified for local attention"
if not self.attn_only:
if self.d_mlp is None:
# For some reason everyone hard codes in this hyper-parameter!
- self.d_mlp = self.d_model * 4
- assert (
- self.act_fn is not None
- ), "act_fn must be specified for non-attn-only models"
+ self.d_mlp: int = self.d_model * 4
+ assert self.act_fn is not None, "act_fn must be specified for non-attn-only models"
assert (
self.act_fn in SUPPORTED_ACTIVATIONS
), f"act_fn={self.act_fn} must be one of {SUPPORTED_ACTIVATIONS}"
- if self.initializer_range < 0:
+ if self.initializer_range < 0 and self.init_mode == "gpt2":
# Roughly copy the GPT-2 value, but proportional to sqrt(1/d_model)
self.initializer_range = 0.8 / np.sqrt(self.d_model)
+ if self.initializer_range < 0 and self.init_mode != "gpt2":
+ # This is the gain parameter for the weight initialisation
+ self.initializer_range = 1.0
if self.d_vocab_out == -1:
# d_vocab_out defaults to d_vocab, unless there's an algorithmic task
- # If d_vocab is not set, it'll be inferred from tokenizer_name or from a tokenizer explicitly passed to HookedTransformer initialisation.
+ # If d_vocab is not set, it'll be inferred from tokenizer_name or from a tokenizer
+ # explicitly passed to HookedTransformer initialisation.
self.d_vocab_out = self.d_vocab
if self.positional_embedding_type == "rotary" and self.rotary_dim is None:
self.rotary_dim = self.d_head
+ if self.num_experts is not None:
+ assert (
+ self.experts_per_token is not None
+ ), "experts_per_token must be set if num_experts is set"
+ if self.experts_per_token is not None:
+ assert (
+ self.num_experts is not None
+ ), "num_experts must be set if experts_per_token is set"
+
# The number of parameters in attention layers (ignoring biases and layer norm). 4 because W_Q, W_K, W_V and W_O
- self.n_params = self.n_layers * (
- (self.d_model * self.d_head * self.n_heads * 4)
- )
+ self.n_params = self.n_layers * ((self.d_model * self.d_head * self.n_heads * 4))
if not self.attn_only:
+ assert self.d_mlp is not None # mypy
# Number of parameters in MLP layers (ignoring biases and layer norm). 2 because W_in and W_out
- self.n_params += self.n_layers * self.d_model * self.d_mlp * 2
+ mlp_params_per_layer = self.d_model * self.d_mlp * (2 + self.gated_mlp)
+
+ if self.num_experts:
+ # If we are using MoE, we multiply by num_experts, and add the expert gate parameters (d_model * num_experts)
+ mlp_params_per_layer = (mlp_params_per_layer + self.d_model) * self.num_experts
+ self.n_params += self.n_layers * mlp_params_per_layer
if self.device is None:
self.device = utils.get_device()
diff --git a/transformer_lens/SVDInterpreter.py b/transformer_lens/SVDInterpreter.py
index caaecd448..cf0354d61 100644
--- a/transformer_lens/SVDInterpreter.py
+++ b/transformer_lens/SVDInterpreter.py
@@ -3,6 +3,7 @@
Module for getting the singular vectors of the OV, w_in, and w_out matrices of a
:class:`transformer_lens.HookedTransformer`.
"""
+
from typing import Optional, Union
import fancy_einsum as einsum
@@ -10,7 +11,8 @@
from typeguard import typechecked
from typing_extensions import Literal
-from transformer_lens import FactoredMatrix, HookedTransformer
+from transformer_lens.FactoredMatrix import FactoredMatrix
+from transformer_lens.HookedTransformer import HookedTransformer
OUTPUT_EMBEDDING = "unembed.W_U"
VECTOR_TYPES = ["OV", "w_in", "w_out"]
@@ -79,7 +81,9 @@ def plot_matrix(matrix, tokens, k=10, filter="topk"):
"w_out",
], f"Head index optional only for w_in and w_out, got {vector_type}"
+ matrix: Union[FactoredMatrix, torch.Tensor]
if vector_type == "OV":
+ assert head_index is not None # keep mypy happy
matrix = self._get_OV_matrix(layer_index, head_index)
V = matrix.Vh.T
@@ -92,13 +96,9 @@ def plot_matrix(matrix, tokens, k=10, filter="topk"):
_, _, V = torch.linalg.svd(matrix)
else:
- raise ValueError(
- f"Vector type must be in {VECTOR_TYPES}, instead got {vector_type}"
- )
+ raise ValueError(f"Vector type must be in {VECTOR_TYPES}, instead got {vector_type}")
- return self._get_singular_vectors_from_matrix(
- V, self.params[OUTPUT_EMBEDDING], num_vectors
- )
+ return self._get_singular_vectors_from_matrix(V, self.params[OUTPUT_EMBEDDING], num_vectors)
def _get_singular_vectors_from_matrix(
self,
@@ -108,12 +108,12 @@ def _get_singular_vectors_from_matrix(
) -> torch.Tensor:
"""Returns the top num_vectors singular vectors from a matrix."""
- vectors = []
+ vectors_list = []
for i in range(num_vectors):
- activations = V[i, :].float() @ embedding
- vectors.append(activations)
+ activations = V[i, :].float() @ embedding # type: ignore
+ vectors_list.append(activations)
- vectors = torch.stack(vectors, dim=1).unsqueeze(1)
+ vectors = torch.stack(vectors_list, dim=1).unsqueeze(1)
assert vectors.shape == (
self.cfg.d_vocab,
1,
@@ -131,10 +131,8 @@ def _get_OV_matrix(self, layer_index: int, head_index: int) -> FactoredMatrix:
0 <= head_index < self.cfg.n_heads
), f"Head index must be between 0 and {self.cfg.n_heads-1} but got {head_index}"
- W_V, W_O = (
- self.params[f"blocks.{layer_index}.attn.W_V"],
- self.params[f"blocks.{layer_index}.attn.W_O"],
- )
+ W_V: torch.Tensor = self.params[f"blocks.{layer_index}.attn.W_V"]
+ W_O: torch.Tensor = self.params[f"blocks.{layer_index}.attn.W_O"]
W_V, W_O = W_V[head_index, :, :], W_O[head_index, :, :]
return FactoredMatrix(W_V, W_O)
diff --git a/transformer_lens/__init__.py b/transformer_lens/__init__.py
index e2fb1484b..9ab2acea7 100644
--- a/transformer_lens/__init__.py
+++ b/transformer_lens/__init__.py
@@ -10,6 +10,9 @@
from .FactoredMatrix import FactoredMatrix
from .ActivationCache import ActivationCache
from .HookedTransformer import HookedTransformer
+from .HookedSAEConfig import HookedSAEConfig
+from .HookedSAE import HookedSAE
+from .HookedSAETransformer import HookedSAETransformer
from .SVDInterpreter import SVDInterpreter
from .HookedEncoder import HookedEncoder
from . import head_detector
diff --git a/transformer_lens/components.py b/transformer_lens/components.py
deleted file mode 100644
index 9c8d663cd..000000000
--- a/transformer_lens/components.py
+++ /dev/null
@@ -1,1306 +0,0 @@
-"""Hooked Transformer Components.
-
-This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`)
-needed to create many different types of generative language models. They are used by
-:class:`transformer_lens.HookedTransformer`.
-"""
-import logging
-from typing import Dict, Optional, Tuple, Union
-
-import einops
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from fancy_einsum import einsum
-from jaxtyping import Float, Int
-
-from transformer_lens.FactoredMatrix import FactoredMatrix
-from transformer_lens.hook_points import HookPoint
-from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
-from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry
-from transformer_lens.utils import gelu_fast, gelu_new, get_offset_position_ids, solu
-
-
-# Embed & Unembed
-class Embed(nn.Module):
- def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
- super().__init__()
- if isinstance(cfg, Dict):
- cfg = HookedTransformerConfig.from_dict(cfg)
- self.cfg = cfg
- self.W_E: Float[torch.Tensor, "d_vocab d_model"] = nn.Parameter(
- torch.empty(self.cfg.d_vocab, self.cfg.d_model, dtype=cfg.dtype)
- )
- # Some models (e.g. Bloom) need post embedding layer norm
- if cfg.post_embedding_ln:
- self.ln = LayerNorm(cfg)
-
- def forward(
- self, tokens: Int[torch.Tensor, "batch pos"]
- ) -> Float[torch.Tensor, "batch pos d_model"]:
- # If A has shape [a, b] and B has shape [c, d], then A[:, B] has shape [a, c, d]
- # B acts as a tensor of indices into the second dimension (so >=0 and Float[torch.Tensor, "batch pos d_vocab_out"]:
- return (
- einsum(
- "batch pos d_model, d_model vocab -> batch pos vocab",
- residual,
- self.W_U,
- )
- + self.b_U
- )
-
-
-# Positional Embeddings
-class PosEmbed(nn.Module):
- def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
- super().__init__()
- if isinstance(cfg, Dict):
- cfg = HookedTransformerConfig.from_dict(cfg)
- self.cfg = cfg
- self.W_pos = nn.Parameter(
- torch.empty(self.cfg.n_ctx, self.cfg.d_model, dtype=cfg.dtype)
- )
-
- def forward(
- self,
- tokens: Int[torch.Tensor, "batch pos"],
- past_kv_pos_offset: int = 0,
- attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
- ) -> Float[torch.Tensor, "batch pos d_model"]:
- """
- Forward pass for positional embeddings.
-
- Args:
- tokens (Int[torch.Tensor, "batch pos"]): Input tokens.
- past_kv_pos_offset (int, optional): The length of tokens in the past_kv_cache. Defaults to 0.
- attention_mask (Int[torch.Tensor, "batch pos"], optional): The attention mask for padded tokens.
- Defaults to None.
-
- Returns:
- Float[torch.Tensor, "batch pos d_model"]: Absolute position embeddings.
- """
- tokens_length = tokens.size(-1)
-
- if attention_mask is None:
- pos_embed = self.W_pos[
- past_kv_pos_offset : tokens_length + past_kv_pos_offset, :
- ] # [pos, d_model]
- batch_pos_embed = einops.repeat(
- pos_embed, "pos d_model -> batch pos d_model", batch=tokens.size(0)
- )
-
- else:
- # Separated from the no padding case for computational efficiency
- # (this code is a bit slower than the code above)
-
- offset_position_ids = get_offset_position_ids(
- past_kv_pos_offset, attention_mask
- )
- pos_embed = self.W_pos[offset_position_ids] # [batch, pos, d_model]
-
- # Set the position embeddings to 0 for pad tokens (this is an arbitrary choice)
- padding_mask = ~attention_mask.bool() # [batch, tokens_length]
- offset_padding_mask = padding_mask[
- :, past_kv_pos_offset : tokens_length + past_kv_pos_offset
- ].unsqueeze(
- -1
- ) # [batch, pos, 1]
- batch_pos_embed = torch.where(offset_padding_mask, 0, pos_embed)
-
- return batch_pos_embed.clone()
-
-
-class TokenTypeEmbed(nn.Module):
- """
- The token-type embed is a binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, `1` from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length).
-
- See the BERT paper for more information: https://arxiv.org/pdf/1810.04805.pdf
- """
-
- def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
- super().__init__()
- if isinstance(cfg, Dict):
- cfg = HookedTransformerConfig.from_dict(cfg)
- self.cfg = cfg
- self.W_token_type = nn.Parameter(
- torch.empty(2, self.cfg.d_model, dtype=cfg.dtype)
- )
-
- def forward(self, token_type_ids: Int[torch.Tensor, "batch pos"]):
- return self.W_token_type[token_type_ids, :]
-
-
-class BertEmbed(nn.Module):
- """
- Custom embedding layer for a BERT-like model. This module computes the sum of the token, positional and token-type embeddings and takes the layer norm of the result.
- """
-
- def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
- super().__init__()
- if isinstance(cfg, Dict):
- cfg = HookedTransformerConfig.from_dict(cfg)
- self.cfg = cfg
- self.embed = Embed(cfg)
- self.pos_embed = PosEmbed(cfg)
- self.token_type_embed = TokenTypeEmbed(cfg)
- self.ln = LayerNorm(cfg)
-
- self.hook_embed = HookPoint()
- self.hook_pos_embed = HookPoint()
- self.hook_token_type_embed = HookPoint()
-
- def forward(
- self,
- input_ids: Int[torch.Tensor, "batch pos"],
- token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
- ):
- base_index_id = torch.arange(input_ids.shape[1], device=input_ids.device)
- index_ids = einops.repeat(
- base_index_id, "pos -> batch pos", batch=input_ids.shape[0]
- )
- if token_type_ids is None:
- token_type_ids = torch.zeros_like(input_ids)
-
- word_embeddings_out = self.hook_embed(self.embed(input_ids))
- position_embeddings_out = self.hook_pos_embed(self.pos_embed(index_ids))
- token_type_embeddings_out = self.hook_token_type_embed(
- self.token_type_embed(token_type_ids)
- )
-
- embeddings_out = (
- word_embeddings_out + position_embeddings_out + token_type_embeddings_out
- )
- layer_norm_out = self.ln(embeddings_out)
- return layer_norm_out
-
-
-class BertMLMHead(nn.Module):
- """
- Transforms BERT embeddings into logits. The purpose of this module is to predict masked tokens in a sentence.
- """
-
- def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
- super().__init__()
- if isinstance(cfg, Dict):
- cfg = HookedTransformerConfig.from_dict(cfg)
- self.cfg = cfg
- self.W = nn.Parameter(torch.empty(cfg.d_model, cfg.d_model, dtype=cfg.dtype))
- self.b = nn.Parameter(torch.zeros(cfg.d_model, dtype=cfg.dtype))
- self.act_fn = nn.GELU()
- self.ln = LayerNorm(cfg)
-
- def forward(self, resid: Float[torch.Tensor, "batch pos d_model"]) -> torch.Tensor:
- resid = (
- einsum(
- "batch pos d_model_in, d_model_out d_model_in -> batch pos d_model_out",
- resid,
- self.W,
- )
- + self.b
- )
- resid = self.act_fn(resid)
- resid = self.ln(resid)
- return resid
-
-
-# LayerNormPre
-# I fold the LayerNorm weights and biases into later weights and biases.
-# This is just the 'center and normalise' part of LayerNorm
-# Centering is equivalent to just deleting one direction of residual space,
-# and is equivalent to centering the weight matrices of everything writing to the residual stream
-# Normalising is a funkier non-linear operation, that projects the residual stream onto the unit hypersphere
-class LayerNormPre(nn.Module):
- def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
- """LayerNormPre - the 'center and normalise' part of LayerNorm. Length is
- normally d_model, but is d_mlp for softmax. Not needed as a parameter. This
- should only be used in inference mode after folding in LayerNorm weights"""
- super().__init__()
- if isinstance(cfg, Dict):
- cfg = HookedTransformerConfig.from_dict(cfg)
- self.cfg = cfg
- self.eps = self.cfg.eps
-
- # Adds a hook point for the normalisation scale factor
- self.hook_scale = HookPoint() # [batch, pos]
- # Hook Normalized captures LN output - here it's a vector with std 1 and mean 0
- self.hook_normalized = HookPoint() # [batch, pos, length]
-
- def forward(
- self,
- x: Union[
- Float[torch.Tensor, "batch pos d_model"],
- Float[torch.Tensor, "batch pos head_index d_model"],
- ],
- ) -> Union[
- Float[torch.Tensor, "batch pos d_model"],
- Float[torch.Tensor, "batch pos head_index d_model"],
- ]:
- if self.cfg.dtype not in [torch.float32, torch.float64]:
- x = x.to(torch.float32)
-
- x = x - x.mean(axis=-1, keepdim=True) # [batch, pos, length]
- scale: Union[
- Float[torch.Tensor, "batch pos 1"],
- Float[torch.Tensor, "batch pos head_index 1"],
- ] = self.hook_scale((x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt())
- return self.hook_normalized(x / scale).to(self.cfg.dtype)
-
-
-class LayerNorm(nn.Module):
- def __init__(
- self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None
- ):
- """
- LayerNorm with optional length parameter
-
- length (Optional[int]): If the dimension of the LayerNorm. If not provided, assumed to be d_model
- """
- super().__init__()
- if isinstance(cfg, Dict):
- cfg = HookedTransformerConfig.from_dict(cfg)
- self.cfg = cfg
- self.eps = self.cfg.eps
- if length is None:
- self.length = self.cfg.d_model
- else:
- self.length = length
-
- self.w = nn.Parameter(torch.ones(self.length, dtype=cfg.dtype))
- self.b = nn.Parameter(torch.zeros(self.length, dtype=cfg.dtype))
-
- # Adds a hook point for the normalisation scale factor
- self.hook_scale = HookPoint() # [batch, pos, 1]
- # Hook_normalized is on the LN output
- self.hook_normalized = HookPoint() # [batch, pos, length]
-
- def forward(
- self,
- x: Union[
- Float[torch.Tensor, "batch pos d_model"],
- Float[torch.Tensor, "batch pos head_index d_model"],
- ],
- ) -> Union[
- Float[torch.Tensor, "batch pos d_model"],
- Float[torch.Tensor, "batch pos head_index d_model"],
- ]:
- if self.cfg.dtype not in [torch.float32, torch.float64]:
- x = x.to(torch.float32)
-
- x = x - x.mean(axis=-1, keepdim=True) # [batch, pos, length]
- scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
- (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
- )
- x = x / scale # [batch, pos, length]
- return self.hook_normalized(x * self.w + self.b).to(self.cfg.dtype)
-
-
-class RMSNormPre(nn.Module):
- def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
- """RMSNormPre - LayerNormPre without the centering and bias (RMS = Root Mean Square)"""
- super().__init__()
- if isinstance(cfg, Dict):
- cfg = HookedTransformerConfig.from_dict(cfg)
- self.cfg = cfg
- self.eps = self.cfg.eps
-
- # Adds a hook point for the normalisation scale factor
- self.hook_scale = HookPoint() # [batch, pos]
- self.hook_normalized = HookPoint() # [batch, pos, length]
-
- def forward(
- self, x: Float[torch.Tensor, "batch pos length"]
- ) -> Float[torch.Tensor, "batch pos length"]:
- if self.cfg.dtype not in [torch.float32, torch.float64]:
- x = x.to(torch.float32)
-
- scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
- (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
- )
- return self.hook_normalized(x / scale).to(
- self.cfg.dtype
- ) # [batch, pos, length]
-
-
-class RMSNorm(nn.Module):
- def __init__(
- self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None
- ):
- """
- RMSNorm - LayerNorm without the centering and bias (RMS = Root Mean Square)
-
- length (Optional[int]): If the dimension of the RMSNorm. If not provided, assumed to be d_model
- """
- super().__init__()
- if isinstance(cfg, Dict):
- cfg = HookedTransformerConfig.from_dict(cfg)
- self.cfg = cfg
- self.eps = self.cfg.eps
- if length is None:
- self.length = self.cfg.d_model
- else:
- self.length = length
-
- self.w = nn.Parameter(torch.ones(self.length, dtype=cfg.dtype))
-
- # Adds a hook point for the normalisation scale factor
- self.hook_scale = HookPoint() # [batch, pos, 1]
- self.hook_normalized = HookPoint() # [batch, pos, length]
-
- def forward(
- self, x: Float[torch.Tensor, "batch pos length"]
- ) -> Float[torch.Tensor, "batch pos length"]:
- if self.cfg.dtype not in [torch.float32, torch.float64]:
- x = x.to(torch.float32)
-
- scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
- (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
- )
- x = self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length]
- return x * self.w
-
-
-# Attention
-class Attention(nn.Module):
- def __init__(
- self,
- cfg: Union[Dict, HookedTransformerConfig],
- attn_type: str = "global",
- layer_id: Optional[int] = None,
- ):
- """Attention Block - params have shape [head_index, d_model, d_head] (or [head_index, d_head, d_model] for W_O) and multiply on the right. attn_scores refers to query key dot product immediately before attention softmax
-
- Convention: All attention pattern-style matrices have shape [batch, head_index, query_pos, key_pos]
-
- Args:
- cfg (Union[Dict, HookedTransformerConfig]): Config
- attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global".
- layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.
- """
- super().__init__()
- if isinstance(cfg, Dict):
- cfg = HookedTransformerConfig.from_dict(cfg)
- self.cfg = cfg
- self.W_Q = nn.Parameter(
- torch.empty(
- self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype
- )
- )
- self.W_K = nn.Parameter(
- torch.empty(
- self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype
- )
- )
- self.W_V = nn.Parameter(
- torch.empty(
- self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype
- )
- )
- self.W_O = nn.Parameter(
- torch.empty(
- self.cfg.n_heads, self.cfg.d_head, self.cfg.d_model, dtype=cfg.dtype
- )
- )
- self.b_Q = nn.Parameter(
- torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)
- )
- self.b_K = nn.Parameter(
- torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)
- )
- self.b_V = nn.Parameter(
- torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)
- )
- self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype))
-
- self.attn_type = attn_type
- # Create a max_ctx x max_ctx mask, with True iff that query position
- # can attend to that key position (query is first axis, key is second axis)
- causal_mask = torch.tril(torch.ones((self.cfg.n_ctx, self.cfg.n_ctx)).bool())
- if self.attn_type == "global":
- # For global attention, this is a lower triangular matrix - key <= query
- self.register_buffer("mask", causal_mask)
- elif self.attn_type == "local":
- # For local, this is banded, query - window_size < key <= query
- assert isinstance(self.cfg.window_size, int)
- self.register_buffer(
- "mask", torch.triu(causal_mask, 1 - self.cfg.window_size)
- )
- else:
- raise ValueError(f"Invalid attention type: {self.attn_type}")
-
- self.register_buffer("IGNORE", torch.tensor(-torch.inf))
-
- self.layer_id = layer_id
-
- # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability?
- if self.cfg.use_attn_scale:
- self.attn_scale = np.sqrt(self.cfg.d_head)
- else:
- self.attn_scale = 1.0
- if self.cfg.scale_attn_by_inverse_layer_idx:
- self.attn_scale *= self.layer_id + 1
-
- self.hook_k = HookPoint() # [batch, pos, head_index, d_head]
- self.hook_q = HookPoint() # [batch, pos, head_index, d_head]
- self.hook_v = HookPoint() # [batch, pos, head_index, d_head]
- self.hook_z = HookPoint() # [batch, pos, head_index, d_head]
- self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos]
- self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos]
- self.hook_result = HookPoint() # [batch, pos, head_index, d_model]
-
- # See HookedTransformerConfig for more details.
- if self.cfg.positional_embedding_type == "shortformer":
- # This tracks the input to the keys and queries, which is resid_pre + pos_embeds
- self.hook_attn_input = HookPoint() # [batch, pos, d_model]
- elif self.cfg.positional_embedding_type == "rotary":
- # Applies a rotation to each two-element chunk of keys and queries pre dot producting to bake in relative position. See HookedTransformerConfig for details
- self.hook_rot_k = HookPoint()
- self.hook_rot_q = HookPoint()
- sin, cos = self.calculate_sin_cos_rotary(
- self.cfg.rotary_dim, self.cfg.n_ctx, dtype=self.cfg.dtype
- )
- self.register_buffer("rotary_sin", sin)
- self.register_buffer("rotary_cos", cos)
- elif self.cfg.positional_embedding_type == "alibi":
- # ALiBi bias wil be constructed on the first forward pass.
- # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage.
- self.alibi = None
-
- @property
- def OV(self) -> FactoredMatrix:
- """
- OV-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity between the value vector and the output of the layer, the output is purely determined by the matrix W_OV = W_V @ W_O, and not W_V or W_O individually. (Mathematically, for a single head, output == pattern @ residual @ W_V @ W_O, see the glossary for more)
-
- Done in the order W_V, W_O because the paper uses left-multiplying weight matrices, and TransformerLens uses right-multiplying, sorry!
-
- Returns a FactoredMatrix, with left matrix W_V [head_index, d_model, d_head] and right matrix W_O [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model]. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the OV circuit of a head k, attn.OV[k] works.
- """
- return FactoredMatrix(self.W_V, self.W_O)
-
- @property
- def QK(self) -> FactoredMatrix:
- """
- QK-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity in the key-query dot product, the output is purely determined by the matrix W_QK = W_Q.T @ W_K, and not W_Q or W_K individually. (Mathematically, for a single head, pattern = destination_residual.T @ W_Q.T @ W_K @ source-residual, see the glossary for more).
-
- Done in the order Q on the left, K on the right, because the pattern has dimensions [destination_pos, source_pos]
-
- Returns a FactoredMatrix, with left matrix W_Q [head_index, d_model, d_head] and right matrix W_K.T [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model] matrix. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the QK circuit of a head k, attn.QK[k] works.
- """
- W_K_transpose = einops.rearrange(
- self.W_K, "head_index d_model d_head -> head_index d_head d_model"
- )
- return FactoredMatrix(self.W_Q, W_K_transpose)
-
- def forward(
- self,
- query_input: Union[
- Float[torch.Tensor, "batch pos d_model"],
- Float[torch.Tensor, "batch pos head_index d_model"],
- ],
- key_input: Union[
- Float[torch.Tensor, "batch pos d_model"],
- Float[torch.Tensor, "batch pos head_index d_model"],
- ],
- value_input: Union[
- Float[torch.Tensor, "batch pos d_model"],
- Float[torch.Tensor, "batch pos head_index d_model"],
- ],
- past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None,
- additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None,
- attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
- ) -> Float[torch.Tensor, "batch pos d_model"]:
- """
- shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details
- past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None
- additive_attention_mask is an optional mask to add to the attention weights. Defaults to None.
- attention_mask is the attention mask for padded tokens. Defaults to None.
- """
-
- if self.cfg.use_split_qkv_input or self.cfg.use_attn_in:
- qkv_einops_string = "batch pos head_index d_model"
- else:
- qkv_einops_string = "batch pos d_model"
- q = self.hook_q(
- einsum(
- f"{qkv_einops_string}, head_index d_model d_head \
- -> batch pos head_index d_head",
- query_input,
- self.W_Q,
- )
- + self.b_Q
- ) # [batch, pos, head_index, d_head]
- k = self.hook_k(
- einsum(
- f"{qkv_einops_string}, head_index d_model d_head \
- -> batch pos head_index d_head",
- key_input,
- self.W_K,
- )
- + self.b_K
- ) # [batch, pos, head_index, d_head]
- v = self.hook_v(
- einsum(
- f"{qkv_einops_string}, head_index d_model d_head \
- -> batch pos head_index d_head",
- value_input,
- self.W_V,
- )
- + self.b_V
- ) # [batch, pos, head_index, d_head]
-
- if past_kv_cache_entry is not None:
- # Appends the new keys and values to the cached values, and automatically updates the cache
- kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1)
- k, v = past_kv_cache_entry.append(k, v)
- else:
- # Not using a cache
- kv_cache_pos_offset = 0
-
- if self.cfg.positional_embedding_type == "rotary":
- q = self.hook_rot_q(
- self.apply_rotary(q, kv_cache_pos_offset, attention_mask)
- )
- k = self.hook_rot_k(
- self.apply_rotary(k, 0, attention_mask)
- ) # keys are cached so no offset
-
- if self.cfg.dtype not in [torch.float32, torch.float64]:
- # If using 16 bits, increase the precision to avoid numerical instabilities
- q = q.to(torch.float32)
- k = k.to(torch.float32)
-
- attn_scores = (
- einsum(
- "batch query_pos head_index d_head, \
- batch key_pos head_index d_head \
- -> batch head_index query_pos key_pos",
- q,
- k,
- )
- / self.attn_scale
- ) # [batch, head_index, query_pos, key_pos]
-
- if self.cfg.positional_embedding_type == "alibi":
- query_ctx = attn_scores.size(-2)
- # The key context length is the number of positions in the past - this includes all positions in the cache
- key_ctx = attn_scores.size(-1)
-
- # only recompute when necessary to increase efficiency.
- if self.alibi is None or key_ctx > self.alibi.size(-1):
- self.alibi = Attention.create_alibi_bias(
- self.cfg.n_heads, key_ctx, self.cfg.device
- )
-
- attn_scores += self.alibi[
- :, :query_ctx, :key_ctx
- ] # [batch, head_index, query_pos, key_pos]
-
- if self.cfg.attention_dir == "causal":
- # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask.
- attn_scores = self.apply_causal_mask(
- attn_scores, kv_cache_pos_offset, attention_mask
- ) # [batch, head_index, query_pos, key_pos]
- if additive_attention_mask is not None:
- attn_scores += additive_attention_mask
-
- attn_scores = self.hook_attn_scores(attn_scores)
- pattern = F.softmax(attn_scores, dim=-1)
- pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
- pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
- pattern = pattern.to(self.cfg.dtype)
- z = self.hook_z(
- einsum(
- "batch key_pos head_index d_head, \
- batch head_index query_pos key_pos -> \
- batch query_pos head_index d_head",
- v,
- pattern,
- )
- ) # [batch, pos, head_index, d_head]
- if not self.cfg.use_attn_result:
- out = (
- (
- einsum(
- "batch pos head_index d_head, \
- head_index d_head d_model -> \
- batch pos d_model",
- z,
- self.W_O,
- )
- )
- + self.b_O
- ) # [batch, pos, d_model]
- else:
- # Explicitly calculate the attention result so it can be accessed by a hook
- # This is off by default because it can easily eat through your GPU memory.
- result = self.hook_result(
- einsum(
- "batch pos head_index d_head, \
- head_index d_head d_model -> \
- batch pos head_index d_model",
- z,
- self.W_O,
- )
- ) # [batch, pos, head_index, d_model]
- out = (
- einops.reduce(
- result, "batch position index model->batch position model", "sum"
- )
- + self.b_O
- ) # [batch, pos, d_model]
- return out
-
- def apply_causal_mask(
- self,
- attn_scores: Float[
- torch.Tensor, "batch head_index pos pos_plus_past_kv_pos_offset"
- ],
- past_kv_pos_offset: int = 0,
- attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
- ):
- # The query context length is the number of positions we take queries from - if not using a past_kv_cache this is just the context length (for the current prompt), but if we're caching it can be different.
- query_ctx_length = attn_scores.size(-2)
- # The key context length is the number of positions in the past - this includes all positions in the cache
- # If not caching, query_ctx_length == key_ctx_length
- key_ctx_length = attn_scores.size(-1)
-
- assert (
- query_ctx_length + past_kv_pos_offset == key_ctx_length
- ), f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug."
-
- # Index back to front to ensure local attention works
- final_mask = self.mask[
- None, None, -query_ctx_length:, -key_ctx_length:
- ] # [1, 1, pos, pos]
- if attention_mask is not None:
- # Apply a causal mask to the attention scores considering the padding
- einsum_str = "batch head pos offset_pos, batch offset_pos -> batch head pos offset_pos"
- final_mask = einops.einsum(final_mask, attention_mask, einsum_str).bool()
-
- return torch.where(final_mask, attn_scores, self.IGNORE)
-
- def calculate_sin_cos_rotary(
- self,
- rotary_dim: int,
- n_ctx: int,
- base: int = 10000,
- dtype: torch.dtype = torch.float32,
- ) -> Tuple[
- Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]
- ]:
- """
- Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details
-
- Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent.
- To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is.
- """
- high_precision = torch.float32 if dtype != torch.float64 else torch.float64
- pos = torch.arange(n_ctx, dtype=high_precision)
- dim = torch.arange(rotary_dim // 2, dtype=high_precision)
-
- # A set of frequencies evenly spaced in log space
- freq = base ** (dim / (rotary_dim / 2))
- if self.cfg.original_architecture in ["GPTNeoXForCausalLM", "LlamaForCausalLM"]:
- freq = einops.repeat(freq, "d -> (2 d)")
- else:
- freq = einops.repeat(freq, "d -> (d 2)")
- # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency
- angles = pos[:, None] / freq[None, :]
- return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype)
-
- def rotate_every_two(
- self, x: Float[torch.Tensor, "... rotary_dim"]
- ) -> Float[torch.Tensor, "... rotary_dim"]:
- """
- Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0]
-
- The final axis of x must have even length.
-
- GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details.
- """
- rot_x = x.clone()
- if self.cfg.original_architecture in ["GPTNeoXForCausalLM", "LlamaForCausalLM"]:
- n = x.size(-1) // 2
- rot_x[..., :n] = -x[..., n:]
- rot_x[..., n:] = x[..., :n]
- else:
- rot_x[..., ::2] = -x[..., 1::2]
- rot_x[..., 1::2] = x[..., ::2]
-
- return rot_x
-
- def apply_rotary(
- self,
- x: Float[torch.Tensor, "batch pos head_index d_head"],
- past_kv_pos_offset=0,
- attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
- ) -> Float[torch.Tensor, "batch pos head_index d_head"]:
- # Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions)
- x_pos = x.size(1)
- x_rot = x[..., : self.cfg.rotary_dim]
- x_pass = x[..., self.cfg.rotary_dim :]
- x_flip = self.rotate_every_two(x_rot)
-
- if attention_mask is None:
- rotary_cos = self.rotary_cos[
- None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, :
- ]
- rotary_sin = self.rotary_sin[
- None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, :
- ]
- x_rotated = x_rot * rotary_cos + x_flip * rotary_sin
- else:
- offset_position_ids = get_offset_position_ids(
- past_kv_pos_offset, attention_mask
- )
- mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :]
- mask_rotary_sin = self.rotary_sin[offset_position_ids, None, :]
- x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin
-
- return torch.cat([x_rotated, x_pass], dim=-1)
-
- @staticmethod
- def create_alibi_slope(
- n_ctx: int, device: torch.device = None
- ) -> Float[torch.Tensor, "query key"]:
- """Create an ALiBi Slope Matrix.
-
- Create the slope matrix used in ALiBi, before it is multiplied by the head-specific scalar.
-
- See :meth:`create_alibi_bias` for the full ALiBi bias calculation.
-
- Examples:
-
- >>> Attention.create_alibi_slope(3)
- tensor([[ 0., 0., 0.],
- [-1., 0., 0.],
- [-2., -1., 0.]])
-
- >>> Attention.create_alibi_slope(4)
- tensor([[ 0., 0., 0., 0.],
- [-1., 0., 0., 0.],
- [-2., -1., 0., 0.],
- [-3., -2., -1., 0.]])
-
- Args:
- n_ctx: The maximum number of tokens in a prompt.
-
- Returns:
- A tensor of shape (n_ctx, n_ctx), where the upper triangle is zero and the lower
- triangle is decreasing by a constant slope of 1 (towards the bottom left corner).
- """
- # set rows as [[0,1,2...]]
- rows = torch.arange(n_ctx, device=device).unsqueeze(0)
-
- # Set cols as [[0],[1],[2]...]
- cols = torch.arange(n_ctx, device=device).unsqueeze(1)
-
- # Use broadcasting to create the desired lower triangular part of the matrix
- slope_matrix = rows - cols
-
- # Use the clamp method to set all positive values (upper right triangle) to
- return slope_matrix.clamp(max=0).to(torch.float32)
-
- @staticmethod
- def create_alibi_multipliers(
- n_heads: int, device: torch.device = None
- ) -> Float[torch.Tensor, "head_idx"]:
- """Create the ALiBi Scalar Multipliers for each Head.
-
- For n heads, the set of multipliers (m) is the geometric sequence that starts at 2^(-8/n), and
- uses that same value as its ratio. For example, with 8 heads the values would be [1/(2^1),
- 1/(2^2), ... , 1/(2^8)]. With 16 heads the values would be [1/(2^0.5), 1/(2^1), ... , 1/(2^8)].
-
- See :meth:`create_alibi_bias` for the full ALiBi bias calculation.
-
- Examples:
-
- >>> Attention.create_alibi_multipliers(8)
- tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039])
-
- >>> Attention.create_alibi_multipliers(16)
- tensor([0.7071, 0.5000, 0.3536, 0.2500, 0.1768, 0.1250, 0.0884, 0.0625, 0.0442, 0.0312,
- 0.0221, 0.0156, 0.0110, 0.0078, 0.0055, 0.0039])
-
- Args:
- n_heads: The number of heads in a layer.
- device: The device to create the tensor on.
-
- Returns:
- A tensor of shape (n_heads,) containing the scalar multiplier for each head.
- """
- # Calculate the starting value
- start = 2 ** (-8 / n_heads)
-
- # Generate the indices [0, 1, ..., n_heads-1]
- indices = torch.arange(n_heads, device=device)
-
- # Compute the multipliers, with the starting value being the same as the ratio
- multipliers = start * (start**indices)
-
- return multipliers
-
- @staticmethod
- def create_alibi_bias(
- n_heads: int, n_ctx: int, device: torch.device = None
- ) -> Float[torch.Tensor, "head_idx query key"]:
- """Create the ALiBi Bias for all Heads.
-
- Calculate the ALiBi bias (https://arxiv.org/pdf/2108.12409.pdf) for all heads in a layer.
-
- The broad idea behind ALiBi is to remove the positional encoding from the original transformer
- model, and instead apply a bias to each attention score. This bias is proportional to the
- distance between the query and key (i.e. it encourage paying less attention to more distant
- tokens), and is added to the attention scores before the softmax. It is used in models such as
- Bloom.
-
- Examples:
-
- >>> Attention.create_alibi_bias(2, 4, torch.device('cpu'))
- tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000],
- [-0.0625, 0.0000, 0.0000, 0.0000],
- [-0.1250, -0.0625, 0.0000, 0.0000],
- [-0.1875, -0.1250, -0.0625, 0.0000]],
- [[ 0.0000, 0.0000, 0.0000, 0.0000],
- [-0.0039, 0.0000, 0.0000, 0.0000],
- [-0.0078, -0.0039, 0.0000, 0.0000],
- [-0.0117, -0.0078, -0.0039, 0.0000]]])
-
- Args:
- n_heads: The number of heads in a layer.
- n_ctx: The maximum number of tokens in a prompt.
- device: The device to create the tensor on.
-
- Returns:
- The ALiBi bias that should be added to the attention scores before the softmax.
- """
- # Create the slope matrix
- slope: Float[torch.Tensor, "query key"] = Attention.create_alibi_slope(
- n_ctx, device
- )
-
- # Create the scalar multiplier for each head.
- multipliers: Float[
- torch.Tensor, "head_idx"
- ] = Attention.create_alibi_multipliers(n_heads, device)
-
- # The ALiBi bias is then m * slope_matrix
- alibi_bias = torch.einsum("ij,k->kij", slope, multipliers)
-
- return alibi_bias
-
-
-# MLP Layers
-class MLP(nn.Module):
- def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
- super().__init__()
- if isinstance(cfg, Dict):
- cfg = HookedTransformerConfig.from_dict(cfg)
- self.cfg = cfg
- self.W_in = nn.Parameter(
- torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype)
- )
- self.b_in = nn.Parameter(torch.zeros(self.cfg.d_mlp, dtype=cfg.dtype))
- self.W_out = nn.Parameter(
- torch.empty(self.cfg.d_mlp, self.cfg.d_model, dtype=cfg.dtype)
- )
- self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype))
-
- self.hook_pre = HookPoint() # [batch, pos, d_mlp]
- self.hook_post = HookPoint() # [batch, pos, d_mlp]
-
- if self.cfg.act_fn == "relu":
- self.act_fn = F.relu
- elif self.cfg.act_fn == "gelu":
- self.act_fn = F.gelu
- elif self.cfg.act_fn == "silu":
- self.act_fn = F.silu
- elif self.cfg.act_fn == "gelu_new":
- self.act_fn = gelu_new
- elif self.cfg.act_fn == "gelu_fast":
- self.act_fn = gelu_fast
- elif self.cfg.act_fn == "solu_ln":
- self.act_fn = solu
- # Hook taken between activation and layer norm
- self.hook_mid = HookPoint() # [batch, pos, d_mlp]
- if self.cfg.normalization_type == "LN":
- self.ln = LayerNorm(self.cfg, self.cfg.d_mlp)
- else:
- self.ln = LayerNormPre(self.cfg)
-
- else:
- raise ValueError(f"Invalid activation function name: {self.cfg.act_fn}")
-
- def forward(
- self, x: Float[torch.Tensor, "batch pos d_model"]
- ) -> Float[torch.Tensor, "batch pos d_model"]:
- # Technically, all these einsums could be done with a single matmul, but this is more readable.
- pre_act = self.hook_pre(
- einsum("batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_in)
- + self.b_in
- ) # [batch, pos, d_mlp]
- if not self.cfg.act_fn.endswith("_ln"):
- post_act = self.hook_post(self.act_fn(pre_act)) # [batch, pos, d_mlp]
- else:
- mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp]
- post_act = self.hook_post(self.ln(mid_act))
- return (
- einsum(
- "batch pos d_mlp, d_mlp d_model -> batch pos d_model",
- post_act,
- self.W_out,
- )
- + self.b_out
- )
-
-
-# TODO
-# not sure whether to fold this into MLP or not
-class GatedMLP(nn.Module):
- """
- The equation of a gated MLP:
- pre = x @ W_gate
- pre_linear = x @ W_in
- post = Gelu(pre) * (pre_linear) + b_in
- mlp_out = post @ W_out + b_out
-
- In one equation, mlp_out = (Gelu(x @ W_gate) * (x @ W_in) + b_in) @ W_out + b_out
- """
-
- def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
- super().__init__()
- if isinstance(cfg, Dict):
- cfg = HookedTransformerConfig.from_dict(cfg)
- self.cfg = cfg
- self.W_in = nn.Parameter(
- torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype)
- )
- self.W_gate = nn.Parameter(
- torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype)
- )
- self.b_in = nn.Parameter(torch.zeros(self.cfg.d_mlp, dtype=cfg.dtype))
- self.W_out = nn.Parameter(
- torch.empty(self.cfg.d_mlp, self.cfg.d_model, dtype=cfg.dtype)
- )
- self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype))
-
- # hook on gate output but before act_fn
- self.hook_pre = HookPoint() # [batch, pos, d_mlp]
- # hook on the linear component of the input
- self.hook_pre_linear = HookPoint() # [batch, pos, d_mlp]
- # hook on act_fn(gate_output) * W_in(x) + b_in
- self.hook_post = HookPoint() # [batch, pos, d_mlp]
-
- if self.cfg.act_fn == "relu":
- self.act_fn = F.relu
- elif self.cfg.act_fn == "gelu":
- self.act_fn = F.gelu
- elif self.cfg.act_fn == "silu":
- self.act_fn = F.silu
- elif self.cfg.act_fn == "gelu_new":
- self.act_fn = gelu_new
- elif self.cfg.act_fn == "gelu_fast":
- self.act_fn = gelu_fast
- elif self.cfg.act_fn == "solu_ln":
- self.act_fn = solu
- # Hook taken between activation and layer norm
- self.hook_mid = HookPoint() # [batch, pos, d_mlp]
- if self.cfg.normalization_type == "LN":
- self.ln = LayerNorm(self.cfg, self.cfg.d_mlp)
- else:
- self.ln = LayerNormPre(self.cfg)
-
- else:
- raise ValueError(f"Invalid activation function name: {self.cfg.act_fn}")
-
- def forward(
- self, x: Float[torch.Tensor, "batch pos d_model"]
- ) -> Float[torch.Tensor, "batch pos d_model"]:
- # Technically, all these einsums could be done with a single matmul, but this is more readable.
- pre_act = self.hook_pre(
- einsum(
- "batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_gate
- )
- ) # [batch, pos, d_mlp]
- if not self.cfg.act_fn.endswith("_ln"):
- pre_linear = self.hook_pre_linear(
- einsum(
- "batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_in
- )
- )
- post_act = self.hook_post(
- (self.act_fn(pre_act) * pre_linear) + self.b_in
- ) # [batch, pos, d_mlp]
- else:
- mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp]
- post_act = self.hook_post(self.ln(mid_act))
- return (
- einsum(
- "batch pos d_mlp, d_mlp d_model -> batch pos d_model",
- post_act,
- self.W_out,
- )
- + self.b_out
- )
-
-
-# Transformer Block
-class TransformerBlock(nn.Module):
- def __init__(self, cfg: Union[Dict, HookedTransformerConfig], block_index):
- super().__init__()
- if isinstance(cfg, Dict):
- cfg = HookedTransformerConfig.from_dict(cfg)
- self.cfg = cfg
- if self.cfg.normalization_type == "LN":
- self.ln1 = LayerNorm(cfg)
- if not self.cfg.attn_only:
- self.ln2 = LayerNorm(cfg)
- elif self.cfg.normalization_type == "LNPre":
- # We've folded in LayerNorm weights, so just need the center + scale parts
- self.ln1 = LayerNormPre(cfg)
- if not self.cfg.attn_only:
- self.ln2 = LayerNormPre(cfg)
- elif self.cfg.normalization_type == "RMS":
- self.ln1 = RMSNorm(cfg)
- if not self.cfg.attn_only:
- self.ln2 = RMSNorm(cfg)
- elif self.cfg.normalization_type == "RMSPre":
- self.ln1 = RMSNormPre(cfg)
- if not self.cfg.attn_only:
- self.ln2 = RMSNormPre(cfg)
- elif self.cfg.normalization_type is None:
- self.ln1 = nn.Identity()
- if not self.cfg.attn_only:
- self.ln2 = nn.Identity()
- else:
- logging.warning(
- f"Invalid normalization_type passed in {self.cfg.normalization_type}"
- )
-
- if not self.cfg.use_local_attn:
- self.attn = Attention(cfg, "global", block_index)
- else:
- assert self.cfg.attn_types is not None
- attn_type = self.cfg.attn_types[block_index]
- self.attn = Attention(cfg, attn_type, block_index)
- if not self.cfg.attn_only:
- if self.cfg.gated_mlp:
- self.mlp = GatedMLP(cfg)
- else:
- self.mlp = MLP(cfg)
-
- self.hook_attn_in = HookPoint() # [batch, pos, n_heads, d_model]
- self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model]
- self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model]
- self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model]
- self.hook_mlp_in = HookPoint() # [batch, pos, d_model]
-
- self.hook_attn_out = HookPoint() # [batch, pos, d_model]
- self.hook_mlp_out = HookPoint() # [batch, pos, d_model]
-
- self.hook_resid_pre = HookPoint() # [batch, pos, d_model]
- if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp:
- self.hook_resid_mid = HookPoint() # [batch, pos, d_model]
- self.hook_resid_post = HookPoint() # [batch, pos, d_model]
-
- def forward(
- self,
- resid_pre: Float[torch.Tensor, "batch pos d_model"],
- shortformer_pos_embed: Optional[
- Float[torch.Tensor, "batch pos d_model"]
- ] = None,
- past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None,
- attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
- ) -> Float[torch.Tensor, "batch pos d_model"]:
- """A single Transformer block.
-
- Args:
- resid_pre (torch.Tensor): The residual stream - shape [batch, pos, d_model]
- cache (HookedTransformerKeyValueCache): A cache of previous keys and values, used only when generating text. Defaults to None.
- shortformer_pos_embed (torch.Tensor, optional): Only used for positional_embeddings_type == "shortformer". The positional embeddings. See HookedTransformerConfig for details. Defaults to None.
- attention_mask (torch.Tensor, optional): The attention mask for padded tokens. Defaults to None.
-
- Returns:
- _type_: _description_
- """
- resid_pre = self.hook_resid_pre(resid_pre) # [batch, pos, d_model]
-
- def add_head_dimension(
- tensor: Float[torch.Tensor, "batch pos d_model"],
- clone_tensor=True,
- # `einops.repeat` uses a view in torch, so we generally clone the tensor to avoid using shared storage for each head entry
- ):
- repeated_tensor = einops.repeat(
- tensor,
- "batch pos d_model -> batch pos n_heads d_model",
- n_heads=self.cfg.n_heads,
- )
- if clone_tensor:
- return repeated_tensor.clone()
- else:
- return repeated_tensor
-
- if self.cfg.use_attn_in or self.cfg.use_split_qkv_input:
- # We're adding a head dimension
- attn_in = add_head_dimension(resid_pre, clone_tensor=False)
- if shortformer_pos_embed is not None:
- shortformer_pos_embed = add_head_dimension(shortformer_pos_embed)
- else:
- attn_in = resid_pre
-
- if self.cfg.use_attn_in:
- attn_in = self.hook_attn_in(attn_in.clone())
-
- if self.cfg.use_split_qkv_input:
- query_input = self.hook_q_input(attn_in.clone())
- key_input = self.hook_k_input(attn_in.clone())
- value_input = self.hook_v_input(attn_in.clone())
- else:
- query_input = attn_in
- key_input = attn_in
- value_input = attn_in
-
- attn_out = self.hook_attn_out(
- # hook the residual stream states that are used to calculate the
- # queries, keys and values, independently.
- # Then take the layer norm of these inputs, and pass these to the attention module.
- self.attn(
- query_input=self.ln1(query_input)
- + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
- key_input=self.ln1(key_input)
- + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
- value_input=self.ln1(value_input),
- past_kv_cache_entry=past_kv_cache_entry,
- attention_mask=attention_mask,
- )
- ) # [batch, pos, d_model]
- if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp:
- resid_mid = self.hook_resid_mid(
- resid_pre + attn_out
- ) # [batch, pos, d_model]
- mlp_in = (
- resid_mid
- if not self.cfg.use_hook_mlp_in
- else self.hook_mlp_in(resid_mid.clone())
- )
- normalized_resid_mid = self.ln2(mlp_in)
- mlp_out = self.hook_mlp_out(
- self.mlp(normalized_resid_mid)
- ) # [batch, pos, d_model]
- resid_post = self.hook_resid_post(
- resid_mid + mlp_out
- ) # [batch, pos, d_model]
- elif self.cfg.parallel_attn_mlp:
- # Dumb thing done by GPT-J, both MLP and Attn read from resid_pre and write to resid_post, no resid_mid used.
- # In GPT-J, LN1 and LN2 are tied, in GPT-NeoX they aren't.
- normalized_resid_pre_2 = self.ln2(
- resid_pre
- if not self.cfg.use_hook_mlp_in
- else self.hook_mlp_in(resid_pre.clone())
- )
- mlp_out = self.hook_mlp_out(
- self.mlp(normalized_resid_pre_2)
- ) # [batch, pos, d_model]
- resid_post = self.hook_resid_post(
- resid_pre + attn_out + mlp_out
- ) # [batch, pos, d_model]
- else:
- resid_post = self.hook_resid_post(
- resid_pre + attn_out
- ) # [batch, pos, d_model]
- return resid_post
-
-
-class BertBlock(nn.Module):
- """
- BERT Block. Similar to the TransformerBlock, except that the LayerNorms are applied after the attention and MLP, rather than before.
- """
-
- def __init__(self, cfg: HookedTransformerConfig):
- super().__init__()
- self.cfg = cfg
-
- self.attn = Attention(cfg)
- self.ln1 = LayerNorm(cfg)
- self.mlp = MLP(cfg)
- self.ln2 = LayerNorm(cfg)
-
- self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model]
- self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model]
- self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model]
-
- self.hook_attn_out = HookPoint() # [batch, pos, d_model]
- self.hook_mlp_in = HookPoint() # [batch, pos, d_model]
- self.hook_mlp_out = HookPoint() # [batch, pos, d_model]
- self.hook_resid_pre = HookPoint() # [batch, pos, d_model]
- self.hook_resid_mid = HookPoint() # [batch, pos, d_model]
- self.hook_resid_post = HookPoint() # [batch, pos, d_model]
- self.hook_normalized_resid_post = HookPoint() # [batch, pos, d_model]
-
- def forward(
- self,
- resid_pre: Float[torch.Tensor, "batch pos d_model"],
- additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None,
- ):
- resid_pre = self.hook_resid_pre(resid_pre)
-
- query_input = resid_pre
- key_input = resid_pre
- value_input = resid_pre
-
- if self.cfg.use_split_qkv_input:
-
- def add_head_dimension(tensor):
- return einops.repeat(
- tensor,
- "batch pos d_model -> batch pos n_heads d_model",
- n_heads=self.cfg.n_heads,
- ).clone()
-
- query_input = self.hook_q_input(add_head_dimension(query_input))
- key_input = self.hook_k_input(add_head_dimension(key_input))
- value_input = self.hook_v_input(add_head_dimension(value_input))
-
- attn_out = self.hook_attn_out(
- self.attn(
- query_input,
- key_input,
- value_input,
- additive_attention_mask=additive_attention_mask,
- )
- )
- resid_mid = self.hook_resid_mid(resid_pre + attn_out)
-
- mlp_in = (
- resid_mid
- if not self.cfg.use_hook_mlp_in
- else self.hook_mlp_in(resid_mid.clone())
- )
- normalized_resid_mid = self.ln1(mlp_in)
- mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid))
- resid_post = self.hook_resid_post(normalized_resid_mid + mlp_out)
- normalized_resid_post = self.hook_normalized_resid_post(self.ln2(resid_post))
-
- return normalized_resid_post
diff --git a/transformer_lens/components/__init__.py b/transformer_lens/components/__init__.py
new file mode 100644
index 000000000..47677426a
--- /dev/null
+++ b/transformer_lens/components/__init__.py
@@ -0,0 +1,29 @@
+"""Hooked Transformer Components.
+
+This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`)
+needed to create many different types of generative language models. They are used by
+:class:`transformer_lens.HookedTransformer`.
+"""
+# Independent classes
+from .abstract_attention import AbstractAttention
+from .layer_norm import LayerNorm
+from .layer_norm_pre import LayerNormPre
+from .pos_embed import PosEmbed
+from .rms_norm import RMSNorm
+from .rms_norm_pre import RMSNormPre
+from .token_typed_embed import TokenTypeEmbed
+from .unembed import Unembed
+
+# Only dependent on independent modules
+from .attention import Attention
+from .bert_mlm_head import BertMLMHead
+from .embed import Embed
+from .gated_mlp import GatedMLP
+from .grouped_query_attention import GroupedQueryAttention
+from .mlp import MLP
+
+# Interdependent modules
+from .bert_block import BertBlock
+from .bert_embed import BertEmbed
+from .moe import MoE
+from .transformer_block import TransformerBlock
diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py
new file mode 100644
index 000000000..c6fc0e5e8
--- /dev/null
+++ b/transformer_lens/components/abstract_attention.py
@@ -0,0 +1,654 @@
+from abc import ABC
+from typing import Dict, Optional, Tuple, Union
+
+import einops
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from better_abc import abstract_attribute
+from fancy_einsum import einsum
+from jaxtyping import Float, Int
+from transformers.utils import is_bitsandbytes_available
+
+from transformer_lens.FactoredMatrix import FactoredMatrix
+from transformer_lens.hook_points import HookPoint
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry
+from transformer_lens.utils import get_offset_position_ids
+
+if is_bitsandbytes_available():
+ import bitsandbytes as bnb
+ from bitsandbytes.nn.modules import Params4bit
+
+
+class AbstractAttention(ABC, nn.Module):
+ alibi: Union[torch.Tensor, None]
+
+ def __init__(
+ self,
+ cfg: Union[Dict, HookedTransformerConfig],
+ attn_type: str = "global",
+ layer_id: Optional[int] = None,
+ ):
+ """Abstract Base Class of Attention Blocks, featuring common functionality of both Attention and GroupedQueryAttention blocks.
+
+ Query and Output projections are defined in this class as they are the same for regular and grouped query attention.
+ Attributes related to Key and Value projections are abstract as their implementations may differ. For example, in GroupedQueryAttention there are less query and key heads than value heads.
+ To enforce implementation of W_K, W_V, b_K, and b_V by child classes, the better_abc.abstract_attribute class is used. See here for details: https://stackoverflow.com/questions/23831510/abstract-attribute-not-property.
+
+ Args:
+ cfg (Union[Dict, HookedTransformerConfig]): Config
+ attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global".
+ layer_id (int, optional): The index of the current layer. Used by the Mistral models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.
+ """
+ super().__init__()
+ if isinstance(cfg, Dict):
+ cfg = HookedTransformerConfig.from_dict(cfg)
+ self.cfg = cfg
+
+ if self.cfg.load_in_4bit:
+ nq = int((cfg.d_model * cfg.d_model) / 2)
+ self.W_Q = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
+ self.W_O = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
+ else:
+ self.W_Q = nn.Parameter(
+ torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype)
+ )
+ self.W_O = nn.Parameter(
+ torch.empty(self.cfg.n_heads, self.cfg.d_head, self.cfg.d_model, dtype=cfg.dtype)
+ )
+ self.W_K = abstract_attribute()
+ self.W_V = abstract_attribute()
+
+ self.b_Q = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype))
+ self.b_K: nn.Parameter = abstract_attribute()
+ self.b_V: nn.Parameter = abstract_attribute()
+ self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype))
+
+ self.attn_type = attn_type
+ # Create a max_ctx x max_ctx mask, with True iff that query position
+ # can attend to that key position (query is first axis, key is second axis)
+ causal_mask = torch.tril(torch.ones((self.cfg.n_ctx, self.cfg.n_ctx)).bool())
+ if self.attn_type == "global":
+ # For global attention, this is a lower triangular matrix - key <= query
+ self.register_buffer("mask", causal_mask)
+ elif self.attn_type == "local":
+ # For local, this is banded, query - window_size < key <= query
+ assert isinstance(self.cfg.window_size, int)
+ self.register_buffer("mask", torch.triu(causal_mask, 1 - self.cfg.window_size))
+ else:
+ raise ValueError(f"Invalid attention type: {self.attn_type}")
+
+ self.register_buffer("IGNORE", torch.tensor(-torch.inf))
+
+ self.layer_id = layer_id
+
+ # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability?
+ if self.cfg.use_attn_scale:
+ self.attn_scale = np.sqrt(self.cfg.d_head)
+ else:
+ self.attn_scale = 1.0
+ if self.cfg.scale_attn_by_inverse_layer_idx:
+ assert self.layer_id is not None # keep mypy happy
+ self.attn_scale *= self.layer_id + 1
+
+ self.hook_k = HookPoint() # [batch, pos, head_index, d_head]
+ self.hook_q = HookPoint() # [batch, pos, head_index, d_head]
+ self.hook_v = HookPoint() # [batch, pos, head_index, d_head]
+ self.hook_z = HookPoint() # [batch, pos, head_index, d_head]
+ self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos]
+ self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos]
+ self.hook_result = HookPoint() # [batch, pos, head_index, d_model]
+
+ # See HookedTransformerConfig for more details.
+ if self.cfg.positional_embedding_type == "shortformer":
+ # This tracks the input to the keys and queries, which is resid_pre + pos_embeds
+ self.hook_attn_input = HookPoint() # [batch, pos, d_model]
+ elif self.cfg.positional_embedding_type == "rotary":
+ # Applies a rotation to each two-element chunk of keys and queries pre dot producting to bake in relative position. See HookedTransformerConfig for details
+ self.hook_rot_k = HookPoint()
+ self.hook_rot_q = HookPoint()
+ assert self.cfg.rotary_dim is not None # keep mypy happy
+ sin, cos = self.calculate_sin_cos_rotary(
+ self.cfg.rotary_dim,
+ self.cfg.n_ctx,
+ base=self.cfg.rotary_base,
+ dtype=self.cfg.dtype,
+ )
+ self.register_buffer("rotary_sin", sin)
+ self.register_buffer("rotary_cos", cos)
+ elif self.cfg.positional_embedding_type == "alibi":
+ # ALiBi bias wil be constructed on the first forward pass.
+ # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage.
+ self.alibi = None
+
+ @property
+ def OV(self) -> FactoredMatrix:
+ """
+ OV-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity between the value vector and the output of the layer, the output is purely determined by the matrix W_OV = W_V @ W_O, and not W_V or W_O individually. (Mathematically, for a single head, output == pattern @ residual @ W_V @ W_O, see the glossary for more)
+
+ Done in the order W_V, W_O because the paper uses left-multiplying weight matrices, and TransformerLens uses right-multiplying, sorry!
+
+ Returns a FactoredMatrix, with left matrix W_V [head_index, d_model, d_head] and right matrix W_O [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model]. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the OV circuit of a head k, attn.OV[k] works.
+ """
+ return FactoredMatrix(self.W_V, self.W_O)
+
+ @property
+ def QK(self) -> FactoredMatrix:
+ """
+ QK-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity in the key-query dot product, the output is purely determined by the matrix W_QK = W_Q.T @ W_K, and not W_Q or W_K individually. (Mathematically, for a single head, pattern = destination_residual.T @ W_Q.T @ W_K @ source-residual, see the glossary for more).
+
+ Done in the order Q on the left, K on the right, because the pattern has dimensions [destination_pos, source_pos]
+
+ Returns a FactoredMatrix, with left matrix W_Q [head_index, d_model, d_head] and right matrix W_K.T [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model] matrix. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the QK circuit of a head k, attn.QK[k] works.
+ """
+ W_K_transpose = einops.rearrange(
+ self.W_K, "head_index d_model d_head -> head_index d_head d_model"
+ )
+ return FactoredMatrix(self.W_Q, W_K_transpose)
+
+ def forward(
+ self,
+ query_input: Union[
+ Float[torch.Tensor, "batch pos d_model"],
+ Float[torch.Tensor, "batch pos head_index d_model"],
+ ],
+ key_input: Union[
+ Float[torch.Tensor, "batch pos d_model"],
+ Float[torch.Tensor, "batch pos head_index d_model"],
+ Float[torch.Tensor, "batch pos kv_head_index d_model"],
+ ],
+ value_input: Union[
+ Float[torch.Tensor, "batch pos d_model"],
+ Float[torch.Tensor, "batch pos head_index d_model"],
+ Float[torch.Tensor, "batch pos kv_head_index d_model"],
+ ],
+ past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None,
+ additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None,
+ attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
+ ) -> Float[torch.Tensor, "batch pos d_model"]:
+ """
+ shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details
+ past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None
+ additive_attention_mask is an optional mask to add to the attention weights. Defaults to None.
+ attention_mask is the attention mask for padded tokens. Defaults to None.
+ """
+
+ q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)
+
+ if past_kv_cache_entry is not None:
+ # Appends the new keys and values to the cached values, and automatically updates the cache
+ kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1)
+ k, v = past_kv_cache_entry.append(k, v)
+ else:
+ # Not using a cache
+ kv_cache_pos_offset = 0
+
+ if self.cfg.positional_embedding_type == "rotary":
+ q = self.hook_rot_q(self.apply_rotary(q, kv_cache_pos_offset, attention_mask))
+ k = self.hook_rot_k(
+ self.apply_rotary(k, 0, attention_mask)
+ ) # keys are cached so no offset
+
+ if self.cfg.dtype not in [torch.float32, torch.float64]:
+ # If using 16 bits, increase the precision to avoid numerical instabilities
+ q = q.to(torch.float32)
+ k = k.to(torch.float32)
+
+ attn_scores = self.calculate_attention_scores(
+ q, k
+ ) # [batch, head_index, query_pos, key_pos]
+
+ if self.cfg.positional_embedding_type == "alibi":
+ query_ctx = attn_scores.size(-2)
+ # The key context length is the number of positions in the past - this includes all positions in the cache
+ key_ctx = attn_scores.size(-1)
+
+ # only recompute when necessary to increase efficiency.
+ if self.alibi is None or key_ctx > self.alibi.size(-1):
+ self.alibi = AbstractAttention.create_alibi_bias(
+ self.cfg.n_heads, key_ctx, self.cfg.device
+ )
+
+ attn_scores += self.alibi[
+ :, :query_ctx, :key_ctx
+ ] # [batch, head_index, query_pos, key_pos]
+
+ if self.cfg.attention_dir == "causal":
+ # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask.
+ attn_scores = self.apply_causal_mask(
+ attn_scores, kv_cache_pos_offset, attention_mask
+ ) # [batch, head_index, query_pos, key_pos]
+ if additive_attention_mask is not None:
+ attn_scores += additive_attention_mask
+
+ attn_scores = self.hook_attn_scores(attn_scores)
+ pattern = F.softmax(attn_scores, dim=-1)
+ pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
+ pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
+ pattern = pattern.to(self.cfg.dtype)
+ pattern = pattern.to(v.device)
+ z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head]
+ if not self.cfg.use_attn_result:
+ if self.cfg.load_in_4bit:
+ # call bitsandbytes method to dequantize and multiply
+ out = bnb.matmul_4bit(
+ z.reshape(z.shape[0], z.shape[1], self.cfg.d_model),
+ self.W_O.t(),
+ # bias=self.W_O.t(),
+ bias=None,
+ quant_state=self.W_O.quant_state,
+ )
+ +self.b_O
+ else:
+ out = (
+ (
+ einsum(
+ "batch pos head_index d_head, \
+ head_index d_head d_model -> \
+ batch pos d_model",
+ z,
+ self.W_O,
+ )
+ )
+ + self.b_O
+ ) # [batch, pos, d_model]
+ else:
+ # Explicitly calculate the attention result so it can be accessed by a hook
+ # This is off by default because it can easily eat through your GPU memory.
+ if self.cfg.load_in_4bit:
+ result = self.hook_result(
+ bnb.matmul_4bit(
+ z.reshape(z.shape[0], z.shape[1], self.cfg.d_model),
+ self.W_O.t(),
+ bias=None,
+ quant_state=self.W_O.quant_state,
+ )
+ )
+ else:
+ result = self.hook_result(
+ einsum(
+ "batch pos head_index d_head, \
+ head_index d_head d_model -> \
+ batch pos head_index d_model",
+ z,
+ self.W_O,
+ )
+ ) # [batch, pos, head_index, d_model]
+ out = (
+ einops.reduce(result, "batch position index model->batch position model", "sum")
+ + self.b_O
+ ) # [batch, pos, d_model]
+ return out
+
+ def calculate_qkv_matrices(
+ self,
+ query_input: Union[
+ Float[torch.Tensor, "batch pos d_model"],
+ Float[torch.Tensor, "batch pos head_index d_model"],
+ ],
+ key_input: Union[
+ Float[torch.Tensor, "batch pos d_model"],
+ Float[torch.Tensor, "batch pos head_index d_model"],
+ ],
+ value_input: Union[
+ Float[torch.Tensor, "batch pos d_model"],
+ Float[torch.Tensor, "batch pos head_index d_model"],
+ ],
+ ) -> Tuple[
+ Float[torch.Tensor, "batch pos head_index d_head"],
+ Float[torch.Tensor, "batch pos head_index d_head"],
+ Float[torch.Tensor, "batch pos head_index d_head"],
+ ]:
+ if self.cfg.use_split_qkv_input or self.cfg.use_attn_in:
+ qkv_einops_string = "batch pos head_index d_model"
+ else:
+ qkv_einops_string = "batch pos d_model"
+
+ if self.cfg.load_in_4bit:
+ q = self.hook_q(
+ # call bitsandbytes method to dequantize and multiply
+ bnb.matmul_4bit(
+ query_input,
+ self.W_Q.t(),
+ bias=None,
+ quant_state=self.W_Q.quant_state,
+ ).reshape(
+ query_input.shape[0],
+ query_input.shape[1],
+ self.cfg.n_heads,
+ self.cfg.d_head,
+ )
+ + self.b_Q
+ )
+ else:
+ q = self.hook_q(
+ einsum(
+ f"{qkv_einops_string}, head_index d_model d_head \
+ -> batch pos head_index d_head",
+ query_input,
+ self.W_Q,
+ )
+ + self.b_Q
+ ) # [batch, pos, head_index, d_head]
+ if self.cfg.load_in_4bit:
+ assert isinstance(self.W_K, Params4bit)
+ k = self.hook_k(
+ # call bitsandbytes method to dequantize and multiply
+ bnb.matmul_4bit(
+ key_input, self.W_K.t(), bias=None, quant_state=self.W_K.quant_state
+ ).reshape(
+ key_input.shape[0],
+ key_input.shape[1],
+ self.cfg.n_heads,
+ self.cfg.d_head,
+ )
+ + self.b_K
+ )
+ else:
+ k = self.hook_k(
+ einsum(
+ f"{qkv_einops_string}, head_index d_model d_head \
+ -> batch pos head_index d_head",
+ key_input,
+ self.W_K,
+ )
+ + self.b_K
+ ) # [batch, pos, head_index, d_head]
+
+ if self.cfg.load_in_4bit:
+ assert isinstance(self.W_V, Params4bit)
+ v = self.hook_v(
+ # call bitsandbytes method to dequantize and multiply
+ bnb.matmul_4bit(
+ value_input,
+ self.W_V.t(),
+ bias=None,
+ quant_state=self.W_V.quant_state,
+ ).reshape(
+ value_input.shape[0],
+ value_input.shape[1],
+ self.cfg.n_heads,
+ self.cfg.d_head,
+ )
+ + self.b_V
+ )
+ else:
+ v = self.hook_v(
+ einsum(
+ f"{qkv_einops_string}, head_index d_model d_head \
+ -> batch pos head_index d_head",
+ value_input,
+ self.W_V,
+ )
+ + self.b_V
+ ) # [batch, pos, head_index, d_head]
+ return q, k, v
+
+ def calculate_attention_scores(
+ self,
+ q: Float[torch.Tensor, "batch query_pos head_index d_head"],
+ k: Float[torch.Tensor, "batch key_pos head_index d_head"],
+ ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]:
+ attn_scores = (
+ einsum(
+ "batch query_pos head_index d_head, \
+ batch key_pos head_index d_head \
+ -> batch head_index query_pos key_pos",
+ q,
+ k,
+ )
+ / self.attn_scale
+ )
+ return attn_scores
+
+ def calculate_z_scores(
+ self,
+ v: Float[torch.Tensor, "batch key_pos head_index d_head"],
+ pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"],
+ ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]:
+ z = self.hook_z(
+ einsum(
+ "batch key_pos head_index d_head, \
+ batch head_index query_pos key_pos -> \
+ batch query_pos head_index d_head",
+ v,
+ pattern,
+ )
+ )
+ return z
+
+ def apply_causal_mask(
+ self,
+ attn_scores: Float[torch.Tensor, "batch head_index pos pos_plus_past_kv_pos_offset"],
+ past_kv_pos_offset: int = 0,
+ attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
+ ):
+ # The query context length is the number of positions we take queries from - if not using a past_kv_cache this is just the context length (for the current prompt), but if we're caching it can be different.
+ query_ctx_length = attn_scores.size(-2)
+ # The key context length is the number of positions in the past - this includes all positions in the cache
+ # If not caching, query_ctx_length == key_ctx_length
+ key_ctx_length = attn_scores.size(-1)
+
+ assert (
+ query_ctx_length + past_kv_pos_offset == key_ctx_length
+ ), f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug."
+
+ # Index back to front to ensure local attention works
+ final_mask = self.mask[None, None, -query_ctx_length:, -key_ctx_length:] # [1, 1, pos, pos]
+ if attention_mask is not None:
+ # Apply a causal mask to the attention scores considering the padding
+ einsum_str = "batch head pos offset_pos, batch offset_pos -> batch head pos offset_pos"
+ final_mask = final_mask.to(attention_mask.device)
+ final_mask = einops.einsum(final_mask, attention_mask, einsum_str).bool()
+
+ attn_scores = attn_scores.to(final_mask.device)
+ return torch.where(final_mask, attn_scores, self.IGNORE)
+
+ def calculate_sin_cos_rotary(
+ self,
+ rotary_dim: int,
+ n_ctx: int,
+ base: int = 10000,
+ dtype: torch.dtype = torch.float32,
+ ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]:
+ """
+ Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details
+
+ Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent.
+ To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is.
+ """
+ high_precision = torch.float32 if dtype != torch.float64 else torch.float64
+ pos = torch.arange(n_ctx, dtype=high_precision)
+ dim = torch.arange(rotary_dim // 2, dtype=high_precision)
+
+ # A set of frequencies evenly spaced in log space
+ freq = base ** (dim / (rotary_dim / 2))
+ if self.cfg.rotary_adjacent_pairs:
+ freq = einops.repeat(freq, "d -> (d 2)")
+ else:
+ freq = einops.repeat(freq, "d -> (2 d)")
+ # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency
+ angles = pos[:, None] / freq[None, :]
+ return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype)
+
+ def rotate_every_two(
+ self, x: Float[torch.Tensor, "... rotary_dim"]
+ ) -> Float[torch.Tensor, "... rotary_dim"]:
+ """
+ Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0]
+
+ The final axis of x must have even length.
+
+ GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details.
+ """
+ rot_x = x.clone()
+ if self.cfg.rotary_adjacent_pairs:
+ rot_x[..., ::2] = -x[..., 1::2]
+ rot_x[..., 1::2] = x[..., ::2]
+ else:
+ n = x.size(-1) // 2
+ rot_x[..., :n] = -x[..., n:]
+ rot_x[..., n:] = x[..., :n]
+
+ return rot_x
+
+ def apply_rotary(
+ self,
+ x: Float[torch.Tensor, "batch pos head_index d_head"],
+ past_kv_pos_offset=0,
+ attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
+ ) -> Float[torch.Tensor, "batch pos head_index d_head"]:
+ # Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions)
+ x_pos = x.size(1)
+ x_rot = x[..., : self.cfg.rotary_dim]
+ x_pass = x[..., self.cfg.rotary_dim :]
+ x_flip = self.rotate_every_two(x_rot)
+
+ if attention_mask is None:
+ rotary_cos = self.rotary_cos[
+ None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, :
+ ]
+ rotary_sin = self.rotary_sin[
+ None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, :
+ ]
+ x_rotated = x_rot * rotary_cos + x_flip * rotary_sin
+ else:
+ offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask)
+ offset_position_ids = offset_position_ids.to(self.rotary_cos.device)
+ mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :]
+ mask_rotary_sin = self.rotary_sin[offset_position_ids, None, :]
+ x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin
+
+ return torch.cat([x_rotated, x_pass], dim=-1)
+
+ @staticmethod
+ def create_alibi_slope(
+ n_ctx: int, device: Optional[Union[str, torch.device]] = None
+ ) -> Float[torch.Tensor, "query key"]:
+ """Create an ALiBi Slope Matrix.
+
+ Create the slope matrix used in ALiBi, before it is multiplied by the head-specific scalar.
+
+ See :meth:`create_alibi_bias` for the full ALiBi bias calculation.
+
+ Examples:
+
+ >>> AbstractAttention.create_alibi_slope(3)
+ tensor([[ 0., 0., 0.],
+ [-1., 0., 0.],
+ [-2., -1., 0.]])
+
+ >>> AbstractAttention.create_alibi_slope(4)
+ tensor([[ 0., 0., 0., 0.],
+ [-1., 0., 0., 0.],
+ [-2., -1., 0., 0.],
+ [-3., -2., -1., 0.]])
+
+ Args:
+ n_ctx: The maximum number of tokens in a prompt.
+
+ Returns:
+ A tensor of shape (n_ctx, n_ctx), where the upper triangle is zero and the lower
+ triangle is decreasing by a constant slope of 1 (towards the bottom left corner).
+ """
+ # set rows as [[0,1,2...]]
+ rows = torch.arange(n_ctx, device=device).unsqueeze(0)
+
+ # Set cols as [[0],[1],[2]...]
+ cols = torch.arange(n_ctx, device=device).unsqueeze(1)
+
+ # Use broadcasting to create the desired lower triangular part of the matrix
+ slope_matrix = rows - cols
+
+ # Use the clamp method to set all positive values (upper right triangle) to
+ return slope_matrix.clamp(max=0).to(torch.float32)
+
+ @staticmethod
+ def create_alibi_multipliers(
+ n_heads: int, device: Optional[Union[str, torch.device]] = None
+ ) -> Float[torch.Tensor, "head_idx"]:
+ """Create the ALiBi Scalar Multipliers for each Head.
+
+ For n heads, the set of multipliers (m) is the geometric sequence that starts at 2^(-8/n), and
+ uses that same value as its ratio. For example, with 8 heads the values would be [1/(2^1),
+ 1/(2^2), ... , 1/(2^8)]. With 16 heads the values would be [1/(2^0.5), 1/(2^1), ... , 1/(2^8)].
+
+ See :meth:`create_alibi_bias` for the full ALiBi bias calculation.
+
+ Examples:
+
+ >>> AbstractAttention.create_alibi_multipliers(8)
+ tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039])
+
+ >>> AbstractAttention.create_alibi_multipliers(16)
+ tensor([0.7071, 0.5000, 0.3536, 0.2500, 0.1768, 0.1250, 0.0884, 0.0625, 0.0442, 0.0312,
+ 0.0221, 0.0156, 0.0110, 0.0078, 0.0055, 0.0039])
+
+ Args:
+ n_heads: The number of heads in a layer.
+ device: The device to create the tensor on.
+
+ Returns:
+ A tensor of shape (n_heads,) containing the scalar multiplier for each head.
+ """
+ # Calculate the starting value
+ start = 2 ** (-8 / n_heads)
+
+ # Generate the indices [0, 1, ..., n_heads-1]
+ indices = torch.arange(n_heads, device=device)
+
+ # Compute the multipliers, with the starting value being the same as the ratio
+ multipliers = start * (start**indices)
+
+ return multipliers
+
+ @staticmethod
+ def create_alibi_bias(
+ n_heads: int, n_ctx: int, device: Optional[Union[torch.device, str]] = None
+ ) -> Float[torch.Tensor, "head_idx query key"]:
+ """Create the ALiBi Bias for all Heads.
+
+ Calculate the ALiBi bias (https://arxiv.org/pdf/2108.12409.pdf) for all heads in a layer.
+
+ The broad idea behind ALiBi is to remove the positional encoding from the original transformer
+ model, and instead apply a bias to each attention score. This bias is proportional to the
+ distance between the query and key (i.e. it encourage paying less attention to more distant
+ tokens), and is added to the attention scores before the softmax. It is used in models such as
+ Bloom.
+
+ Examples:
+
+ >>> AbstractAttention.create_alibi_bias(2, 4, torch.device('cpu'))
+ tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000],
+ [-0.0625, 0.0000, 0.0000, 0.0000],
+ [-0.1250, -0.0625, 0.0000, 0.0000],
+ [-0.1875, -0.1250, -0.0625, 0.0000]],
+ [[ 0.0000, 0.0000, 0.0000, 0.0000],
+ [-0.0039, 0.0000, 0.0000, 0.0000],
+ [-0.0078, -0.0039, 0.0000, 0.0000],
+ [-0.0117, -0.0078, -0.0039, 0.0000]]])
+
+ Args:
+ n_heads: The number of heads in a layer.
+ n_ctx: The maximum number of tokens in a prompt.
+ device: The device to create the tensor on.
+
+ Returns:
+ The ALiBi bias that should be added to the attention scores before the softmax.
+ """
+ # Create the slope matrix
+ slope: Float[torch.Tensor, "query key"] = AbstractAttention.create_alibi_slope(
+ n_ctx, device
+ )
+
+ # Create the scalar multiplier for each head.
+ multipliers: Float[torch.Tensor, "head_idx"] = AbstractAttention.create_alibi_multipliers(
+ n_heads, device
+ )
+
+ # The ALiBi bias is then m * slope_matrix
+ alibi_bias = torch.einsum("ij,k->kij", slope, multipliers)
+
+ return alibi_bias
diff --git a/transformer_lens/components/attention.py b/transformer_lens/components/attention.py
new file mode 100644
index 000000000..c463361c5
--- /dev/null
+++ b/transformer_lens/components/attention.py
@@ -0,0 +1,53 @@
+"""Hooked Transformer Attention Component.
+
+This module contains all the component :class:`Attention`.
+"""
+from typing import Dict, Optional, Union
+
+import torch
+import torch.nn as nn
+from transformers.utils import is_bitsandbytes_available
+
+from transformer_lens.components import AbstractAttention
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+
+if is_bitsandbytes_available():
+ from bitsandbytes.nn.modules import Params4bit
+
+
+# Attention
+class Attention(AbstractAttention):
+ def __init__(
+ self,
+ cfg: Union[Dict, HookedTransformerConfig],
+ attn_type: str = "global",
+ layer_id: Optional[int] = None,
+ ):
+ """Attention Block - params have shape [head_index, d_model, d_head] (or [head_index, d_head, d_model] for W_O) and multiply on the right. attn_scores refers to query key dot product immediately before attention softmax
+
+ Convention: All attention pattern-style matrices have shape [batch, head_index, query_pos, key_pos]
+
+ Args:
+ cfg (Union[Dict, HookedTransformerConfig]): Config
+ attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global".
+ layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.
+ """
+ super().__init__(cfg, attn_type, layer_id)
+ if isinstance(cfg, Dict):
+ cfg = HookedTransformerConfig.from_dict(cfg)
+ self.cfg = cfg
+
+ if cfg.load_in_4bit:
+ # 4-bit quantization convention
+ nq = int((cfg.d_model * cfg.d_model) / 2)
+ self.W_K = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
+ self.W_V = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
+ else:
+ self.W_K = nn.Parameter(
+ torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype)
+ )
+ self.W_V = nn.Parameter(
+ torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype)
+ )
+ self.b_K = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype))
+ self.b_V = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype))
diff --git a/transformer_lens/components/bert_block.py b/transformer_lens/components/bert_block.py
new file mode 100644
index 000000000..3740d914b
--- /dev/null
+++ b/transformer_lens/components/bert_block.py
@@ -0,0 +1,76 @@
+"""Hooked Transformer Bert Block Component.
+
+This module contains all the component :class:`BertBlock`.
+"""
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from jaxtyping import Float
+
+from transformer_lens.components import MLP, Attention, LayerNorm
+from transformer_lens.hook_points import HookPoint
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+from transformer_lens.utils import repeat_along_head_dimension
+
+
+class BertBlock(nn.Module):
+ """
+ BERT Block. Similar to the TransformerBlock, except that the LayerNorms are applied after the attention and MLP, rather than before.
+ """
+
+ def __init__(self, cfg: HookedTransformerConfig):
+ super().__init__()
+ self.cfg = cfg
+
+ self.attn = Attention(cfg)
+ self.ln1 = LayerNorm(cfg)
+ self.mlp = MLP(cfg)
+ self.ln2 = LayerNorm(cfg)
+
+ self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model]
+ self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model]
+ self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model]
+
+ self.hook_attn_out = HookPoint() # [batch, pos, d_model]
+ self.hook_mlp_in = HookPoint() # [batch, pos, d_model]
+ self.hook_mlp_out = HookPoint() # [batch, pos, d_model]
+ self.hook_resid_pre = HookPoint() # [batch, pos, d_model]
+ self.hook_resid_mid = HookPoint() # [batch, pos, d_model]
+ self.hook_resid_post = HookPoint() # [batch, pos, d_model]
+ self.hook_normalized_resid_post = HookPoint() # [batch, pos, d_model]
+
+ def forward(
+ self,
+ resid_pre: Float[torch.Tensor, "batch pos d_model"],
+ additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None,
+ ):
+ resid_pre = self.hook_resid_pre(resid_pre)
+
+ query_input = resid_pre
+ key_input = resid_pre
+ value_input = resid_pre
+
+ if self.cfg.use_split_qkv_input:
+ n_heads = self.cfg.n_heads
+ query_input = self.hook_q_input(repeat_along_head_dimension(query_input, n_heads))
+ key_input = self.hook_k_input(repeat_along_head_dimension(key_input, n_heads))
+ value_input = self.hook_v_input(repeat_along_head_dimension(value_input, n_heads))
+
+ attn_out = self.hook_attn_out(
+ self.attn(
+ query_input,
+ key_input,
+ value_input,
+ additive_attention_mask=additive_attention_mask,
+ )
+ )
+ resid_mid = self.hook_resid_mid(resid_pre + attn_out)
+
+ mlp_in = resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone())
+ normalized_resid_mid = self.ln1(mlp_in)
+ mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid))
+ resid_post = self.hook_resid_post(normalized_resid_mid + mlp_out)
+ normalized_resid_post = self.hook_normalized_resid_post(self.ln2(resid_post))
+
+ return normalized_resid_post
diff --git a/transformer_lens/components/bert_embed.py b/transformer_lens/components/bert_embed.py
new file mode 100644
index 000000000..9495b9798
--- /dev/null
+++ b/transformer_lens/components/bert_embed.py
@@ -0,0 +1,54 @@
+"""Hooked Transformer Bert Embed Component.
+
+This module contains all the component :class:`BertEmbed`.
+"""
+from typing import Dict, Optional, Union
+
+import einops
+import torch
+import torch.nn as nn
+from jaxtyping import Int
+
+from transformer_lens.components import Embed, LayerNorm, PosEmbed, TokenTypeEmbed
+from transformer_lens.hook_points import HookPoint
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+
+
+class BertEmbed(nn.Module):
+ """
+ Custom embedding layer for a BERT-like model. This module computes the sum of the token, positional and token-type embeddings and takes the layer norm of the result.
+ """
+
+ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
+ super().__init__()
+ if isinstance(cfg, Dict):
+ cfg = HookedTransformerConfig.from_dict(cfg)
+ self.cfg = cfg
+ self.embed = Embed(cfg)
+ self.pos_embed = PosEmbed(cfg)
+ self.token_type_embed = TokenTypeEmbed(cfg)
+ self.ln = LayerNorm(cfg)
+
+ self.hook_embed = HookPoint()
+ self.hook_pos_embed = HookPoint()
+ self.hook_token_type_embed = HookPoint()
+
+ def forward(
+ self,
+ input_ids: Int[torch.Tensor, "batch pos"],
+ token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
+ ):
+ base_index_id = torch.arange(input_ids.shape[1], device=input_ids.device)
+ index_ids = einops.repeat(base_index_id, "pos -> batch pos", batch=input_ids.shape[0])
+ if token_type_ids is None:
+ token_type_ids = torch.zeros_like(input_ids)
+
+ word_embeddings_out = self.hook_embed(self.embed(input_ids))
+ position_embeddings_out = self.hook_pos_embed(self.pos_embed(index_ids))
+ token_type_embeddings_out = self.hook_token_type_embed(
+ self.token_type_embed(token_type_ids)
+ )
+
+ embeddings_out = word_embeddings_out + position_embeddings_out + token_type_embeddings_out
+ layer_norm_out = self.ln(embeddings_out)
+ return layer_norm_out
diff --git a/transformer_lens/components/bert_mlm_head.py b/transformer_lens/components/bert_mlm_head.py
new file mode 100644
index 000000000..878abec45
--- /dev/null
+++ b/transformer_lens/components/bert_mlm_head.py
@@ -0,0 +1,42 @@
+"""Hooked Transformer Bert MLM Head Component.
+
+This module contains all the component :class:`BertMLMHead`.
+"""
+from typing import Dict, Union
+
+import torch
+import torch.nn as nn
+from fancy_einsum import einsum
+from jaxtyping import Float
+
+from transformer_lens.components import LayerNorm
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+
+
+class BertMLMHead(nn.Module):
+ """
+ Transforms BERT embeddings into logits. The purpose of this module is to predict masked tokens in a sentence.
+ """
+
+ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
+ super().__init__()
+ if isinstance(cfg, Dict):
+ cfg = HookedTransformerConfig.from_dict(cfg)
+ self.cfg = cfg
+ self.W = nn.Parameter(torch.empty(cfg.d_model, cfg.d_model, dtype=cfg.dtype))
+ self.b = nn.Parameter(torch.zeros(cfg.d_model, dtype=cfg.dtype))
+ self.act_fn = nn.GELU()
+ self.ln = LayerNorm(cfg)
+
+ def forward(self, resid: Float[torch.Tensor, "batch pos d_model"]) -> torch.Tensor:
+ resid = (
+ einsum(
+ "batch pos d_model_in, d_model_out d_model_in -> batch pos d_model_out",
+ resid,
+ self.W,
+ )
+ + self.b
+ )
+ resid = self.act_fn(resid)
+ resid = self.ln(resid)
+ return resid
diff --git a/transformer_lens/components/embed.py b/transformer_lens/components/embed.py
new file mode 100644
index 000000000..da4c4fff1
--- /dev/null
+++ b/transformer_lens/components/embed.py
@@ -0,0 +1,36 @@
+"""Hooked Transformer Embed Component.
+
+This module contains all the component :class:`BertMLMHead`.
+"""
+from typing import Dict, Union
+
+import torch
+import torch.nn as nn
+from jaxtyping import Float, Int
+
+from transformer_lens.components import LayerNorm
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+
+
+# Embed & Unembed
+class Embed(nn.Module):
+ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
+ super().__init__()
+ if isinstance(cfg, Dict):
+ cfg = HookedTransformerConfig.from_dict(cfg)
+ self.cfg = cfg
+ self.W_E: Float[torch.Tensor, "d_vocab d_model"] = nn.Parameter(
+ torch.empty(self.cfg.d_vocab, self.cfg.d_model, dtype=cfg.dtype)
+ )
+ # Some models (e.g. Bloom) need post embedding layer norm
+ if cfg.post_embedding_ln:
+ self.ln = LayerNorm(cfg)
+
+ def forward(
+ self, tokens: Int[torch.Tensor, "batch pos"]
+ ) -> Float[torch.Tensor, "batch pos d_model"]:
+ # If A has shape [a, b] and B has shape [c, d], then A[:, B] has shape [a, c, d]
+ # B acts as a tensor of indices into the second dimension (so >=0 and Float[torch.Tensor, "batch pos d_model"]:
+ # Technically, all these einsums could be done with a single matmul, but this is more readable.
+ if self.cfg.load_in_4bit:
+ pre_act = self.hook_pre(
+ bnb.matmul_4bit(x, self.W_gate.t(), bias=None, quant_state=self.W_gate.quant_state)
+ )
+ else:
+ pre_act = self.hook_pre(
+ einsum(
+ "batch pos d_model, d_model d_mlp -> batch pos d_mlp",
+ x,
+ self.W_gate,
+ )
+ ) # [batch, pos, d_mlp]
+
+ if self.cfg.act_fn is not None and not self.cfg.act_fn.endswith("_ln"):
+ if self.cfg.load_in_4bit:
+ pre_linear = self.hook_pre_linear(
+ bnb.matmul_4bit(x, self.W_in.t(), bias=None, quant_state=self.W_in.quant_state)
+ )
+ else:
+ pre_linear = self.hook_pre_linear(
+ einsum(
+ "batch pos d_model, d_model d_mlp -> batch pos d_mlp",
+ x,
+ self.W_in,
+ )
+ )
+
+ post_act = self.hook_post(
+ (self.act_fn(pre_act) * pre_linear) + self.b_in
+ ) # [batch, pos, d_mlp]
+ else:
+ mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp]
+ post_act = self.hook_post(self.ln(mid_act))
+
+ if self.cfg.load_in_4bit:
+ return bnb.matmul_4bit(
+ post_act, self.W_out.t(), bias=None, quant_state=self.W_out.quant_state
+ )
+ else:
+ return (
+ einsum(
+ "batch pos d_mlp, d_mlp d_model -> batch pos d_model",
+ post_act,
+ self.W_out,
+ )
+ + self.b_out
+ )
diff --git a/transformer_lens/components/grouped_query_attention.py b/transformer_lens/components/grouped_query_attention.py
new file mode 100644
index 000000000..6c94e00a2
--- /dev/null
+++ b/transformer_lens/components/grouped_query_attention.py
@@ -0,0 +1,190 @@
+from typing import Dict, Tuple, Union
+
+import torch
+import torch.nn as nn
+from fancy_einsum import einsum
+from jaxtyping import Float
+
+from transformer_lens.components import AbstractAttention
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+
+
+class GroupedQueryAttention(AbstractAttention):
+ def __init__(
+ self,
+ cfg: Union[Dict, HookedTransformerConfig],
+ attn_type: str = "global",
+ layer_id: Union[int, None] = None,
+ ):
+ """Grouped Query Attention Block - see https://arxiv.org/abs/2305.13245 for details.
+ Similar to regular attention, W_Q, W_K, and W_V all have shape [head_index, d_model, d_head] and W_Q has shape [head_index, d_head, d_model].
+ However, under the hood the key and value weights _W_K and _W_V are stored with shape [n_key_value_heads, d_model, d_head] and are expanded when the corresponding properties' getter is called.
+ Similarly, during a forward pass, initially K and V are kept in shapes [batch, pos, n_key_value_heads, d_head] and will only be expanded to shapes [batch, pos, n_heads, d_head]
+ using torch.repeat_interleave when the attention pattern and z-scores are calculated.
+
+ Args:
+ cfg (Union[Dict, HookedTransformerConfig]): Config
+ attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global".
+ layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.
+ """
+ if isinstance(cfg, Dict):
+ cfg = HookedTransformerConfig.from_dict(cfg)
+ assert cfg.n_key_value_heads is not None
+ super().__init__(cfg, attn_type, layer_id)
+ self.repeat_kv_heads = cfg.n_heads // cfg.n_key_value_heads
+ self._W_K = nn.Parameter(
+ torch.empty(
+ cfg.n_key_value_heads,
+ self.cfg.d_model,
+ self.cfg.d_head,
+ dtype=cfg.dtype,
+ )
+ )
+ self._W_V = nn.Parameter(
+ torch.empty(
+ cfg.n_key_value_heads,
+ self.cfg.d_model,
+ self.cfg.d_head,
+ dtype=cfg.dtype,
+ )
+ )
+ self._b_K = nn.Parameter(
+ torch.zeros(cfg.n_key_value_heads, self.cfg.d_head, dtype=cfg.dtype)
+ )
+ self._b_V = nn.Parameter(
+ torch.zeros(cfg.n_key_value_heads, self.cfg.d_head, dtype=cfg.dtype)
+ )
+
+ @property
+ def W_K(self):
+ return torch.repeat_interleave(self._W_K, dim=0, repeats=self.repeat_kv_heads)
+
+ @W_K.setter
+ def W_K(self, value):
+ self._W_K = value
+
+ @property
+ def W_V(self):
+ return torch.repeat_interleave(self._W_V, dim=0, repeats=self.repeat_kv_heads)
+
+ @W_V.setter
+ def W_V(self, value):
+ self._W_V = value
+
+ @property
+ def b_K(self):
+ return torch.repeat_interleave(self._b_K, dim=0, repeats=self.repeat_kv_heads)
+
+ @b_K.setter
+ def b_K(self, value):
+ self._b_K = value
+
+ @property
+ def b_V(self):
+ return torch.repeat_interleave(self._b_V, dim=0, repeats=self.repeat_kv_heads)
+
+ @b_V.setter
+ def b_V(self, value):
+ self._b_V = value
+
+ def calculate_qkv_matrices(
+ self,
+ query_input: Union[
+ Float[torch.Tensor, "batch pos d_model"],
+ Float[torch.Tensor, "batch pos head_index d_model"],
+ ],
+ key_input: Union[
+ Float[torch.Tensor, "batch pos d_model"],
+ Float[torch.Tensor, "batch pos kv_head_index d_model"],
+ ],
+ value_input: Union[
+ Float[torch.Tensor, "batch pos d_model"],
+ Float[torch.Tensor, "batch pos kv_head_index d_model"],
+ ],
+ ) -> Tuple[
+ Float[torch.Tensor, "batch pos head_index d_head"],
+ Float[torch.Tensor, "batch pos kv_head_index d_head"],
+ Float[torch.Tensor, "batch pos kv_head_index d_head"],
+ ]:
+ """Calculate the Q, K, and V matrices for grouped query attention.
+ This function uses the unexpanded weights _W_K and _W_V to calculate K and V.
+
+ Args:
+ query_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos head_index d_model"]]): The input tensor for the query projection.
+ key_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos kv_head_index d_model"]]): The input tensor for the key projection. Note that is has as many head dimensions as the GPA block has key-value heads.
+ value_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos kv_head_index d_model"]]): The input tensor for the value projection. Note that is has as many head dimensions as the GPA block has key-value heads.
+
+ Returns:
+ Tuple[Float[torch.Tensor, "batch pos head_index d_head"], Float[torch.Tensor, "batch pos kv_head_index d_head"], Float[torch.Tensor, "batch pos kv_head_index d_head"]]:
+ A tuple containing the Q, K, and V matrices with the specified shapes.
+ """
+ if self.cfg.use_split_qkv_input or self.cfg.use_attn_in:
+ kv_einops_string = "batch pos kv_head_index d_model"
+ q_einops_string = "batch pos head_index d_model"
+ else:
+ kv_einops_string = q_einops_string = "batch pos d_model"
+
+ q = self.hook_q(
+ einsum(
+ f"{q_einops_string}, head_index d_model d_head \
+ -> batch pos head_index d_head",
+ query_input,
+ self.W_Q,
+ )
+ + self.b_Q
+ ) # [batch, pos, head_index, d_head]
+ k = self.hook_k(
+ einsum(
+ f"{kv_einops_string}, kv_head_index d_model d_head \
+ -> batch pos kv_head_index d_head",
+ key_input,
+ self._W_K,
+ )
+ + self._b_K
+ ) # [batch, pos, head_index, d_head]
+ v = self.hook_v(
+ einsum(
+ f"{kv_einops_string}, kv_head_index d_model d_head \
+ -> batch pos kv_head_index d_head",
+ value_input,
+ self._W_V,
+ )
+ + self._b_V
+ ) # [batch, pos, head_index, d_head]
+ return q, k, v
+
+ def calculate_attention_scores(
+ self,
+ q: Float[torch.Tensor, "batch query_pos head_index d_head"],
+ k: Float[torch.Tensor, "batch key_pos kv_head_index d_head"],
+ ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]:
+ """Calculate attention scores from Q and the unexpanded K matrix.
+ K will be expaned from [batch, pos, n_key_value_head, d_head] to [batch, pos, n_query_heads, d_head] using torch.repeat_interleave.
+
+ Args:
+ q (Float[torch.Tensor, "batch query_pos head_index d_head"]): The Q tensor.
+ k (Float[torch.Tensor, "batch key_pos kv_head_index d_head"]): The K tensor.
+
+ Returns:
+ Float[torch.Tensor, "batch head_index query_pos key_pos"]: The attention scores.
+ """
+ k = torch.repeat_interleave(k, dim=2, repeats=self.repeat_kv_heads)
+ return super().calculate_attention_scores(q, k)
+
+ def calculate_z_scores(
+ self,
+ v: Float[torch.Tensor, "batch key_pos kv_head_index d_head"],
+ pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"],
+ ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]:
+ """Calculate z scores from the attention pattern and the unexpanded V matrix.
+ V will be expaned from [batch, pos, n_key_value_head, d_head] to [batch, pos, n_query_heads, d_head] using torch.repeat_interleave.
+
+ Args:
+ v (Float[torch.Tensor, "batch query_pos head_index d_head"]): The V tensor.
+ pattern (Float[torch.Tensor, "batch key_pos kv_head_index d_head"]): The attention pattern.
+
+ Returns:
+ Float[torch.Tensor, "batch head_index query_pos key_pos"]: The z scores.
+ """
+ v = torch.repeat_interleave(v, dim=2, repeats=self.repeat_kv_heads)
+ return super().calculate_z_scores(v, pattern)
diff --git a/transformer_lens/components/layer_norm.py b/transformer_lens/components/layer_norm.py
new file mode 100644
index 000000000..855f2a688
--- /dev/null
+++ b/transformer_lens/components/layer_norm.py
@@ -0,0 +1,58 @@
+"""Hooked Transformer Layer Norm Component.
+
+This module contains all the component :class:`LayerNorm`.
+"""
+from typing import Dict, Optional, Union
+
+import torch
+import torch.nn as nn
+from jaxtyping import Float
+
+from transformer_lens.hook_points import HookPoint
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None):
+ """
+ LayerNorm with optional length parameter
+
+ length (Optional[int]): If the dimension of the LayerNorm. If not provided, assumed to be d_model
+ """
+ super().__init__()
+ if isinstance(cfg, Dict):
+ cfg = HookedTransformerConfig.from_dict(cfg)
+ self.cfg = cfg
+ self.eps = self.cfg.eps
+ if length is None:
+ self.length = self.cfg.d_model
+ else:
+ self.length = length
+
+ self.w = nn.Parameter(torch.ones(self.length, dtype=cfg.dtype))
+ self.b = nn.Parameter(torch.zeros(self.length, dtype=cfg.dtype))
+
+ # Adds a hook point for the normalisation scale factor
+ self.hook_scale = HookPoint() # [batch, pos, 1]
+ # Hook_normalized is on the LN output
+ self.hook_normalized = HookPoint() # [batch, pos, length]
+
+ def forward(
+ self,
+ x: Union[
+ Float[torch.Tensor, "batch pos d_model"],
+ Float[torch.Tensor, "batch pos head_index d_model"],
+ ],
+ ) -> Union[
+ Float[torch.Tensor, "batch pos d_model"],
+ Float[torch.Tensor, "batch pos head_index d_model"],
+ ]:
+ if self.cfg.dtype not in [torch.float32, torch.float64]:
+ x = x.to(torch.float32)
+
+ x = x - x.mean(-1, keepdim=True) # [batch, pos, length]
+ scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
+ (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
+ )
+ x = x / scale # [batch, pos, length]
+ return self.hook_normalized(x * self.w + self.b).to(self.cfg.dtype)
diff --git a/transformer_lens/components/layer_norm_pre.py b/transformer_lens/components/layer_norm_pre.py
new file mode 100644
index 000000000..36ea5af8b
--- /dev/null
+++ b/transformer_lens/components/layer_norm_pre.py
@@ -0,0 +1,55 @@
+"""Hooked Transformer Layer Norm Pre Component.
+
+This module contains all the component :class:`LayerNormPre`.
+"""
+from typing import Dict, Union
+
+import torch
+import torch.nn as nn
+from jaxtyping import Float
+
+from transformer_lens.hook_points import HookPoint
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+
+
+# LayerNormPre
+# I fold the LayerNorm weights and biases into later weights and biases.
+# This is just the 'center and normalise' part of LayerNorm
+# Centering is equivalent to just deleting one direction of residual space,
+# and is equivalent to centering the weight matrices of everything writing to the residual stream
+# Normalising is a funkier non-linear operation, that projects the residual stream onto the unit hypersphere
+class LayerNormPre(nn.Module):
+ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
+ """LayerNormPre - the 'center and normalise' part of LayerNorm. Length is
+ normally d_model, but is d_mlp for softmax. Not needed as a parameter. This
+ should only be used in inference mode after folding in LayerNorm weights"""
+ super().__init__()
+ if isinstance(cfg, Dict):
+ cfg = HookedTransformerConfig.from_dict(cfg)
+ self.cfg = cfg
+ self.eps = self.cfg.eps
+
+ # Adds a hook point for the normalisation scale factor
+ self.hook_scale = HookPoint() # [batch, pos]
+ # Hook Normalized captures LN output - here it's a vector with std 1 and mean 0
+ self.hook_normalized = HookPoint() # [batch, pos, length]
+
+ def forward(
+ self,
+ x: Union[
+ Float[torch.Tensor, "batch pos d_model"],
+ Float[torch.Tensor, "batch pos head_index d_model"],
+ ],
+ ) -> Union[
+ Float[torch.Tensor, "batch pos d_model"],
+ Float[torch.Tensor, "batch pos head_index d_model"],
+ ]:
+ if self.cfg.dtype not in [torch.float32, torch.float64]:
+ x = x.to(torch.float32)
+
+ x = x - x.mean(-1, keepdim=True) # [batch, pos, length]
+ scale: Union[
+ Float[torch.Tensor, "batch pos 1"],
+ Float[torch.Tensor, "batch pos head_index 1"],
+ ] = self.hook_scale((x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt())
+ return self.hook_normalized(x / scale).to(self.cfg.dtype)
diff --git a/transformer_lens/components/mlp.py b/transformer_lens/components/mlp.py
new file mode 100644
index 000000000..c9a18fc3f
--- /dev/null
+++ b/transformer_lens/components/mlp.py
@@ -0,0 +1,79 @@
+"""Hooked Transformer MLP Component.
+
+This module contains all the component :class:`MLP`.
+"""
+from typing import Callable, Dict, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from fancy_einsum import einsum
+from jaxtyping import Float
+
+from transformer_lens.components import LayerNorm, LayerNormPre
+from transformer_lens.hook_points import HookPoint
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+from transformer_lens.utils import gelu_fast, gelu_new, solu
+
+
+# MLP Layers
+class MLP(nn.Module):
+ act_fn: Callable[..., torch.Tensor]
+ ln: nn.Module
+
+ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
+ super().__init__()
+ if isinstance(cfg, Dict):
+ cfg = HookedTransformerConfig.from_dict(cfg)
+ self.cfg = cfg
+ assert self.cfg.d_mlp is not None # TODO: should this not be optional?
+ self.W_in = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype))
+ self.b_in = nn.Parameter(torch.zeros(self.cfg.d_mlp, dtype=cfg.dtype))
+ self.W_out = nn.Parameter(torch.empty(self.cfg.d_mlp, self.cfg.d_model, dtype=cfg.dtype))
+ self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype))
+
+ self.hook_pre = HookPoint() # [batch, pos, d_mlp]
+ self.hook_post = HookPoint() # [batch, pos, d_mlp]
+
+ if self.cfg.act_fn == "relu":
+ self.act_fn = F.relu
+ elif self.cfg.act_fn == "gelu":
+ self.act_fn = F.gelu
+ elif self.cfg.act_fn == "silu":
+ self.act_fn = F.silu
+ elif self.cfg.act_fn == "gelu_new":
+ self.act_fn = gelu_new
+ elif self.cfg.act_fn == "gelu_fast":
+ self.act_fn = gelu_fast
+ elif self.cfg.act_fn == "solu_ln":
+ self.act_fn = solu
+ # Hook taken between activation and layer norm
+ self.hook_mid = HookPoint() # [batch, pos, d_mlp]
+ if self.cfg.normalization_type == "LN":
+ self.ln = LayerNorm(self.cfg, self.cfg.d_mlp)
+ else:
+ self.ln = LayerNormPre(self.cfg)
+
+ else:
+ raise ValueError(f"Invalid activation function name: {self.cfg.act_fn}")
+
+ def forward(
+ self, x: Float[torch.Tensor, "batch pos d_model"]
+ ) -> Float[torch.Tensor, "batch pos d_model"]:
+ # Technically, all these einsums could be done with a single matmul, but this is more readable.
+ pre_act = self.hook_pre(
+ einsum("batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_in) + self.b_in
+ ) # [batch, pos, d_mlp]
+ if self.cfg.act_fn is not None and not self.cfg.act_fn.endswith("_ln"):
+ post_act = self.hook_post(self.act_fn(pre_act)) # [batch, pos, d_mlp]
+ else:
+ mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp]
+ post_act = self.hook_post(self.ln(mid_act))
+ return (
+ einsum(
+ "batch pos d_mlp, d_mlp d_model -> batch pos d_model",
+ post_act,
+ self.W_out,
+ )
+ + self.b_out
+ )
diff --git a/transformer_lens/components/moe.py b/transformer_lens/components/moe.py
new file mode 100644
index 000000000..01f0298c7
--- /dev/null
+++ b/transformer_lens/components/moe.py
@@ -0,0 +1,62 @@
+from typing import Dict, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from fancy_einsum import einsum
+from jaxtyping import Float
+
+from transformer_lens.components import MLP, GatedMLP
+from transformer_lens.hook_points import HookPoint
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+
+
+class MoE(nn.Module):
+ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
+ super().__init__()
+ if isinstance(cfg, Dict):
+ cfg = HookedTransformerConfig.from_dict(cfg)
+ self.cfg = cfg
+
+ # Ensure that num_experts and experts_per_token are specified and non-zero
+ assert cfg.num_experts is not None, "num_experts must be specified for MoE layer"
+ assert cfg.experts_per_token, "experts_per_token must be specified for MoE layer"
+ self.experts_per_token: int = cfg.experts_per_token
+ assert (
+ cfg.experts_per_token <= cfg.num_experts
+ ), "experts_per_token must be less than or equal to num_experts"
+
+ self.experts = nn.ModuleList(
+ [GatedMLP(cfg) if cfg.gated_mlp else MLP(cfg) for _ in range(cfg.num_experts)]
+ )
+ self.W_gate = nn.Parameter(torch.empty(cfg.d_model, cfg.num_experts, dtype=cfg.dtype))
+
+ # Hook on the weights of selected experts [batch pos experts_per_token]
+ self.hook_expert_weights = HookPoint()
+ # Hook on the indices of selected experts [batch pos experts_per_token]
+ self.hook_expert_indices = HookPoint()
+
+ def forward(
+ self, x: Float[torch.Tensor, "batch pos d_model"]
+ ) -> Float[torch.Tensor, "batch pos d_model"]:
+ # [batch, pos, d_model] -> [batch, pos, num_experts]
+ gate_logits = einsum(
+ "batch pos d_model, d_model num_experts -> batch pos num_experts",
+ x,
+ self.W_gate,
+ )
+
+ # choose the top k(=experts_per_token) experts to use
+ # both are [batch, pos, experts_per_token]
+ weights, expert_indices = torch.topk(gate_logits, self.experts_per_token)
+ weights = self.hook_expert_weights(F.softmax(weights, dim=-1))
+ expert_indices = self.hook_expert_indices(expert_indices)
+
+ results = torch.zeros_like(x)
+ for i, expert_mlp in enumerate(self.experts):
+ # find the batch, pos, and expert indices which use this expert
+ batch, pos, expert = torch.where(expert_indices == i)
+ # accumulate the weighted outputs from the expert
+ results[batch] += weights[batch, pos, expert, None, None] * expert_mlp(x[batch])
+
+ return results
diff --git a/transformer_lens/components/pos_embed.py b/transformer_lens/components/pos_embed.py
new file mode 100644
index 000000000..000b0d5c1
--- /dev/null
+++ b/transformer_lens/components/pos_embed.py
@@ -0,0 +1,69 @@
+"""Hooked Transformer POS Embed Component.
+
+This module contains all the component :class:`PosEmbed`.
+"""
+from typing import Dict, Optional, Union
+
+import einops
+import torch
+import torch.nn as nn
+from jaxtyping import Float, Int
+
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+from transformer_lens.utils import get_offset_position_ids
+
+
+# Positional Embeddings
+class PosEmbed(nn.Module):
+ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
+ super().__init__()
+ if isinstance(cfg, Dict):
+ cfg = HookedTransformerConfig.from_dict(cfg)
+ self.cfg = cfg
+ self.W_pos = nn.Parameter(torch.empty(self.cfg.n_ctx, self.cfg.d_model, dtype=cfg.dtype))
+
+ def forward(
+ self,
+ tokens: Int[torch.Tensor, "batch pos"],
+ past_kv_pos_offset: int = 0,
+ attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
+ ) -> Float[torch.Tensor, "batch pos d_model"]:
+ """
+ Forward pass for positional embeddings.
+
+ Args:
+ tokens (Int[torch.Tensor, "batch pos"]): Input tokens.
+ past_kv_pos_offset (int, optional): The length of tokens in the past_kv_cache. Defaults to 0.
+ attention_mask (Int[torch.Tensor, "batch pos"], optional): The attention mask for padded tokens.
+ Defaults to None.
+
+ Returns:
+ Float[torch.Tensor, "batch pos d_model"]: Absolute position embeddings.
+ """
+ tokens_length = tokens.size(-1)
+
+ if attention_mask is None:
+ pos_embed = self.W_pos[
+ past_kv_pos_offset : tokens_length + past_kv_pos_offset, :
+ ] # [pos, d_model]
+ batch_pos_embed = einops.repeat(
+ pos_embed, "pos d_model -> batch pos d_model", batch=tokens.size(0)
+ )
+
+ else:
+ # Separated from the no padding case for computational efficiency
+ # (this code is a bit slower than the code above)
+
+ offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask)
+ pos_embed = self.W_pos[offset_position_ids] # [batch, pos, d_model]
+
+ # Set the position embeddings to 0 for pad tokens (this is an arbitrary choice)
+ padding_mask = ~attention_mask.bool() # [batch, tokens_length]
+ offset_padding_mask = padding_mask[
+ :, past_kv_pos_offset : tokens_length + past_kv_pos_offset
+ ].unsqueeze(
+ -1
+ ) # [batch, pos, 1]
+ batch_pos_embed = torch.where(offset_padding_mask, 0, pos_embed)
+
+ return batch_pos_embed.clone()
diff --git a/transformer_lens/components/rms_norm.py b/transformer_lens/components/rms_norm.py
new file mode 100644
index 000000000..7ecdc66d5
--- /dev/null
+++ b/transformer_lens/components/rms_norm.py
@@ -0,0 +1,47 @@
+"""Hooked Transformer RMS Norm Component.
+
+This module contains all the component :class:`RMSNorm`.
+"""
+from typing import Dict, Optional, Union
+
+import torch
+import torch.nn as nn
+from jaxtyping import Float
+
+from transformer_lens.hook_points import HookPoint
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None):
+ """
+ RMSNorm - LayerNorm without the centering and bias (RMS = Root Mean Square)
+
+ length (Optional[int]): If the dimension of the RMSNorm. If not provided, assumed to be d_model
+ """
+ super().__init__()
+ if isinstance(cfg, Dict):
+ cfg = HookedTransformerConfig.from_dict(cfg)
+ self.cfg = cfg
+ self.eps = self.cfg.eps
+ if length is None:
+ self.length = self.cfg.d_model
+ else:
+ self.length = length
+
+ self.w = nn.Parameter(torch.ones(self.length, dtype=cfg.dtype))
+
+ # Adds a hook point for the normalisation scale factor
+ self.hook_scale = HookPoint() # [batch, pos, 1]
+ self.hook_normalized = HookPoint() # [batch, pos, length]
+
+ def forward(
+ self, x: Float[torch.Tensor, "batch pos length"]
+ ) -> Float[torch.Tensor, "batch pos length"]:
+ if self.cfg.dtype not in [torch.float32, torch.float64]:
+ x = x.to(torch.float32)
+ scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
+ (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
+ )
+ x = self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length]
+ return x * self.w
diff --git a/transformer_lens/components/rms_norm_pre.py b/transformer_lens/components/rms_norm_pre.py
new file mode 100644
index 000000000..4df010136
--- /dev/null
+++ b/transformer_lens/components/rms_norm_pre.py
@@ -0,0 +1,37 @@
+"""Hooked Transformer RMS Norm Pre Component.
+
+This module contains all the component :class:`RMSNormPre`.
+"""
+from typing import Dict, Union
+
+import torch
+import torch.nn as nn
+from jaxtyping import Float
+
+from transformer_lens.hook_points import HookPoint
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+
+
+class RMSNormPre(nn.Module):
+ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
+ """RMSNormPre - LayerNormPre without the centering and bias (RMS = Root Mean Square)"""
+ super().__init__()
+ if isinstance(cfg, Dict):
+ cfg = HookedTransformerConfig.from_dict(cfg)
+ self.cfg = cfg
+ self.eps = self.cfg.eps
+
+ # Adds a hook point for the normalisation scale factor
+ self.hook_scale = HookPoint() # [batch, pos]
+ self.hook_normalized = HookPoint() # [batch, pos, length]
+
+ def forward(
+ self, x: Float[torch.Tensor, "batch pos length"]
+ ) -> Float[torch.Tensor, "batch pos length"]:
+ if self.cfg.dtype not in [torch.float32, torch.float64]:
+ x = x.to(torch.float32)
+
+ scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
+ (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
+ )
+ return self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length]
diff --git a/transformer_lens/components/token_typed_embed.py b/transformer_lens/components/token_typed_embed.py
new file mode 100644
index 000000000..1cd5c74b4
--- /dev/null
+++ b/transformer_lens/components/token_typed_embed.py
@@ -0,0 +1,29 @@
+"""Hooked Transformer Token Typed Embed Component.
+
+This module contains all the component :class:`TokenTypeEmbed`.
+"""
+from typing import Dict, Union
+
+import torch
+import torch.nn as nn
+from jaxtyping import Int
+
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+
+
+class TokenTypeEmbed(nn.Module):
+ """
+ The token-type embed is a binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, `1` from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length).
+
+ See the BERT paper for more information: https://arxiv.org/pdf/1810.04805.pdf
+ """
+
+ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
+ super().__init__()
+ if isinstance(cfg, Dict):
+ cfg = HookedTransformerConfig.from_dict(cfg)
+ self.cfg = cfg
+ self.W_token_type = nn.Parameter(torch.empty(2, self.cfg.d_model, dtype=cfg.dtype))
+
+ def forward(self, token_type_ids: Int[torch.Tensor, "batch pos"]):
+ return self.W_token_type[token_type_ids, :]
diff --git a/transformer_lens/components/transformer_block.py b/transformer_lens/components/transformer_block.py
new file mode 100644
index 000000000..8980f9e8c
--- /dev/null
+++ b/transformer_lens/components/transformer_block.py
@@ -0,0 +1,181 @@
+"""Hooked Transformer Transformer Block Component.
+
+This module contains all the component :class:`TransformerBlock`.
+"""
+import logging
+from typing import Dict, Optional, Union
+
+import torch
+import torch.nn as nn
+from jaxtyping import Float, Int
+
+from transformer_lens.components import (
+ MLP,
+ Attention,
+ GatedMLP,
+ GroupedQueryAttention,
+ LayerNorm,
+ LayerNormPre,
+ MoE,
+ RMSNorm,
+ RMSNormPre,
+)
+from transformer_lens.hook_points import HookPoint
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry
+from transformer_lens.utils import repeat_along_head_dimension
+
+
+# Transformer Block
+class TransformerBlock(nn.Module):
+ ln1: nn.Module
+ ln2: nn.Module
+ mlp: nn.Module
+
+ def __init__(self, cfg: Union[Dict, HookedTransformerConfig], block_index):
+ super().__init__()
+ if isinstance(cfg, Dict):
+ cfg = HookedTransformerConfig.from_dict(cfg)
+ self.cfg = cfg
+ if self.cfg.normalization_type == "LN":
+ self.ln1 = LayerNorm(cfg)
+ if not self.cfg.attn_only:
+ self.ln2 = LayerNorm(cfg)
+ elif self.cfg.normalization_type == "LNPre":
+ # We've folded in LayerNorm weights, so just need the center + scale parts
+ self.ln1 = LayerNormPre(cfg)
+ if not self.cfg.attn_only:
+ self.ln2 = LayerNormPre(cfg)
+ elif self.cfg.normalization_type == "RMS":
+ self.ln1 = RMSNorm(cfg)
+ if not self.cfg.attn_only:
+ self.ln2 = RMSNorm(cfg)
+ elif self.cfg.normalization_type == "RMSPre":
+ self.ln1 = RMSNormPre(cfg)
+ if not self.cfg.attn_only:
+ self.ln2 = RMSNormPre(cfg)
+ elif self.cfg.normalization_type is None:
+ self.ln1 = nn.Identity()
+ if not self.cfg.attn_only:
+ self.ln2 = nn.Identity()
+ else:
+ logging.warning(f"Invalid normalization_type passed in {self.cfg.normalization_type}")
+
+ attention = Attention if self.cfg.n_key_value_heads is None else GroupedQueryAttention
+ if not self.cfg.use_local_attn:
+ self.attn = attention(cfg, "global", block_index)
+ else:
+ assert self.cfg.attn_types is not None
+ attn_type = self.cfg.attn_types[block_index]
+ self.attn = attention(cfg, attn_type, block_index)
+ if not self.cfg.attn_only:
+ if self.cfg.num_experts:
+ self.mlp = MoE(cfg)
+ elif self.cfg.gated_mlp:
+ self.mlp = GatedMLP(cfg)
+ else:
+ self.mlp = MLP(cfg)
+
+ self.hook_attn_in = HookPoint() # [batch, pos, n_heads, d_model]
+ self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model]
+ self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model]
+ self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model]
+ self.hook_mlp_in = HookPoint() # [batch, pos, d_model]
+
+ self.hook_attn_out = HookPoint() # [batch, pos, d_model]
+ self.hook_mlp_out = HookPoint() # [batch, pos, d_model]
+
+ self.hook_resid_pre = HookPoint() # [batch, pos, d_model]
+ if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp:
+ self.hook_resid_mid = HookPoint() # [batch, pos, d_model]
+ self.hook_resid_post = HookPoint() # [batch, pos, d_model]
+
+ def forward(
+ self,
+ resid_pre: Float[torch.Tensor, "batch pos d_model"],
+ shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
+ past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None,
+ attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
+ ) -> Float[torch.Tensor, "batch pos d_model"]:
+ """A single Transformer block.
+
+ Args:
+ resid_pre (torch.Tensor): The residual stream - shape [batch, pos, d_model]
+ cache (HookedTransformerKeyValueCache): A cache of previous keys and values, used only when generating text. Defaults to None.
+ shortformer_pos_embed (torch.Tensor, optional): Only used for positional_embeddings_type == "shortformer". The positional embeddings. See HookedTransformerConfig for details. Defaults to None.
+ attention_mask (torch.Tensor, optional): The attention mask for padded tokens. Defaults to None.
+
+ Returns:
+ _type_: _description_
+ """
+ resid_pre = self.hook_resid_pre(resid_pre) # [batch, pos, d_model]
+
+ if self.cfg.use_attn_in or self.cfg.use_split_qkv_input:
+ # We're adding a head dimension
+ if shortformer_pos_embed is not None:
+ shortformer_pos_embed = repeat_along_head_dimension(
+ shortformer_pos_embed, n_heads=self.cfg.n_heads
+ )
+ else:
+ attn_in = resid_pre
+
+ if self.cfg.use_attn_in:
+ attn_in = self.hook_attn_in(
+ repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads)
+ )
+
+ if self.cfg.use_split_qkv_input:
+ n_kv_heads = (
+ self.cfg.n_key_value_heads
+ if self.cfg.n_key_value_heads is not None
+ else self.cfg.n_heads
+ )
+ query_input = self.hook_q_input(
+ repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads)
+ )
+ key_input = self.hook_k_input(
+ repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads)
+ )
+ value_input = self.hook_v_input(
+ repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads)
+ )
+ else:
+ query_input = attn_in
+ key_input = attn_in
+ value_input = attn_in
+
+ attn_out = self.hook_attn_out(
+ # hook the residual stream states that are used to calculate the
+ # queries, keys and values, independently.
+ # Then take the layer norm of these inputs, and pass these to the attention module.
+ self.attn(
+ query_input=self.ln1(query_input)
+ + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
+ key_input=self.ln1(key_input)
+ + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
+ value_input=self.ln1(value_input),
+ past_kv_cache_entry=past_kv_cache_entry,
+ attention_mask=attention_mask,
+ )
+ ) # [batch, pos, d_model]
+ if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp:
+ resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model]
+ mlp_in = (
+ resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone())
+ )
+ normalized_resid_mid = self.ln2(mlp_in)
+ mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid)) # [batch, pos, d_model]
+ resid_post = self.hook_resid_post(resid_mid + mlp_out) # [batch, pos, d_model]
+ elif self.cfg.parallel_attn_mlp:
+ # Dumb thing done by GPT-J, both MLP and Attn read from resid_pre and write to resid_post, no resid_mid used.
+ # In GPT-J, LN1 and LN2 are tied, in GPT-NeoX they aren't.
+ normalized_resid_pre_2 = self.ln2(
+ resid_pre if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_pre.clone())
+ )
+ mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_pre_2)) # [batch, pos, d_model]
+ resid_post = self.hook_resid_post(
+ resid_pre + attn_out + mlp_out
+ ) # [batch, pos, d_model]
+ else:
+ resid_post = self.hook_resid_post(resid_pre + attn_out) # [batch, pos, d_model]
+ return resid_post
diff --git a/transformer_lens/components/unembed.py b/transformer_lens/components/unembed.py
new file mode 100644
index 000000000..d1de2ea2f
--- /dev/null
+++ b/transformer_lens/components/unembed.py
@@ -0,0 +1,39 @@
+"""Hooked Transformer Unembed Component.
+
+This module contains all the component :class:`Unembed`.
+"""
+from typing import Dict, Union
+
+import torch
+import torch.nn as nn
+from fancy_einsum import einsum
+from jaxtyping import Float
+
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+
+
+class Unembed(nn.Module):
+ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
+ super().__init__()
+ if isinstance(cfg, Dict):
+ cfg = HookedTransformerConfig.from_dict(cfg)
+ self.cfg = cfg
+ # Note that there's a separate variable for d_vocab_out and d_vocab (the input vocab size). For language tasks these are always the same, but for algorithmic tasks we may want them to be different.
+ self.W_U: Float[torch.Tensor, "d_model d_vocab_out"] = nn.Parameter(
+ torch.empty(self.cfg.d_model, self.cfg.d_vocab_out, dtype=cfg.dtype)
+ )
+ self.b_U: Float[torch.Tensor, "d_vocab_out"] = nn.Parameter(
+ torch.zeros(self.cfg.d_vocab_out, dtype=cfg.dtype)
+ )
+
+ def forward(
+ self, residual: Float[torch.Tensor, "batch pos d_model"]
+ ) -> Float[torch.Tensor, "batch pos d_vocab_out"]:
+ return (
+ einsum(
+ "batch pos d_model, d_model vocab -> batch pos vocab",
+ residual,
+ self.W_U,
+ )
+ + self.b_U
+ )
diff --git a/transformer_lens/evals.py b/transformer_lens/evals.py
index 710560491..b77c727c5 100644
--- a/transformer_lens/evals.py
+++ b/transformer_lens/evals.py
@@ -40,9 +40,7 @@ def make_wiki_data_loader(tokenizer, batch_size=8):
wiki_data = load_dataset("wikitext", "wikitext-2-v1", split="train")
print(len(wiki_data))
dataset = utils.tokenize_and_concatenate(wiki_data, tokenizer)
- data_loader = DataLoader(
- dataset, batch_size=batch_size, shuffle=True, drop_last=True
- )
+ data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
return data_loader
@@ -55,9 +53,7 @@ def make_owt_data_loader(tokenizer, batch_size=8):
owt_data = load_dataset("stas/openwebtext-10k", split="train")
print(len(owt_data))
dataset = utils.tokenize_and_concatenate(owt_data, tokenizer)
- data_loader = DataLoader(
- dataset, batch_size=batch_size, shuffle=True, drop_last=True
- )
+ data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
return data_loader
@@ -71,9 +67,7 @@ def make_pile_data_loader(tokenizer, batch_size=8):
pile_data = load_dataset("NeelNanda/pile-10k", split="train")
print(len(pile_data))
dataset = utils.tokenize_and_concatenate(pile_data, tokenizer)
- data_loader = DataLoader(
- dataset, batch_size=batch_size, shuffle=True, drop_last=True
- )
+ data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
return data_loader
@@ -86,12 +80,8 @@ def make_code_data_loader(tokenizer, batch_size=8):
"""
code_data = load_dataset("codeparrot/codeparrot-valid-v2-near-dedup", split="train")
print(len(code_data))
- dataset = utils.tokenize_and_concatenate(
- code_data, tokenizer, column_name="content"
- )
- data_loader = DataLoader(
- dataset, batch_size=batch_size, shuffle=True, drop_last=True
- )
+ dataset = utils.tokenize_and_concatenate(code_data, tokenizer, column_name="content")
+ data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
return data_loader
@@ -146,9 +136,7 @@ def induction_loss(
repeated_tokens[:, 0] = tokenizer.bos_token_id
# Run the model, and extract the per token correct log prob
logits = model(repeated_tokens, return_type="logits")
- correct_log_probs = utils.lm_cross_entropy_loss(
- logits, repeated_tokens, per_token=True
- )
+ correct_log_probs = utils.lm_cross_entropy_loss(logits, repeated_tokens, per_token=True)
# Take the loss over the second half of the sequence
return correct_log_probs[:, subseq_len + 1 :].mean()
@@ -212,9 +200,7 @@ def __init__(
self.tokenizer = tokenizer
self.prepend_bos = prepend_bos
- self.templates = (
- templates if templates is not None else self.get_default_templates()
- )
+ self.templates = templates if templates is not None else self.get_default_templates()
self.names = names if names is not None else self.get_default_names()
self.nouns = nouns if nouns is not None else self.get_default_nouns()
@@ -256,9 +242,7 @@ def get_sample(self, symmetric=False) -> List[Dict[str, str]]:
if symmetric:
sample_2 = template.replace("[A]", names[1])
sample_2 = sample_2.replace("[B]", names[0])
- samples.append(
- {"text": sample_2, "IO": " " + names[1], "S": " " + names[0]}
- )
+ samples.append({"text": sample_2, "IO": " " + names[1], "S": " " + names[0]})
return samples
@@ -282,9 +266,7 @@ def get_default_nouns():
@torch.inference_mode()
-def ioi_eval(
- model, dataset=None, batch_size=8, num_samples=1000, tokenizer=None, symmetric=False
-):
+def ioi_eval(model, dataset=None, batch_size=8, num_samples=1000, tokenizer=None, symmetric=False):
"""Evaluate the Model on the Indirect Object Identification Task.
Args:
@@ -314,9 +296,7 @@ def collate(samples):
"prompt_length": [p.shape[0] for p in prompts],
}
- data_loader = DataLoader(
- dataset, batch_size=batch_size, shuffle=True, collate_fn=collate
- )
+ data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate)
total_correct = 0
total_logit_diff = 0
diff --git a/transformer_lens/head_detector.py b/transformer_lens/head_detector.py
index 41eb72da9..fbb50fae8 100644
--- a/transformer_lens/head_detector.py
+++ b/transformer_lens/head_detector.py
@@ -2,6 +2,7 @@
Utilities for detecting specific types of heads (e.g. previous token heads).
"""
+
import logging
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union, cast
@@ -10,7 +11,8 @@
import torch
from typing_extensions import Literal, get_args
-from transformer_lens import ActivationCache, HookedTransformer
+from transformer_lens.ActivationCache import ActivationCache
+from transformer_lens.HookedTransformer import HookedTransformer
from transformer_lens.utils import is_lower_triangular, is_square
HeadName = Literal["previous_token_head", "duplicate_token_head", "induction_head"]
@@ -24,9 +26,7 @@
f"detection_pattern must be a Tensor or one of head names: {HEAD_NAMES}; got %s"
)
-SEQ_LEN_ERR = (
- "The sequence must be non-empty and must fit within the model's context window."
-)
+SEQ_LEN_ERR = "The sequence must be non-empty and must fit within the model's context window."
DET_PAT_NOT_SQUARE_ERR = "The detection pattern must be a lower triangular matrix of shape (sequence_length, sequence_length); sequence_length=%d; got detection patern of shape %s"
@@ -111,9 +111,7 @@ def detect_head(
# Validate detection pattern if it's a string
if isinstance(detection_pattern, str):
- assert detection_pattern in HEAD_NAMES, (
- INVALID_HEAD_NAME_ERR % detection_pattern
- )
+ assert detection_pattern in HEAD_NAMES, INVALID_HEAD_NAME_ERR % detection_pattern
if isinstance(seq, list):
batch_scores = [detect_head(model, seq, detection_pattern) for seq in seq]
return torch.stack(batch_scores).mean(0)
@@ -123,9 +121,7 @@ def detect_head(
).to(cfg.device)
# if we're using "mul", detection_pattern should consist of zeros and ones
- if error_measure == "mul" and not set(detection_pattern.unique().tolist()).issubset(
- {0, 1}
- ):
+ if error_measure == "mul" and not set(detection_pattern.unique().tolist()).issubset({0, 1}):
logging.warning(
"Using detection pattern with values other than 0 or 1 with error_measure 'mul'"
)
@@ -140,9 +136,7 @@ def detect_head(
_, cache = model.run_with_cache(tokens, remove_batch_dim=True)
if heads is None:
- layer2heads = {
- layer_i: list(range(cfg.n_heads)) for layer_i in range(cfg.n_layers)
- }
+ layer2heads = {layer_i: list(range(cfg.n_heads)) for layer_i in range(cfg.n_layers)}
elif isinstance(heads, list):
layer2heads = defaultdict(list)
for layer, head in heads:
@@ -198,9 +192,7 @@ def get_duplicate_token_head_detection_pattern(
# If token_pattern[i][j] matches its transpose, then token j and token i are duplicates.
eq_mask = np.equal(token_pattern, token_pattern.T).astype(int)
- np.fill_diagonal(
- eq_mask, 0
- ) # Current token is always a duplicate of itself. Ignore that.
+ np.fill_diagonal(eq_mask, 0) # Current token is always a duplicate of itself. Ignore that.
detection_pattern = eq_mask.astype(int)
return torch.tril(torch.as_tensor(detection_pattern).float())
diff --git a/transformer_lens/hook_points.py b/transformer_lens/hook_points.py
index ce86fd77a..858c3f48b 100644
--- a/transformer_lens/hook_points.py
+++ b/transformer_lens/hook_points.py
@@ -2,14 +2,18 @@
Helpers to access activations in models.
"""
+
import logging
from contextlib import contextmanager
from dataclasses import dataclass
-from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
+from functools import partial
+from typing import Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union, cast
import torch.nn as nn
import torch.utils.hooks as hooks
+from transformer_lens.utils import Slice
+
@dataclass
class LensHandle:
@@ -51,7 +55,12 @@ def add_perma_hook(self, hook, dir="fwd") -> None:
self.add_hook(hook, dir=dir, is_permanent=True)
def add_hook(
- self, hook, dir="fwd", is_permanent=False, level=None, prepend=False
+ self,
+ hook: Callable,
+ dir: Literal["fwd", "bwd"] = "fwd",
+ is_permanent: bool = False,
+ level: Optional[int] = None,
+ prepend: bool = False,
) -> None:
"""
Hook format is fn(activation, hook_name)
@@ -59,47 +68,38 @@ def add_hook(
which are the same for a HookPoint)
If prepend is True, add this hook before all other hooks
"""
- if dir == "fwd":
-
- def full_hook(module, module_input, module_output):
- return hook(module_output, hook=self)
- full_hook.__name__ = (
- hook.__repr__()
- ) # annotate the `full_hook` with the string representation of the `hook` function
+ def full_hook(module, module_input, module_output):
+ if (
+ dir == "bwd"
+ ): # For a backwards hook, module_output is a tuple of (grad,) - I don't know why.
+ module_output = module_output[0]
+ return hook(module_output, hook=self)
- handle = self.register_forward_hook(full_hook)
- handle = LensHandle(handle, is_permanent, level)
-
- if prepend:
- # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this...
- self._forward_hooks.move_to_end(handle.hook.id, last=False)
- self.fwd_hooks.insert(0, handle)
-
- else:
- self.fwd_hooks.append(handle)
+ full_hook.__name__ = (
+ hook.__repr__()
+ ) # annotate the `full_hook` with the string representation of the `hook` function
+ if dir == "fwd":
+ pt_handle = self.register_forward_hook(full_hook)
+ _internal_hooks = self._forward_hooks
+ visible_hooks = self.fwd_hooks
elif dir == "bwd":
- # For a backwards hook, module_output is a tuple of (grad,) - I don't know why.
-
- def full_hook(module, module_input, module_output):
- return hook(module_output[0], hook=self)
+ pt_handle = self.register_full_backward_hook(full_hook)
+ _internal_hooks = self._backward_hooks
+ visible_hooks = self.bwd_hooks
+ else:
+ raise ValueError(f"Invalid direction {dir}")
- full_hook.__name__ = (
- hook.__repr__()
- ) # annotate the `full_hook` with the string representation of the `hook` function
+ handle = LensHandle(pt_handle, is_permanent, level)
- handle = self.register_full_backward_hook(full_hook)
- handle = LensHandle(handle, is_permanent, level)
+ if prepend:
+ # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this...
+ _internal_hooks.move_to_end(handle.hook.id, last=False) # type: ignore # TODO: this type error could signify a bug
+ visible_hooks.insert(0, handle)
- if prepend:
- # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this...
- self._backward_hooks.move_to_end(handle.hook.id, last=False)
- self.bwd_hooks.insert(0, handle)
- else:
- self.bwd_hooks.append(handle)
else:
- raise ValueError(f"Invalid direction {dir}")
+ visible_hooks.append(handle)
def remove_hooks(self, dir="fwd", including_permanent=False, level=None) -> None:
def _remove_hooks(handles: List[LensHandle]) -> List[LensHandle]:
@@ -107,9 +107,7 @@ def _remove_hooks(handles: List[LensHandle]) -> List[LensHandle]:
for handle in handles:
if including_permanent:
handle.hook.remove()
- elif (not handle.is_permanent) and (
- level is None or handle.context_level == level
- ):
+ elif (not handle.is_permanent) and (level is None or handle.context_level == level):
handle.hook.remove()
else:
output_handles.append(handle)
@@ -133,6 +131,7 @@ def layer(self):
# Returns the layer index if the name has the form 'blocks.{layer}.{...}'
# Helper function that's mainly useful on HookedTransformer
# If it doesn't have this form, raises an error -
+ assert self.name is not None # keep mypy happy
split_name = self.name.split(".")
return int(split_name[1])
@@ -185,13 +184,9 @@ def setup(self):
def hook_points(self):
return self.hook_dict.values()
- def remove_all_hook_fns(
- self, direction="both", including_permanent=False, level=None
- ):
+ def remove_all_hook_fns(self, direction="both", including_permanent=False, level=None):
for hp in self.hook_points():
- hp.remove_hooks(
- direction, including_permanent=including_permanent, level=level
- )
+ hp.remove_hooks(direction, including_permanent=including_permanent, level=level)
def clear_contexts(self):
for hp in self.hook_points():
@@ -228,12 +223,16 @@ def check_and_add_hook(
is_permanent=is_permanent,
prepend=prepend,
)
- hook_point.add_hook(
- hook, dir=dir, is_permanent=is_permanent, level=level, prepend=prepend
- )
+ hook_point.add_hook(hook, dir=dir, is_permanent=is_permanent, level=level, prepend=prepend)
def check_hooks_to_add(
- self, hook_point, hook_point_name, hook, dir="fwd", is_permanent=False
+ self,
+ hook_point,
+ hook_point_name,
+ hook,
+ dir="fwd",
+ is_permanent=False,
+ prepend=False,
) -> None:
"""Override this function to add checks on which hooks should be added"""
pass
@@ -297,23 +296,19 @@ def hooks(
self.context_level += 1
for name, hook in fwd_hooks:
- if type(name) == str:
- self.mod_dict[name].add_hook(
- hook, dir="fwd", level=self.context_level
- )
+ if isinstance(name, str):
+ self.mod_dict[name].add_hook(hook, dir="fwd", level=self.context_level)
else:
# Otherwise, name is a Boolean function on names
for hook_name, hp in self.hook_dict.items():
if name(hook_name):
hp.add_hook(hook, dir="fwd", level=self.context_level)
for name, hook in bwd_hooks:
- if type(name) == str:
- self.mod_dict[name].add_hook(
- hook, dir="bwd", level=self.context_level
- )
+ if isinstance(name, str):
+ self.mod_dict[name].add_hook(hook, dir="bwd", level=self.context_level)
else:
# Otherwise, name is a Boolean function on names
- for hook_name, hp in self.hook_dict:
+ for hook_name, hp in self.hook_dict: # type: ignore
if name(hook_name):
hp.add_hook(hook, dir="bwd", level=self.context_level)
yield self
@@ -359,9 +354,7 @@ def run_with_hooks(
"WARNING: Hooks will be reset at the end of run_with_hooks. This removes the backward hooks before a backward pass can occur."
)
- with self.hooks(
- fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts
- ) as hooked_model:
+ with self.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_model:
return hooked_model.forward(*model_args, **model_kwargs)
def add_caching_hooks(
@@ -396,25 +389,25 @@ def add_caching_hooks(
filter_list = names_filter
names_filter = lambda name: name in filter_list
- self.is_caching = True
+ # mypy can't seem to infer this
+ names_filter = cast(Callable[[str], bool], names_filter)
- def save_hook(tensor, hook):
- if remove_batch_dim:
- cache[hook.name] = tensor.detach().to(device)[0]
- else:
- cache[hook.name] = tensor.detach().to(device)
+ self.is_caching = True
- def save_hook_back(tensor, hook):
+ def save_hook(tensor, hook, is_backward):
+ hook_name = hook.name
+ if is_backward:
+ hook_name += "_grad"
if remove_batch_dim:
- cache[hook.name + "_grad"] = tensor.detach().to(device)[0]
+ cache[hook_name] = tensor.detach().to(device)[0]
else:
- cache[hook.name + "_grad"] = tensor.detach().to(device)
+ cache[hook_name] = tensor.detach().to(device)
for name, hp in self.hook_dict.items():
if names_filter(name):
- hp.add_hook(save_hook, "fwd")
+ hp.add_hook(partial(save_hook, is_backward=False), "fwd")
if incl_bwd:
- hp.add_hook(save_hook_back, "bwd")
+ hp.add_hook(partial(save_hook, is_backward=True), "bwd")
return cache
def run_with_cache(
@@ -426,6 +419,7 @@ def run_with_cache(
incl_bwd=False,
reset_hooks_end=True,
clear_contexts=False,
+ pos_slice=None,
**model_kwargs,
):
"""
@@ -448,14 +442,28 @@ def run_with_cache(
end of the run. Defaults to True.
clear_contexts (bool, optional): If True, clears hook contexts whenever hooks are reset.
Defaults to False.
+ pos_slice:
+ The slice to apply to the cache output. Defaults to None, do nothing.
**model_kwargs: Keyword arguments for the model.
Returns:
tuple: A tuple containing the model output and a Cache object.
"""
+
+ if not isinstance(pos_slice, Slice):
+ if isinstance(
+ pos_slice, int
+ ): # slicing with an int collapses the dimension so this stops the pos dimension from collapsing
+ pos_slice = [pos_slice]
+ pos_slice = Slice(pos_slice)
+
cache_dict, fwd, bwd = self.get_caching_hooks(
- names_filter, incl_bwd, device, remove_batch_dim=remove_batch_dim
+ names_filter,
+ incl_bwd,
+ device,
+ remove_batch_dim=remove_batch_dim,
+ pos_slice=pos_slice,
)
with self.hooks(
@@ -477,6 +485,7 @@ def get_caching_hooks(
device=None,
remove_batch_dim: bool = False,
cache: Optional[dict] = None,
+ pos_slice: Optional[Slice] = None,
) -> Tuple[dict, list, list]:
"""Creates hooks to cache activations. Note: It does not add the hooks to the model.
@@ -497,33 +506,53 @@ def get_caching_hooks(
if names_filter is None:
names_filter = lambda name: True
- elif type(names_filter) == str:
+ elif isinstance(names_filter, str):
filter_str = names_filter
names_filter = lambda name: name == filter_str
- elif type(names_filter) == list:
+ elif isinstance(names_filter, list):
filter_list = names_filter
names_filter = lambda name: name in filter_list
self.is_caching = True
- def save_hook(tensor, hook):
- if remove_batch_dim:
- cache[hook.name] = tensor.detach().to(device)[0]
- else:
- cache[hook.name] = tensor.detach().to(device)
+ # mypy can't seem to infer this
+ names_filter = cast(Callable[[str], bool], names_filter)
- def save_hook_back(tensor, hook):
+ def save_hook(tensor, hook, is_backward=False):
+ hook_name = hook.name
+ if is_backward:
+ hook_name += "_grad"
+ resid_stream = tensor.detach().to(device)
if remove_batch_dim:
- cache[hook.name + "_grad"] = tensor.detach().to(device)[0]
+ resid_stream = resid_stream[0]
+
+ # for attention heads the pos dimension is the third from last
+ if (
+ hook.name.endswith("hook_q")
+ or hook.name.endswith("hook_k")
+ or hook.name.endswith("hook_v")
+ or hook.name.endswith("hook_z")
+ or hook.name.endswith("hook_result")
+ ):
+ pos_dim = -3
else:
- cache[hook.name + "_grad"] = tensor.detach().to(device)
+ # for all other components the pos dimension is the second from last
+ # including the attn scores where the dest token is the second from last
+ pos_dim = -2
+
+ if (
+ tensor.dim() >= -pos_dim
+ ): # check if the residual stream has a pos dimension before trying to slice
+ assert pos_slice is not None # keep mypy happy
+ resid_stream = pos_slice.apply(resid_stream, dim=pos_dim)
+ cache[hook_name] = resid_stream
fwd_hooks = []
bwd_hooks = []
for name, hp in self.hook_dict.items():
if names_filter(name):
- fwd_hooks.append((name, save_hook))
+ fwd_hooks.append((name, partial(save_hook, is_backward=False)))
if incl_bwd:
- bwd_hooks.append((name, save_hook_back))
+ bwd_hooks.append((name, partial(save_hook, is_backward=True)))
return cache, fwd_hooks, bwd_hooks
diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py
index 123582f1a..0a8d132cc 100644
--- a/transformer_lens/loading_from_pretrained.py
+++ b/transformer_lens/loading_from_pretrained.py
@@ -2,10 +2,12 @@
This module contains functions for loading pretrained models from the Hugging Face Hub.
"""
+
import dataclasses
import logging
+import os
import re
-from typing import Dict, Optional
+from typing import Dict, Optional, Union, cast
import einops
import torch
@@ -111,11 +113,18 @@
"llama-13b-hf",
"llama-30b-hf",
"llama-65b-hf",
- "Llama-2-7b-hf",
- "Llama-2-7b-chat-hf",
- "Llama-2-13b-hf",
- "Llama-2-13b-chat-hf",
- # TODO Llama-2-70b-hf requires Grouped-Query Attention, see the paper https://arxiv.org/pdf/2307.09288.pdf
+ "meta-llama/Llama-2-7b-hf",
+ "meta-llama/Llama-2-7b-chat-hf",
+ "meta-llama/Llama-2-13b-hf",
+ "meta-llama/Llama-2-13b-chat-hf",
+ "meta-llama/Llama-2-70b-chat-hf",
+ "CodeLlama-7b-hf",
+ "CodeLlama-7b-Python-hf",
+ "CodeLlama-7b-Instruct-hf",
+ "meta-llama/Meta-Llama-3-8B",
+ "meta-llama/Meta-Llama-3-8B-Instruct",
+ "meta-llama/Meta-Llama-3-70B",
+ "meta-llama/Meta-Llama-3-70B-Instruct",
"Baidicoot/Othello-GPT-Transformer-Lens",
"bert-base-cased",
"roneneldan/TinyStories-1M",
@@ -136,8 +145,43 @@
"stabilityai/stablelm-base-alpha-7b",
"stabilityai/stablelm-tuned-alpha-3b",
"stabilityai/stablelm-tuned-alpha-7b",
+ "mistralai/Mistral-7B-v0.1",
+ "mistralai/Mistral-7B-Instruct-v0.1",
+ "mistralai/Mixtral-8x7B-v0.1",
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
"bigscience/bloom-560m",
+ "bigscience/bloom-1b1",
+ "bigscience/bloom-1b7",
+ "bigscience/bloom-3b",
+ "bigscience/bloom-7b1",
"bigcode/santacoder",
+ "Qwen/Qwen-1_8B",
+ "Qwen/Qwen-7B",
+ "Qwen/Qwen-14B",
+ "Qwen/Qwen-1_8B-Chat",
+ "Qwen/Qwen-7B-Chat",
+ "Qwen/Qwen-14B-Chat",
+ "Qwen/Qwen1.5-0.5B",
+ "Qwen/Qwen1.5-0.5B-Chat",
+ "Qwen/Qwen1.5-1.8B",
+ "Qwen/Qwen1.5-1.8B-Chat",
+ "Qwen/Qwen1.5-4B",
+ "Qwen/Qwen1.5-4B-Chat",
+ "Qwen/Qwen1.5-7B",
+ "Qwen/Qwen1.5-7B-Chat",
+ "Qwen/Qwen1.5-14B",
+ "Qwen/Qwen1.5-14B-Chat",
+ "microsoft/phi-1",
+ "microsoft/phi-1_5",
+ "microsoft/phi-2",
+ "google/gemma-2b",
+ "google/gemma-7b",
+ "google/gemma-2b-it",
+ "google/gemma-7b-it",
+ "01-ai/Yi-6B",
+ "01-ai/Yi-34B",
+ "01-ai/Yi-6B-Chat",
+ "01-ai/Yi-34B-Chat",
]
"""Official model names for models on HuggingFace."""
@@ -460,11 +504,26 @@
"llama-13b-hf": ["llama-13b"],
"llama-30b-hf": ["llama-30b"],
"llama-65b-hf": ["llama-65b"],
- "Llama-2-7b-hf": ["Llama-2-7b", "meta-llama/Llama-2-7b-hf"],
- "Llama-2-7b-chat-hf": ["Llama-2-7b-chat", "meta-llama/Llama-2-7b-chat-hf"],
- "Llama-2-13b-hf": ["Llama-2-13b", "meta-llama/Llama-2-13b-hf"],
- "Llama-2-13b-chat-hf": ["Llama-2-13b-chat", "meta-llama/Llama-2-13b-chat-hf"],
- # TODO Llama-2-70b-hf requires Grouped-Query Attention, see the paper https://arxiv.org/pdf/2307.09288.pdf
+ "meta-llama/Llama-2-7b-hf": ["Llama-2-7b", "meta-llama/Llama-2-7b-hf"],
+ "meta-llama/Llama-2-7b-chat-hf": [
+ "Llama-2-7b-chat",
+ "meta-llama/Llama-2-7b-chat-hf",
+ ],
+ "meta-llama/Llama-2-13b-hf": ["Llama-2-13b", "meta-llama/Llama-2-13b-hf"],
+ "meta-llama/Llama-2-13b-chat-hf": [
+ "Llama-2-13b-chat",
+ "meta-llama/Llama-2-13b-chat-hf",
+ ],
+ "meta-llama/Llama-2-70b-chat-hf": ["Llama-2-70b-chat", "meta-llama-2-70b-chat-hf"],
+ "CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"],
+ "CodeLlama-7b-Python-hf": [
+ "CodeLlama-7b-python",
+ "codellama/CodeLlama-7b-Python-hf",
+ ],
+ "CodeLlama-7b-Instruct-hf": [
+ "CodeLlama-7b-instruct",
+ "codellama/CodeLlama-7b-Instruct-hf",
+ ],
"Baidicoot/Othello-GPT-Transformer-Lens": ["othello-gpt"],
"roneneldan/TinyStories-1M": ["tiny-stories-1M"],
"roneneldan/TinyStories-3M": ["tiny-stories-3M"],
@@ -496,17 +555,68 @@
"stablelm-tuned-alpha-7b",
"stablelm-tuned-7b",
],
+ "mistralai/Mistral-7B-v0.1": ["mistral-7b"],
+ "mistralai/Mistral-7B-Instruct-v0.1": ["mistral-7b-instruct"],
+ "mistralai/Mixtral-8x7B-v0.1": ["mixtral", "mixtral-8x7b"],
+ "mistralai/Mixtral-8x7B-Instruct-v0.1": [
+ "mixtral-instruct",
+ "mixtral-8x7b-instruct",
+ ],
"bigscience/bloom-560m": ["bloom-560m"],
+ "bigscience/bloom-1b1": ["bloom-1b1"],
+ "bigscience/bloom-1b7": ["bloom-1b7"],
+ "bigscience/bloom-3b": ["bloom-3b"],
+ "bigscience/bloom-7b1": ["bloom-7b1"],
"bigcode/santacoder": ["santacoder"],
+ "Qwen/Qwen-1_8B": ["qwen-1.8b"],
+ "Qwen/Qwen-7B": ["qwen-7b"],
+ "Qwen/Qwen-14B": ["qwen-14b"],
+ "Qwen/Qwen-1_8B-Chat": ["qwen-1.8b-chat"],
+ "Qwen/Qwen-7B-Chat": ["qwen-7b-chat"],
+ "Qwen/Qwen-14B-Chat": ["qwen-14b-chat"],
+ "Qwen/Qwen1.5-0.5B": ["qwen1.5-0.5b"],
+ "Qwen/Qwen1.5-0.5B-Chat": ["qwen1.5-0.5b-chat"],
+ "Qwen/Qwen1.5-1.8B": ["qwen1.5-1.8b"],
+ "Qwen/Qwen1.5-1.8B-Chat": ["qwen1.5-1.8b-chat"],
+ "Qwen/Qwen1.5-4B": ["qwen1.5-4b"],
+ "Qwen/Qwen1.5-4B-Chat": ["qwen1.5-4b-chat"],
+ "Qwen/Qwen1.5-7B": ["qwen1.5-7b"],
+ "Qwen/Qwen1.5-7B-Chat": ["qwen1.5-7b-chat"],
+ "Qwen/Qwen1.5-14B": ["qwen1.5-14b"],
+ "Qwen/Qwen1.5-14B-Chat": ["qwen1.5-14b-chat"],
+ "microsoft/phi-1": ["phi-1"],
+ "microsoft/phi-1_5": ["phi-1_5"],
+ "microsoft/phi-2": ["phi-2"],
+ "google/gemma-2b": ["gemma-2b"],
+ "google/gemma-7b": ["gemma-7b"],
+ "google/gemma-2b-it": ["gemma-2b-it"],
+ "google/gemma-7b-it": ["gemma-7b-it"],
+ "01-ai/Yi-6B": ["yi-6b", "Yi-6B"],
+ "01-ai/Yi-34B": ["yi-34b", "Yi-34B"],
+ "01-ai/Yi-6B-Chat": ["yi-6b-chat", "Yi-6B-Chat"],
+ "01-ai/Yi-34B-Chat": ["yi-34b-chat", "Yi-34B-Chat"],
}
"""Model aliases for models on HuggingFace."""
+NON_HF_HOSTED_MODEL_NAMES = [
+ "llama-7b-hf",
+ "llama-13b-hf",
+ "llama-30b-hf",
+ "llama-65b-hf",
+]
+"""Official model names for models not hosted on HuggingFace."""
+
# Sets a default model alias, by convention the first one in the model alias table, else the official name if it has no aliases
DEFAULT_MODEL_ALIASES = [
- MODEL_ALIASES[name][0] if name in MODEL_ALIASES else name
- for name in OFFICIAL_MODEL_NAMES
+ MODEL_ALIASES[name][0] if name in MODEL_ALIASES else name for name in OFFICIAL_MODEL_NAMES
]
+NEED_REMOTE_CODE_MODELS = (
+ "bigcode/santacoder",
+ "Qwen/Qwen-",
+ "microsoft/phi-2",
+)
+
def make_model_alias_map():
"""
@@ -546,13 +656,21 @@ def convert_hf_model_config(model_name: str, **kwargs):
# In case the user passed in an alias
official_model_name = get_official_model_name(model_name)
# Load HuggingFace model config
- if "llama" not in official_model_name.lower():
- hf_config = AutoConfig.from_pretrained(official_model_name, **kwargs)
- architecture = hf_config.architectures[0]
- else:
+ if "llama" in official_model_name.lower():
architecture = "LlamaForCausalLM"
+ elif "gemma" in official_model_name.lower():
+ architecture = "GemmaForCausalLM"
+ else:
+ huggingface_token = os.environ.get("HF_TOKEN", None)
+ hf_config = AutoConfig.from_pretrained(
+ official_model_name,
+ token=huggingface_token,
+ **kwargs,
+ )
+ architecture = hf_config.architectures[0]
+
if official_model_name.startswith(
- ("llama-7b", "Llama-2-7b")
+ ("llama-7b", "meta-llama/Llama-2-7b")
): # same architecture for LLaMA and Llama-2
cfg_dict = {
"d_model": 4096,
@@ -566,12 +684,34 @@ def convert_hf_model_config(model_name: str, **kwargs):
"act_fn": "silu",
"normalization_type": "RMS",
"positional_embedding_type": "rotary",
+ "rotary_adjacent_pairs": False,
+ "rotary_dim": 4096 // 32,
+ "final_rms": True,
+ "gated_mlp": True,
+ }
+ elif official_model_name.startswith("CodeLlama-7b"): # same architecture CodeLlama and Llama-2
+ cfg_dict = {
+ "d_model": 4096,
+ "d_head": 4096 // 32,
+ "n_heads": 32,
+ "d_mlp": 11008,
+ "n_layers": 32,
+ "n_ctx": 4096,
+ "eps": 1e-5,
+ "d_vocab": 32016,
+ "act_fn": "silu",
+ "normalization_type": "RMS",
+ "positional_embedding_type": "rotary",
"rotary_dim": 4096 // 32,
"final_rms": True,
"gated_mlp": True,
+ "rotary_base": 1000000,
}
+ if "python" in official_model_name.lower():
+ # The vocab size of python version of CodeLlama-7b is 32000
+ cfg_dict["d_vocab"] = 32000
elif official_model_name.startswith(
- ("llama-13b", "Llama-2-13b")
+ ("llama-13b", "meta-llama/Llama-2-13b")
): # same architecture for LLaMA and Llama-2
cfg_dict = {
"d_model": 5120,
@@ -585,6 +725,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
"act_fn": "silu",
"normalization_type": "RMS",
"positional_embedding_type": "rotary",
+ "rotary_adjacent_pairs": False,
"rotary_dim": 5120 // 40,
"final_rms": True,
"gated_mlp": True,
@@ -602,6 +743,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
"act_fn": "silu",
"normalization_type": "RMS",
"positional_embedding_type": "rotary",
+ "rotary_adjacent_pairs": False,
"rotary_dim": 6656 // 52,
"final_rms": True,
"gated_mlp": True,
@@ -620,6 +762,64 @@ def convert_hf_model_config(model_name: str, **kwargs):
"normalization_type": "RMS",
"positional_embedding_type": "rotary",
"rotary_dim": 8192 // 64,
+ "rotary_adjacent_pairs": False,
+ "final_rms": True,
+ "gated_mlp": True,
+ }
+ elif "Llama-2-70b" in official_model_name:
+ cfg_dict = {
+ "d_model": 8192,
+ "d_head": 128,
+ "n_heads": 64,
+ "d_mlp": 28672,
+ "n_layers": 80,
+ "n_ctx": 4096,
+ "eps": 1e-5,
+ "d_vocab": 32000,
+ "act_fn": "silu",
+ "n_key_value_heads": 8,
+ "normalization_type": "RMS",
+ "positional_embedding_type": "rotary",
+ "rotary_adjacent_pairs": False,
+ "rotary_dim": 128,
+ "final_rms": True,
+ "gated_mlp": True,
+ }
+ elif "Meta-Llama-3-8B" in official_model_name:
+ cfg_dict = {
+ "d_model": 4096,
+ "d_head": 128,
+ "n_heads": 32,
+ "d_mlp": 14336,
+ "n_layers": 32,
+ "n_ctx": 8192,
+ "eps": 1e-5,
+ "d_vocab": 128256,
+ "act_fn": "silu",
+ "n_key_value_heads": 8,
+ "normalization_type": "RMS",
+ "positional_embedding_type": "rotary",
+ "rotary_adjacent_pairs": False,
+ "rotary_dim": 128,
+ "final_rms": True,
+ "gated_mlp": True,
+ }
+ elif "Meta-Llama-3-70B" in official_model_name:
+ cfg_dict = {
+ "d_model": 8192,
+ "d_head": 128,
+ "n_heads": 64,
+ "d_mlp": 28672,
+ "n_layers": 80,
+ "n_ctx": 8192,
+ "eps": 1e-5,
+ "d_vocab": 128256,
+ "act_fn": "silu",
+ "n_key_value_heads": 8,
+ "normalization_type": "RMS",
+ "positional_embedding_type": "rotary",
+ "rotary_adjacent_pairs": False,
+ "rotary_dim": 128,
"final_rms": True,
"gated_mlp": True,
}
@@ -690,6 +890,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
"parallel_attn_mlp": True,
"positional_embedding_type": "rotary",
"rotary_dim": hf_config.rotary_dim,
+ "rotary_adjacent_pairs": True,
"normalization_type": "LN",
}
elif architecture == "GPTNeoXForCausalLM":
@@ -708,6 +909,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
"scale_attn_by_inverse_layer_idx": False,
"parallel_attn_mlp": True,
"positional_embedding_type": "rotary",
+ "rotary_adjacent_pairs": False,
"normalization_type": "LN",
}
rotary_pct = hf_config.rotary_pct
@@ -725,6 +927,48 @@ def convert_hf_model_config(model_name: str, **kwargs):
"act_fn": "gelu",
"attention_dir": "bidirectional",
}
+ elif architecture == "MistralForCausalLM":
+ cfg_dict = {
+ "d_model": 4096,
+ "d_head": 4096 // 32,
+ "n_heads": 32,
+ "d_mlp": 14336,
+ "n_layers": 32,
+ "n_ctx": 2048, # Capped due to memory issues
+ "d_vocab": 32000,
+ "act_fn": "silu",
+ "normalization_type": "RMS",
+ "positional_embedding_type": "rotary",
+ "window_size": 4096,
+ "attn_types": ["local"] * 32,
+ "eps": 1e-05,
+ "n_key_value_heads": 8,
+ "gated_mlp": True,
+ "use_local_attn": True,
+ "rotary_dim": 4096 // 32,
+ }
+ elif architecture == "MixtralForCausalLM":
+ cfg_dict = {
+ "d_model": hf_config.hidden_size,
+ "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
+ "n_heads": hf_config.num_attention_heads,
+ "d_mlp": hf_config.intermediate_size,
+ "n_layers": hf_config.num_hidden_layers,
+ "n_ctx": 2048, # hf_config.max_position_embeddings, # Capped due to memory issues
+ "d_vocab": hf_config.vocab_size,
+ "act_fn": hf_config.hidden_act,
+ "normalization_type": "RMS",
+ "positional_embedding_type": "rotary",
+ "window_size": hf_config.sliding_window, # This is None, as no sliding window was used
+ "attn_types": ["global"] * 32,
+ "eps": hf_config.rms_norm_eps,
+ "n_key_value_heads": hf_config.num_key_value_heads,
+ "gated_mlp": True,
+ "use_local_attn": False,
+ "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
+ "num_experts": hf_config.num_local_experts,
+ "experts_per_token": hf_config.num_experts_per_tok,
+ }
elif architecture == "BloomForCausalLM":
cfg_dict = {
"d_model": hf_config.hidden_size,
@@ -740,7 +984,6 @@ def convert_hf_model_config(model_name: str, **kwargs):
"post_embedding_ln": True,
"positional_embedding_type": "alibi",
}
-
elif architecture == "GPT2LMHeadCustomModel":
# santacoder
cfg_dict = {
@@ -755,9 +998,149 @@ def convert_hf_model_config(model_name: str, **kwargs):
"act_fn": hf_config.activation_function,
"use_attn_scale": True,
"use_local_attn": False,
+ "trust_remote_code": "santacoder"
+ in official_model_name, # Only santacoder needs trust_remote_code
"scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
"normalization_type": "LN",
}
+ elif architecture == "LlamaForCausalLM":
+ cfg_dict = {
+ "d_model": hf_config.hidden_size,
+ "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
+ "n_heads": hf_config.num_attention_heads,
+ "d_mlp": hf_config.intermediate_size,
+ "n_layers": hf_config.num_hidden_layers,
+ "n_ctx": hf_config.max_position_embeddings,
+ "eps": hf_config.rms_norm_eps,
+ "d_vocab": hf_config.vocab_size,
+ "act_fn": hf_config.hidden_act,
+ "n_key_value_heads": (
+ hf_config.num_key_value_heads
+ if hf_config.num_key_value_heads != hf_config.num_attention_heads
+ else None
+ ),
+ # This is done because the current implementation of GQA will use Grouped-Query Attention if
+ # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as
+ # the same as hf_config.num_attention_heads, in which case GQA should not be used.
+ "normalization_type": "RMS",
+ "positional_embedding_type": "rotary",
+ "rotary_adjacent_pairs": False,
+ "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
+ "final_rms": True,
+ "gated_mlp": True,
+ }
+ elif architecture == "QWenLMHeadModel":
+ cfg_dict = {
+ "d_model": hf_config.hidden_size,
+ "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
+ "n_heads": hf_config.num_attention_heads,
+ "d_mlp": hf_config.intermediate_size // 2,
+ "n_layers": hf_config.num_hidden_layers,
+ "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
+ "eps": hf_config.layer_norm_epsilon,
+ "d_vocab": hf_config.vocab_size,
+ "act_fn": "silu",
+ "use_attn_scale": hf_config.scale_attn_weights,
+ "initializer_range": hf_config.initializer_range,
+ "normalization_type": "RMS",
+ "positional_embedding_type": "rotary",
+ "rotary_dim": hf_config.kv_channels,
+ "rotary_adjacent_pairs": False,
+ "tokenizer_prepends_bos": True,
+ "trust_remote_code": True,
+ "final_rms": True,
+ "gated_mlp": True,
+ }
+ elif architecture == "Qwen2ForCausalLM":
+ # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM.
+ cfg_dict = {
+ "d_model": hf_config.hidden_size,
+ "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
+ "n_heads": hf_config.num_attention_heads,
+ "d_mlp": hf_config.intermediate_size,
+ "n_layers": hf_config.num_hidden_layers,
+ "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
+ "eps": hf_config.rms_norm_eps,
+ "d_vocab": hf_config.vocab_size,
+ "act_fn": hf_config.hidden_act,
+ "use_attn_scale": True,
+ "initializer_range": hf_config.initializer_range,
+ "normalization_type": "RMS",
+ "positional_embedding_type": "rotary",
+ "rotary_base": hf_config.rope_theta,
+ "rotary_adjacent_pairs": False,
+ "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
+ "tokenizer_prepends_bos": True,
+ "final_rms": True,
+ "gated_mlp": True,
+ }
+ elif architecture == "PhiForCausalLM":
+ # Architecture for microsoft/phi models
+ cfg_dict = {
+ "d_model": hf_config.hidden_size,
+ "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
+ "n_heads": hf_config.num_attention_heads,
+ "d_mlp": hf_config.intermediate_size,
+ "n_layers": hf_config.num_hidden_layers,
+ "n_ctx": hf_config.max_position_embeddings,
+ "eps": hf_config.layer_norm_eps,
+ "d_vocab": hf_config.vocab_size,
+ "act_fn": hf_config.hidden_act,
+ "initializer_range": hf_config.initializer_range,
+ "normalization_type": "LN",
+ "positional_embedding_type": "rotary",
+ "trust_remote_code": True,
+ "rotary_base": hf_config.rope_theta,
+ "use_attn_scale": True,
+ "parallel_attn_mlp": True,
+ }
+ partial_rotary_factor = hf_config.partial_rotary_factor
+ cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"])
+
+ elif official_model_name.startswith("google/gemma-2b"):
+ # Architecture for Gemma 2b and Gemma 2b Instruct models
+ cfg_dict = {
+ "d_model": 2048,
+ "d_head": 256,
+ "n_heads": 8,
+ "d_mlp": 16384,
+ "n_layers": 18,
+ "n_ctx": 8192,
+ "eps": 1e-06,
+ "d_vocab": 256000,
+ "act_fn": "gelu",
+ "initializer_range": 0.02,
+ "normalization_type": "RMS",
+ "rotary_base": 10000.0,
+ "rotary_dim": 256,
+ "positional_embedding_type": "rotary",
+ "use_attn_scale": True,
+ "n_key_value_heads": 1,
+ "gated_mlp": True,
+ "final_rms": True,
+ }
+ elif official_model_name.startswith("google/gemma-7b"):
+ # Architecture for Gemma 7b and Gemma 7b Instruct models
+ cfg_dict = {
+ "d_model": 3072,
+ "d_head": 256,
+ "n_heads": 16,
+ "d_mlp": 24576,
+ "n_layers": 28,
+ "n_ctx": 8192,
+ "eps": 1e-06,
+ "d_vocab": 256000,
+ "act_fn": "gelu",
+ "initializer_range": 0.02,
+ "normalization_type": "RMS",
+ "rotary_base": 10000.0,
+ "rotary_dim": 256,
+ "positional_embedding_type": "rotary",
+ "use_attn_scale": True,
+ "n_key_value_heads": 16,
+ "gated_mlp": True,
+ "final_rms": True,
+ }
else:
raise NotImplementedError(f"{architecture} is not currently supported.")
# All of these models use LayerNorm
@@ -775,9 +1158,7 @@ def convert_neel_model_config(official_model_name: str, **kwargs):
AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json.
"""
official_model_name = get_official_model_name(official_model_name)
- cfg_json: dict = utils.download_file_from_hf(
- official_model_name, "config.json", **kwargs
- )
+ cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs)
cfg_arch = cfg_json.get(
"architecture", "neel" if "_old" not in official_model_name else "neel-solu-old"
)
@@ -810,10 +1191,11 @@ def convert_neel_model_config(official_model_name: str, **kwargs):
def get_pretrained_model_config(
model_name: str,
+ hf_cfg: Optional[dict] = None,
checkpoint_index: Optional[int] = None,
checkpoint_value: Optional[int] = None,
fold_ln: bool = False,
- device: Optional[str] = None,
+ device: Optional[Union[str, torch.device]] = None,
n_devices: int = 1,
default_prepend_bos: bool = True,
dtype: torch.dtype = torch.float32,
@@ -829,6 +1211,8 @@ def get_pretrained_model_config(
model_name: The name of the model. This can be either the official
HuggingFace model name, or the name of a model trained by me
(NeelNanda).
+ hf_cfg (dict, optional): Config of a loaded pretrained HF model,
+ converted to a dictionary.
checkpoint_index (int, optional): If loading from a
checkpoint, the index of the checkpoint to load. Defaults to None.
checkpoint_value (int, optional): If loading from a checkpoint, the
@@ -861,6 +1245,13 @@ def get_pretrained_model_config(
):
cfg_dict = convert_neel_model_config(official_model_name, **kwargs)
else:
+ if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get(
+ "trust_remote_code", False
+ ):
+ logging.warning(
+ f"Loading model {official_model_name} requires setting trust_remote_code=True"
+ )
+ kwargs["trust_remote_code"] = True
cfg_dict = convert_hf_model_config(official_model_name, **kwargs)
# Processing common to both model types
# Remove any prefix, saying the organization who made a model.
@@ -886,6 +1277,8 @@ def get_pretrained_model_config(
if fold_ln:
if cfg_dict["normalization_type"] in ["LN", "LNPre"]:
cfg_dict["normalization_type"] = "LNPre"
+ elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]:
+ cfg_dict["normalization_type"] = "RMSPre"
else:
logging.warning("Cannot fold in layer norm, normalization_type is not LN.")
@@ -911,6 +1304,8 @@ def get_pretrained_model_config(
cfg_dict["device"] = device
cfg_dict["n_devices"] = n_devices
cfg_dict["default_prepend_bos"] = default_prepend_bos
+ if hf_cfg is not None:
+ cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False)
cfg = HookedTransformerConfig.from_dict(cfg_dict)
return cfg
@@ -977,8 +1372,6 @@ def get_checkpoint_labels(model_name: str, **kwargs):
# %% Loading state dicts
-
-
def get_pretrained_state_dict(
official_model_name: str,
cfg: HookedTransformerConfig,
@@ -1001,11 +1394,11 @@ def get_pretrained_state_dict(
dtype = kwargs["torch_dtype"]
del kwargs["torch_dtype"]
official_model_name = get_official_model_name(official_model_name)
- if official_model_name == "bigcode/santacoder" and not kwargs.get(
+ if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get(
"trust_remote_code", False
):
logging.warning(
- "Loading santacoder model requires setting trust_remote_code=True"
+ f"Loading model {official_model_name} state dict requires setting trust_remote_code=True"
)
kwargs["trust_remote_code"] = True
if (
@@ -1024,9 +1417,7 @@ def get_pretrained_state_dict(
)[0]
else:
file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0]
- state_dict = utils.download_file_from_hf(
- official_model_name, file_name, **kwargs
- )
+ state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs)
# Convert to dtype
state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
@@ -1038,11 +1429,13 @@ def get_pretrained_state_dict(
return state_dict
else:
if cfg.from_checkpoint:
+ huggingface_token = os.environ.get("HF_TOKEN", None)
if official_model_name.startswith("stanford-crfm"):
hf_model = AutoModelForCausalLM.from_pretrained(
official_model_name,
revision=f"checkpoint-{cfg.checkpoint_value}",
torch_dtype=dtype,
+ token=huggingface_token,
**kwargs,
)
elif official_model_name.startswith("EleutherAI/pythia"):
@@ -1050,22 +1443,28 @@ def get_pretrained_state_dict(
official_model_name,
revision=f"step{cfg.checkpoint_value}",
torch_dtype=dtype,
+ token=huggingface_token,
**kwargs,
)
else:
- raise ValueError(
- f"Checkpoints for model {official_model_name} are not supported"
- )
+ raise ValueError(f"Checkpoints for model {official_model_name} are not supported")
elif hf_model is None:
- if "llama" in official_model_name.lower():
- raise NotImplementedError("Must pass in hf_model for LLaMA models")
+ huggingface_token = os.environ.get("HF_TOKEN", None)
+ if official_model_name in NON_HF_HOSTED_MODEL_NAMES:
+ raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model")
elif "bert" in official_model_name:
hf_model = BertForPreTraining.from_pretrained(
- official_model_name, torch_dtype=dtype, **kwargs
+ official_model_name,
+ torch_dtype=dtype,
+ token=huggingface_token,
+ **kwargs,
)
else:
hf_model = AutoModelForCausalLM.from_pretrained(
- official_model_name, torch_dtype=dtype, **kwargs
+ official_model_name,
+ torch_dtype=dtype,
+ token=huggingface_token,
+ **kwargs,
)
# Load model weights, and fold in layer norm weights
@@ -1087,10 +1486,22 @@ def get_pretrained_state_dict(
state_dict = convert_llama_weights(hf_model, cfg)
elif cfg.original_architecture == "BertForMaskedLM":
state_dict = convert_bert_weights(hf_model, cfg)
+ elif cfg.original_architecture == "MistralForCausalLM":
+ state_dict = convert_mistral_weights(hf_model, cfg)
+ elif cfg.original_architecture == "MixtralForCausalLM":
+ state_dict = convert_mixtral_weights(hf_model, cfg)
elif cfg.original_architecture == "BloomForCausalLM":
state_dict = convert_bloom_weights(hf_model, cfg)
elif cfg.original_architecture == "GPT2LMHeadCustomModel":
state_dict = convert_coder_weights(hf_model, cfg)
+ elif cfg.original_architecture == "QWenLMHeadModel":
+ state_dict = convert_qwen_weights(hf_model, cfg)
+ elif cfg.original_architecture == "Qwen2ForCausalLM":
+ state_dict = convert_qwen2_weights(hf_model, cfg)
+ elif cfg.original_architecture == "PhiForCausalLM":
+ state_dict = convert_phi_weights(hf_model, cfg)
+ elif cfg.original_architecture == "GemmaForCausalLM":
+ state_dict = convert_gemma_weights(hf_model, cfg)
else:
raise ValueError(
f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
@@ -1206,22 +1617,14 @@ def convert_neo_weights(neo, cfg: HookedTransformerConfig):
state_dict[f"blocks.{l}.attn.W_K"] = W_K
state_dict[f"blocks.{l}.attn.W_V"] = W_V
- state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(
- cfg.n_heads, cfg.d_head, dtype=cfg.dtype
- )
- state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(
- cfg.n_heads, cfg.d_head, dtype=cfg.dtype
- )
- state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(
- cfg.n_heads, cfg.d_head, dtype=cfg.dtype
- )
+ state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
+ state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
+ state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
W_O = neo.transformer.h[l].attn.attention.out_proj.weight
W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads)
state_dict[f"blocks.{l}.attn.W_O"] = W_O
- state_dict[f"blocks.{l}.attn.b_O"] = neo.transformer.h[
- l
- ].attn.attention.out_proj.bias
+ state_dict[f"blocks.{l}.attn.b_O"] = neo.transformer.h[l].attn.attention.out_proj.bias
state_dict[f"blocks.{l}.ln2.w"] = neo.transformer.h[l].ln_2.weight
state_dict[f"blocks.{l}.ln2.b"] = neo.transformer.h[l].ln_2.bias
@@ -1258,15 +1661,9 @@ def convert_gptj_weights(gptj, cfg: HookedTransformerConfig):
state_dict[f"blocks.{l}.attn.W_K"] = W_K
state_dict[f"blocks.{l}.attn.W_V"] = W_V
- state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(
- cfg.n_heads, cfg.d_head, dtype=cfg.dtype
- )
- state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(
- cfg.n_heads, cfg.d_head, dtype=cfg.dtype
- )
- state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(
- cfg.n_heads, cfg.d_head, dtype=cfg.dtype
- )
+ state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
+ state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
+ state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
W_O = gptj.transformer.h[l].attn.out_proj.weight
W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads)
@@ -1328,30 +1725,16 @@ def convert_neox_weights(neox, cfg: HookedTransformerConfig):
W_O = neox.gpt_neox.layers[l].attention.dense.weight
W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads)
state_dict[f"blocks.{l}.attn.W_O"] = W_O
- state_dict[f"blocks.{l}.attn.b_O"] = neox.gpt_neox.layers[
- l
- ].attention.dense.bias
+ state_dict[f"blocks.{l}.attn.b_O"] = neox.gpt_neox.layers[l].attention.dense.bias
- state_dict[f"blocks.{l}.ln2.w"] = neox.gpt_neox.layers[
- l
- ].post_attention_layernorm.weight
- state_dict[f"blocks.{l}.ln2.b"] = neox.gpt_neox.layers[
- l
- ].post_attention_layernorm.bias
+ state_dict[f"blocks.{l}.ln2.w"] = neox.gpt_neox.layers[l].post_attention_layernorm.weight
+ state_dict[f"blocks.{l}.ln2.b"] = neox.gpt_neox.layers[l].post_attention_layernorm.bias
- state_dict[f"blocks.{l}.mlp.W_in"] = neox.gpt_neox.layers[
- l
- ].mlp.dense_h_to_4h.weight.T
- state_dict[f"blocks.{l}.mlp.b_in"] = neox.gpt_neox.layers[
- l
- ].mlp.dense_h_to_4h.bias
+ state_dict[f"blocks.{l}.mlp.W_in"] = neox.gpt_neox.layers[l].mlp.dense_h_to_4h.weight.T
+ state_dict[f"blocks.{l}.mlp.b_in"] = neox.gpt_neox.layers[l].mlp.dense_h_to_4h.bias
- state_dict[f"blocks.{l}.mlp.W_out"] = neox.gpt_neox.layers[
- l
- ].mlp.dense_4h_to_h.weight.T
- state_dict[f"blocks.{l}.mlp.b_out"] = neox.gpt_neox.layers[
- l
- ].mlp.dense_4h_to_h.bias
+ state_dict[f"blocks.{l}.mlp.W_out"] = neox.gpt_neox.layers[l].mlp.dense_4h_to_h.weight.T
+ state_dict[f"blocks.{l}.mlp.b_out"] = neox.gpt_neox.layers[l].mlp.dense_4h_to_h.bias
state_dict["ln_final.w"] = neox.gpt_neox.final_layer_norm.weight
state_dict["ln_final.b"] = neox.gpt_neox.final_layer_norm.bias
@@ -1365,15 +1748,102 @@ def convert_llama_weights(llama, cfg: HookedTransformerConfig):
state_dict["embed.W_E"] = llama.model.embed_tokens.weight
+ # Some models with the Llama architecture use Grouped Query Attention, and so for these we need to modify
+ # the state dict keys for the K/V attention weight/biases, prepending "_" to the key names.
+ using_gqa = cfg.n_key_value_heads is not None
+ gqa_uscore = "_" if using_gqa else ""
+ # need a cast since MyPy isn't smart enough to realize that using_gqa implies n_key_value_heads is not None
+ n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads)
+
# llama has no biases anywhere and deals with everything else roughly like
# GPTNeoX with different names
+ assert cfg.d_mlp is not None # keep mypy happy
+
for l in range(cfg.n_layers):
state_dict[f"blocks.{l}.ln1.w"] = llama.model.layers[l].input_layernorm.weight
W_Q = llama.model.layers[l].self_attn.q_proj.weight
W_K = llama.model.layers[l].self_attn.k_proj.weight
W_V = llama.model.layers[l].self_attn.v_proj.weight
+
+ # in case of quantization,
+ # parameters should stay as bitsandbytes.nn.modules.Params4bit
+ if not cfg.load_in_4bit:
+ W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
+ W_K = einops.rearrange(W_K, "(n h) m->n m h", n=n_kv_heads)
+ W_V = einops.rearrange(W_V, "(n h) m->n m h", n=n_kv_heads)
+
+ state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
+ state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K
+ state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V
+
+ state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(
+ cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device
+ )
+ state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros(
+ n_kv_heads,
+ cfg.d_head,
+ dtype=cfg.dtype,
+ device=cfg.device,
+ )
+ state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros(
+ n_kv_heads,
+ cfg.d_head,
+ dtype=cfg.dtype,
+ device=cfg.device,
+ )
+
+ W_O = llama.model.layers[l].self_attn.o_proj.weight
+
+ if not cfg.load_in_4bit:
+ W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
+
+ state_dict[f"blocks.{l}.attn.W_O"] = W_O.to(device=cfg.device)
+
+ state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(
+ cfg.d_model, dtype=cfg.dtype, device=cfg.device
+ )
+
+ state_dict[f"blocks.{l}.ln2.w"] = llama.model.layers[l].post_attention_layernorm.weight
+
+ # in case of quantization,
+ # parameters should stay as bitsandbytes.nn.modules.Params4bit
+ if not cfg.load_in_4bit:
+ state_dict[f"blocks.{l}.mlp.W_in"] = llama.model.layers[l].mlp.up_proj.weight.T
+ state_dict[f"blocks.{l}.mlp.W_gate"] = llama.model.layers[l].mlp.gate_proj.weight.T
+ state_dict[f"blocks.{l}.mlp.W_out"] = llama.model.layers[l].mlp.down_proj.weight.T
+ else:
+ state_dict[f"blocks.{l}.mlp.W_in"] = llama.model.layers[l].mlp.up_proj.weight
+ state_dict[f"blocks.{l}.mlp.W_gate"] = llama.model.layers[l].mlp.gate_proj.weight
+ state_dict[f"blocks.{l}.mlp.W_out"] = llama.model.layers[l].mlp.down_proj.weight
+
+ state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(
+ cfg.d_mlp, dtype=cfg.dtype, device=cfg.device
+ )
+ state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(
+ cfg.d_model, dtype=cfg.dtype, device=cfg.device
+ )
+
+ state_dict["ln_final.w"] = llama.model.norm.weight
+
+ state_dict["unembed.W_U"] = llama.lm_head.weight.T
+ state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=cfg.device)
+
+ return state_dict
+
+
+def convert_qwen_weights(qwen, cfg: HookedTransformerConfig):
+ state_dict = {}
+ model = qwen.transformer
+ state_dict["embed.W_E"] = model.wte.weight
+
+ assert cfg.d_mlp is not None # keep mypy happy
+
+ for l in range(cfg.n_layers):
+ state_dict[f"blocks.{l}.ln1.w"] = model.h[l].ln_1.weight
+
+ W_Q, W_K, W_V = model.h[l].attn.c_attn.weight.split(split_size=cfg.d_model, dim=0)
W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_heads)
W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_heads)
@@ -1381,40 +1851,242 @@ def convert_llama_weights(llama, cfg: HookedTransformerConfig):
state_dict[f"blocks.{l}.attn.W_K"] = W_K
state_dict[f"blocks.{l}.attn.W_V"] = W_V
- state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(
- cfg.n_heads, cfg.d_head, dtype=cfg.dtype
+ b_Q, b_K, b_V = model.h[l].attn.c_attn.bias.split(split_size=cfg.d_model, dim=0)
+ b_Q = einops.rearrange(
+ b_Q,
+ "(n_head d_head) -> n_head d_head",
+ n_head=cfg.n_heads,
)
- state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(
- cfg.n_heads, cfg.d_head, dtype=cfg.dtype
+ b_K = einops.rearrange(
+ b_K,
+ "(n_head d_head) -> n_head d_head",
+ n_head=cfg.n_heads,
)
- state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(
- cfg.n_heads, cfg.d_head, dtype=cfg.dtype
+ b_V = einops.rearrange(
+ b_V,
+ "(n_head d_head) -> n_head d_head",
+ n_head=cfg.n_heads,
)
+ state_dict[f"blocks.{l}.attn.b_Q"] = b_Q
+ state_dict[f"blocks.{l}.attn.b_K"] = b_K
+ state_dict[f"blocks.{l}.attn.b_V"] = b_V
- W_O = llama.model.layers[l].self_attn.o_proj.weight
+ W_O = model.h[l].attn.c_proj.weight
W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
state_dict[f"blocks.{l}.attn.W_O"] = W_O
state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
- state_dict[f"blocks.{l}.ln2.w"] = llama.model.layers[
- l
- ].post_attention_layernorm.weight
+ state_dict[f"blocks.{l}.ln2.w"] = model.h[l].ln_2.weight
- state_dict[f"blocks.{l}.mlp.W_in"] = llama.model.layers[l].mlp.up_proj.weight.T
- state_dict[f"blocks.{l}.mlp.W_gate"] = llama.model.layers[
- l
- ].mlp.gate_proj.weight.T
+ state_dict[f"blocks.{l}.mlp.W_in"] = model.h[l].mlp.w1.weight.T
+ state_dict[f"blocks.{l}.mlp.W_gate"] = model.h[l].mlp.w2.weight.T
state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype)
- state_dict[f"blocks.{l}.mlp.W_out"] = llama.model.layers[
- l
- ].mlp.down_proj.weight.T
+ state_dict[f"blocks.{l}.mlp.W_out"] = model.h[l].mlp.c_proj.weight.T
state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
- state_dict["ln_final.w"] = llama.model.norm.weight
+ state_dict["ln_final.w"] = model.ln_f.weight
- state_dict["unembed.W_U"] = llama.lm_head.weight.T
+ state_dict["unembed.W_U"] = qwen.lm_head.weight.T
+ state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
+
+ return state_dict
+
+
+def convert_qwen2_weights(qwen, cfg: HookedTransformerConfig):
+ # Note that this method is also applied for Qwen1.5 models, since they
+ # have architecture type Qwen2ForCausalLM.
+
+ state_dict = {}
+
+ state_dict["embed.W_E"] = qwen.model.embed_tokens.weight
+
+ assert cfg.d_mlp is not None # keep mypy happy
+
+ for l in range(cfg.n_layers):
+ state_dict[f"blocks.{l}.ln1.w"] = qwen.model.layers[l].input_layernorm.weight
+
+ W_Q = qwen.model.layers[l].self_attn.q_proj.weight
+ W_K = qwen.model.layers[l].self_attn.k_proj.weight
+ W_V = qwen.model.layers[l].self_attn.v_proj.weight
+ W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
+ W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_heads)
+ W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_heads)
+
+ state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
+ state_dict[f"blocks.{l}.attn.W_K"] = W_K
+ state_dict[f"blocks.{l}.attn.W_V"] = W_V
+
+ b_Q = qwen.model.layers[l].self_attn.q_proj.bias
+ b_Q = einops.rearrange(
+ b_Q,
+ "(n_head d_head) -> n_head d_head",
+ n_head=cfg.n_heads,
+ )
+
+ b_K = qwen.model.layers[l].self_attn.k_proj.bias
+ b_K = einops.rearrange(
+ b_K,
+ "(n_head d_head) -> n_head d_head",
+ n_head=cfg.n_heads,
+ )
+
+ b_V = qwen.model.layers[l].self_attn.v_proj.bias
+ b_V = einops.rearrange(
+ b_V,
+ "(n_head d_head) -> n_head d_head",
+ n_head=cfg.n_heads,
+ )
+
+ state_dict[f"blocks.{l}.attn.b_Q"] = b_Q
+ state_dict[f"blocks.{l}.attn.b_K"] = b_K
+ state_dict[f"blocks.{l}.attn.b_V"] = b_V
+
+ W_O = qwen.model.layers[l].self_attn.o_proj.weight
+ W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
+ state_dict[f"blocks.{l}.attn.W_O"] = W_O
+
+ state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
+
+ state_dict[f"blocks.{l}.ln2.w"] = qwen.model.layers[l].post_attention_layernorm.weight
+
+ state_dict[f"blocks.{l}.mlp.W_in"] = qwen.model.layers[l].mlp.up_proj.weight.T
+ state_dict[f"blocks.{l}.mlp.W_gate"] = qwen.model.layers[l].mlp.gate_proj.weight.T
+ state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype)
+
+ state_dict[f"blocks.{l}.mlp.W_out"] = qwen.model.layers[l].mlp.down_proj.weight.T
+ state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
+
+ state_dict["ln_final.w"] = qwen.model.norm.weight
+
+ state_dict["unembed.W_U"] = qwen.lm_head.weight.T
+ state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
+
+ return state_dict
+
+
+def convert_mistral_weights(mistral, cfg: HookedTransformerConfig):
+ state_dict = {}
+
+ state_dict["embed.W_E"] = mistral.model.embed_tokens.weight
+
+ assert cfg.n_key_value_heads is not None # keep mypy happy
+ assert cfg.d_mlp is not None # keep mypy happy
+
+ # Mistral has no biases anywhere
+ for l in range(cfg.n_layers):
+ state_dict[f"blocks.{l}.ln1.w"] = mistral.model.layers[l].input_layernorm.weight
+
+ W_Q = mistral.model.layers[l].self_attn.q_proj.weight
+ W_K = mistral.model.layers[l].self_attn.k_proj.weight
+ W_V = mistral.model.layers[l].self_attn.v_proj.weight
+ W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
+ W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads)
+ W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads)
+ state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
+ state_dict[f"blocks.{l}.attn._W_K"] = W_K
+ state_dict[f"blocks.{l}.attn._W_V"] = W_V
+
+ state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
+ state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(
+ cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
+ )
+ state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(
+ cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
+ )
+
+ W_O = mistral.model.layers[l].self_attn.o_proj.weight
+ W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
+ state_dict[f"blocks.{l}.attn.W_O"] = W_O
+
+ state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
+
+ state_dict[f"blocks.{l}.ln2.w"] = mistral.model.layers[l].post_attention_layernorm.weight
+
+ state_dict[f"blocks.{l}.mlp.W_in"] = mistral.model.layers[l].mlp.up_proj.weight.T
+ state_dict[f"blocks.{l}.mlp.W_gate"] = mistral.model.layers[l].mlp.gate_proj.weight.T
+ state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype)
+
+ state_dict[f"blocks.{l}.mlp.W_out"] = mistral.model.layers[l].mlp.down_proj.weight.T
+ state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
+
+ state_dict["ln_final.w"] = mistral.model.norm.weight
+
+ state_dict["unembed.W_U"] = mistral.lm_head.weight.T
+ state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
+
+ return state_dict
+
+
+def convert_mixtral_weights(mixtral, cfg: HookedTransformerConfig):
+ # The same as Mistral, but with the MLP replaced with MoE
+ # As with Mistral, Mixtral has no biases
+
+ state_dict = {}
+
+ assert cfg.n_key_value_heads is not None # keep mypy happy
+ assert cfg.d_mlp is not None
+ assert cfg.num_experts is not None
+
+ state_dict["embed.W_E"] = mixtral.model.embed_tokens.weight
+
+ for l in range(cfg.n_layers):
+ state_dict[f"blocks.{l}.ln1.w"] = mixtral.model.layers[l].input_layernorm.weight
+
+ W_Q = mixtral.model.layers[l].self_attn.q_proj.weight
+ W_K = mixtral.model.layers[l].self_attn.k_proj.weight
+ W_V = mixtral.model.layers[l].self_attn.v_proj.weight
+ W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
+ W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads)
+ W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads)
+ state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
+ state_dict[f"blocks.{l}.attn._W_K"] = W_K
+ state_dict[f"blocks.{l}.attn._W_V"] = W_V
+
+ state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
+ state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(
+ cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
+ )
+ state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(
+ cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
+ )
+
+ W_O = mixtral.model.layers[l].self_attn.o_proj.weight
+ W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
+ state_dict[f"blocks.{l}.attn.W_O"] = W_O
+
+ state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
+
+ state_dict[f"blocks.{l}.ln2.w"] = mixtral.model.layers[l].post_attention_layernorm.weight
+
+ state_dict[f"blocks.{l}.mlp.W_gate"] = mixtral.model.layers[
+ l
+ ].block_sparse_moe.gate.weight.T
+
+ # The mapping here from wn to W_{in/out/gate} is a bit confusing:
+ # w1 -> W_gate
+ # w2 -> W_out
+ # w3 -> W_in
+ # See https://github.com/mistralai/mistral-src/blob/main/mistral/model.py#L128 for reference
+ for e in range(cfg.num_experts):
+ state_dict[f"blocks.{l}.mlp.experts.{e}.W_in"] = (
+ mixtral.model.layers[l].block_sparse_moe.experts[e].w3.weight.T
+ )
+ state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate"] = (
+ mixtral.model.layers[l].block_sparse_moe.experts[e].w1.weight.T
+ )
+ state_dict[f"blocks.{l}.mlp.experts.{e}.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype)
+ state_dict[f"blocks.{l}.mlp.experts.{e}.W_out"] = (
+ mixtral.model.layers[l].block_sparse_moe.experts[e].w2.weight.T
+ )
+ state_dict[f"blocks.{l}.mlp.experts.{e}.b_out"] = torch.zeros(
+ cfg.d_model, dtype=cfg.dtype
+ )
+
+ state_dict["ln_final.w"] = mixtral.model.norm.weight.data
+
+ state_dict["unembed.W_U"] = mixtral.lm_head.weight.T
state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
return state_dict
@@ -1427,12 +2099,8 @@ def convert_opt_weights(opt, cfg: HookedTransformerConfig):
state_dict["pos_embed.W_pos"] = opt.model.decoder.embed_positions.weight[2:, :]
for l in range(cfg.n_layers):
- state_dict[f"blocks.{l}.ln1.w"] = opt.model.decoder.layers[
- l
- ].self_attn_layer_norm.weight
- state_dict[f"blocks.{l}.ln1.b"] = opt.model.decoder.layers[
- l
- ].self_attn_layer_norm.bias
+ state_dict[f"blocks.{l}.ln1.w"] = opt.model.decoder.layers[l].self_attn_layer_norm.weight
+ state_dict[f"blocks.{l}.ln1.b"] = opt.model.decoder.layers[l].self_attn_layer_norm.bias
W_Q = opt.model.decoder.layers[l].self_attn.q_proj.weight
W_K = opt.model.decoder.layers[l].self_attn.k_proj.weight
@@ -1487,16 +2155,10 @@ def convert_opt_weights(opt, cfg: HookedTransformerConfig):
index=cfg.n_heads,
)
state_dict[f"blocks.{l}.attn.W_O"] = W_O
- state_dict[f"blocks.{l}.attn.b_O"] = opt.model.decoder.layers[
- l
- ].self_attn.out_proj.bias
+ state_dict[f"blocks.{l}.attn.b_O"] = opt.model.decoder.layers[l].self_attn.out_proj.bias
- state_dict[f"blocks.{l}.ln2.w"] = opt.model.decoder.layers[
- l
- ].final_layer_norm.weight
- state_dict[f"blocks.{l}.ln2.b"] = opt.model.decoder.layers[
- l
- ].final_layer_norm.bias
+ state_dict[f"blocks.{l}.ln2.w"] = opt.model.decoder.layers[l].final_layer_norm.weight
+ state_dict[f"blocks.{l}.ln2.b"] = opt.model.decoder.layers[l].final_layer_norm.bias
state_dict[f"blocks.{l}.mlp.W_in"] = opt.model.decoder.layers[l].fc1.weight.T
state_dict[f"blocks.{l}.mlp.W_out"] = opt.model.decoder.layers[l].fc2.weight.T
@@ -1586,9 +2248,7 @@ def convert_mingpt_weights(old_state_dict, cfg: HookedTransformerConfig):
W_O = old_state_dict[f"blocks.{l}.attn.proj.weight"]
W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads)
state_dict[f"blocks.{l}.attn.W_O"] = W_O
- state_dict[f"blocks.{l}.attn.b_O"] = old_state_dict[
- f"blocks.{l}.attn.proj.bias"
- ]
+ state_dict[f"blocks.{l}.attn.b_O"] = old_state_dict[f"blocks.{l}.attn.proj.bias"]
state_dict[f"blocks.{l}.ln2.w"] = old_state_dict[f"blocks.{l}.ln2.weight"]
state_dict[f"blocks.{l}.ln2.b"] = old_state_dict[f"blocks.{l}.ln2.bias"]
@@ -1609,6 +2269,90 @@ def convert_mingpt_weights(old_state_dict, cfg: HookedTransformerConfig):
return state_dict
+def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig):
+ """For https://github.com/karpathy/nanoGPT
+ There are two complications with converting nanogpt models:
+ The first is that some state dicts have an unwanted prefix on keys that needs to be removed.
+ The second is that the models can be saved with or without bias. By default, there
+ is no bias. This function can handle both cases."""
+ # Nanogpt models saved after torch.compile() have this unwanted prefix
+ # This is a simple way to remove it
+ unwanted_prefix = "_orig_mod."
+ for k, v in list(old_state_dict.items()):
+ if k.startswith(unwanted_prefix):
+ old_state_dict[k[len(unwanted_prefix) :]] = old_state_dict.pop(k)
+
+ new_state_dict = {}
+ new_state_dict["pos_embed.W_pos"] = old_state_dict["transformer.wpe.weight"]
+ new_state_dict["embed.W_E"] = old_state_dict["transformer.wte.weight"]
+
+ new_state_dict["ln_final.w"] = old_state_dict["transformer.ln_f.weight"]
+ new_state_dict["ln_final.b"] = torch.zeros_like(old_state_dict["transformer.ln_f.weight"])
+ new_state_dict["unembed.W_U"] = old_state_dict["lm_head.weight"].T
+
+ bias = False
+ if "transformer.ln_f.bias" in old_state_dict:
+ bias = True
+ new_state_dict["ln_final.b"] = old_state_dict["transformer.ln_f.bias"]
+
+ for layer in range(cfg.n_layers):
+ layer_key = f"transformer.h.{layer}"
+
+ new_state_dict[f"blocks.{layer}.ln1.w"] = old_state_dict[f"{layer_key}.ln_1.weight"]
+ # A bias of zeros is required for folding layer norm
+ new_state_dict[f"blocks.{layer}.ln1.b"] = torch.zeros_like(
+ old_state_dict[f"{layer_key}.ln_1.weight"]
+ )
+ new_state_dict[f"blocks.{layer}.ln2.w"] = old_state_dict[f"{layer_key}.ln_2.weight"]
+ new_state_dict[f"blocks.{layer}.ln2.b"] = torch.zeros_like(
+ old_state_dict[f"{layer_key}.ln_2.weight"]
+ )
+
+ W = old_state_dict[f"{layer_key}.attn.c_attn.weight"]
+ W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0)
+ W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads)
+ W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads)
+ W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads)
+ new_state_dict[f"blocks.{layer}.attn.W_Q"] = W_Q
+ new_state_dict[f"blocks.{layer}.attn.W_K"] = W_K
+ new_state_dict[f"blocks.{layer}.attn.W_V"] = W_V
+
+ W_O = old_state_dict[f"{layer_key}.attn.c_proj.weight"]
+ W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads)
+ new_state_dict[f"blocks.{layer}.attn.W_O"] = W_O
+
+ new_state_dict[f"blocks.{layer}.mlp.W_in"] = old_state_dict[
+ f"{layer_key}.mlp.c_fc.weight"
+ ].T
+ new_state_dict[f"blocks.{layer}.mlp.W_out"] = old_state_dict[
+ f"{layer_key}.mlp.c_proj.weight"
+ ].T
+
+ if bias:
+ new_state_dict[f"blocks.{layer}.ln1.b"] = old_state_dict[f"{layer_key}.ln_1.bias"]
+ new_state_dict[f"blocks.{layer}.ln2.b"] = old_state_dict[f"{layer_key}.ln_2.bias"]
+ new_state_dict[f"blocks.{layer}.mlp.b_in"] = old_state_dict[
+ f"{layer_key}.mlp.c_fc.bias"
+ ]
+ new_state_dict[f"blocks.{layer}.mlp.b_out"] = old_state_dict[
+ f"{layer_key}.mlp.c_proj.bias"
+ ]
+
+ B = old_state_dict[f"{layer_key}.attn.c_attn.bias"]
+ B_Q, B_K, B_V = torch.tensor_split(B, 3, dim=0)
+ B_Q = einops.rearrange(B_Q, "(i h)->i h", i=cfg.n_heads)
+ B_K = einops.rearrange(B_K, "(i h)->i h", i=cfg.n_heads)
+ B_V = einops.rearrange(B_V, "(i h)->i h", i=cfg.n_heads)
+ new_state_dict[f"blocks.{layer}.attn.b_Q"] = B_Q
+ new_state_dict[f"blocks.{layer}.attn.b_K"] = B_K
+ new_state_dict[f"blocks.{layer}.attn.b_V"] = B_V
+ new_state_dict[f"blocks.{layer}.attn.b_O"] = old_state_dict[
+ f"{layer_key}.attn.c_proj.bias"
+ ]
+
+ return new_state_dict
+
+
def convert_bert_weights(bert, cfg: HookedTransformerConfig):
embeddings = bert.bert.embeddings
state_dict = {
@@ -1684,10 +2428,8 @@ def convert_bloom_weights(bloom, cfg: HookedTransformerConfig):
state_dict[f"blocks.{l}.ln1.w"] = bloom.transformer.h[l].input_layernorm.weight
state_dict[f"blocks.{l}.ln1.b"] = bloom.transformer.h[l].input_layernorm.bias
- # Bloom attn weight is stored as a fused matrx. BloomAttn: Linear(in=1024, out=3072)
- # The .weight returned matrix will be in shape (3072, 1024)
W = bloom.transformer.h[l].self_attention.query_key_value.weight
- # First transpose -> (1024, 3072), then split into (d_model, n_heads, 3, d_head)
+
W_split = W.T.reshape(cfg.d_model, cfg.n_heads, 3, cfg.d_head)
W_Q, W_K, W_V = W_split[..., 0, :], W_split[..., 1, :], W_split[..., 2, :]
@@ -1706,33 +2448,21 @@ def convert_bloom_weights(bloom, cfg: HookedTransformerConfig):
state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[:, 2, :]
W_O = bloom.transformer.h[l].self_attention.dense.weight.T # [1024, 1024]
- W_O = einops.rearrange(
- W_O, "(n h) m->n h m", n=cfg.n_heads
- ) # [n_heads, d_head, d_model]
+ W_O = einops.rearrange(W_O, "(n h) m->n h m", n=cfg.n_heads) # [n_heads, d_head, d_model]
state_dict[f"blocks.{l}.attn.W_O"] = W_O
- state_dict[f"blocks.{l}.attn.b_O"] = bloom.transformer.h[
- l
- ].self_attention.dense.bias
+ state_dict[f"blocks.{l}.attn.b_O"] = bloom.transformer.h[l].self_attention.dense.bias
- state_dict[f"blocks.{l}.ln2.w"] = bloom.transformer.h[
- l
- ].post_attention_layernorm.weight
- state_dict[f"blocks.{l}.ln2.b"] = bloom.transformer.h[
- l
- ].post_attention_layernorm.bias
+ state_dict[f"blocks.{l}.ln2.w"] = bloom.transformer.h[l].post_attention_layernorm.weight
+ state_dict[f"blocks.{l}.ln2.b"] = bloom.transformer.h[l].post_attention_layernorm.bias
W_in = bloom.transformer.h[l].mlp.dense_h_to_4h.weight.T
state_dict[f"blocks.{l}.mlp.W_in"] = W_in
- state_dict[f"blocks.{l}.mlp.b_in"] = bloom.transformer.h[
- l
- ].mlp.dense_h_to_4h.bias
+ state_dict[f"blocks.{l}.mlp.b_in"] = bloom.transformer.h[l].mlp.dense_h_to_4h.bias
W_out = bloom.transformer.h[l].mlp.dense_4h_to_h.weight.T
state_dict[f"blocks.{l}.mlp.W_out"] = W_out
- state_dict[f"blocks.{l}.mlp.b_out"] = bloom.transformer.h[
- l
- ].mlp.dense_4h_to_h.bias
- state_dict["unembed.W_U"] = bloom.lm_head.weight.T # transpose to match shape
+ state_dict[f"blocks.{l}.mlp.b_out"] = bloom.transformer.h[l].mlp.dense_4h_to_h.bias
+ state_dict["unembed.W_U"] = bloom.lm_head.weight.T
state_dict["ln_final.w"] = bloom.transformer.ln_f.weight
state_dict["ln_final.b"] = bloom.transformer.ln_f.bias
@@ -1798,6 +2528,134 @@ def convert_coder_weights(model, cfg: HookedTransformerConfig):
return state_dict
+def convert_phi_weights(phi, cfg: HookedTransformerConfig):
+ state_dict = {}
+
+ state_dict["embed.W_E"] = phi.model.embed_tokens.weight
+
+ for l in range(cfg.n_layers):
+ state_dict[f"blocks.{l}.ln1.w"] = phi.model.layers[l].input_layernorm.weight
+ state_dict[f"blocks.{l}.ln1.b"] = phi.model.layers[l].input_layernorm.bias
+
+ W_Q = phi.model.layers[l].self_attn.q_proj.weight
+ W_K = phi.model.layers[l].self_attn.k_proj.weight
+ W_V = phi.model.layers[l].self_attn.v_proj.weight
+ W_Q = einops.rearrange(
+ W_Q, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads
+ )
+ W_K = einops.rearrange(
+ W_K, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads
+ )
+ W_V = einops.rearrange(
+ W_V, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads
+ )
+ state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
+ state_dict[f"blocks.{l}.attn.W_K"] = W_K
+ state_dict[f"blocks.{l}.attn.W_V"] = W_V
+
+ b_Q = phi.model.layers[l].self_attn.q_proj.bias
+ b_K = phi.model.layers[l].self_attn.k_proj.bias
+ b_V = phi.model.layers[l].self_attn.v_proj.bias
+ b_Q = einops.rearrange(b_Q, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads)
+ b_K = einops.rearrange(b_K, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads)
+ b_V = einops.rearrange(b_V, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads)
+ state_dict[f"blocks.{l}.attn.b_Q"] = b_Q
+ state_dict[f"blocks.{l}.attn.b_K"] = b_K
+ state_dict[f"blocks.{l}.attn.b_V"] = b_V
+
+ W_O = phi.model.layers[l].self_attn.dense.weight
+ W_O = einops.rearrange(
+ W_O, "d_model (n_head d_head) -> n_head d_head d_model", n_head=cfg.n_heads
+ )
+
+ state_dict[f"blocks.{l}.attn.W_O"] = W_O
+ state_dict[f"blocks.{l}.attn.b_O"] = phi.model.layers[l].self_attn.dense.bias
+
+ # Layer Norm 1 and 2 are tied.
+ state_dict[f"blocks.{l}.ln2.w"] = state_dict[f"blocks.{l}.ln1.w"]
+ state_dict[f"blocks.{l}.ln2.b"] = state_dict[f"blocks.{l}.ln1.b"]
+
+ state_dict[f"blocks.{l}.mlp.W_in"] = phi.model.layers[l].mlp.fc1.weight.T
+ state_dict[f"blocks.{l}.mlp.b_in"] = phi.model.layers[l].mlp.fc1.bias
+ state_dict[f"blocks.{l}.mlp.W_out"] = phi.model.layers[l].mlp.fc2.weight.T
+ state_dict[f"blocks.{l}.mlp.b_out"] = phi.model.layers[l].mlp.fc2.bias
+
+ state_dict["ln_final.w"] = phi.model.final_layernorm.weight
+ state_dict["ln_final.b"] = phi.model.final_layernorm.bias
+
+ state_dict["unembed.W_U"] = phi.lm_head.weight.T
+ state_dict["unembed.b_U"] = phi.lm_head.bias
+
+ return state_dict
+
+
+def convert_gemma_weights(gemma, cfg: HookedTransformerConfig):
+ state_dict = {}
+
+ assert cfg.n_key_value_heads is not None # mypy
+ assert cfg.d_mlp is not None # mypy
+
+ # Gemma Models scale embeddings by multiplying by sqrt(d_model)
+ state_dict["embed.W_E"] = gemma.model.embed_tokens.weight * (cfg.d_model**0.5)
+
+ # Gemma has no biases anywhere
+ for l in range(cfg.n_layers):
+ # GemmaRMSNorm adds 1 to weights before multiplying by input
+ state_dict[f"blocks.{l}.ln1.w"] = gemma.model.layers[
+ l
+ ].input_layernorm.weight + torch.ones_like(
+ gemma.model.layers[l].input_layernorm.weight, dtype=cfg.dtype
+ )
+
+ W_Q = gemma.model.layers[l].self_attn.q_proj.weight
+ W_K = gemma.model.layers[l].self_attn.k_proj.weight
+ W_V = gemma.model.layers[l].self_attn.v_proj.weight
+ W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
+ W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads)
+ W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads)
+ state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
+ state_dict[f"blocks.{l}.attn._W_K"] = W_K
+ state_dict[f"blocks.{l}.attn._W_V"] = W_V
+
+ state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
+ state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(
+ cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
+ )
+ state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(
+ cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
+ )
+
+ W_O = gemma.model.layers[l].self_attn.o_proj.weight
+ W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
+ state_dict[f"blocks.{l}.attn.W_O"] = W_O
+
+ state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
+
+ # GemmaRMSNorm adds 1 to weights before multiplying by input
+ state_dict[f"blocks.{l}.ln2.w"] = gemma.model.layers[
+ l
+ ].post_attention_layernorm.weight + torch.ones_like(
+ gemma.model.norm.weight, dtype=cfg.dtype
+ )
+
+ state_dict[f"blocks.{l}.mlp.W_in"] = gemma.model.layers[l].mlp.up_proj.weight.T
+ state_dict[f"blocks.{l}.mlp.W_gate"] = gemma.model.layers[l].mlp.gate_proj.weight.T
+ state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype)
+
+ state_dict[f"blocks.{l}.mlp.W_out"] = gemma.model.layers[l].mlp.down_proj.weight.T
+ state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
+
+ # GemmaRMSNorm adds 1 to weights before multiplying by input
+ state_dict["ln_final.w"] = gemma.model.norm.weight + torch.ones_like(
+ gemma.model.norm.weight, dtype=cfg.dtype
+ )
+
+ state_dict["unembed.W_U"] = gemma.lm_head.weight.T
+ state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
+
+ return state_dict
+
+
@dataclasses.dataclass
class Config:
d_model: int = 768
@@ -1817,9 +2675,7 @@ def get_basic_config(model_name: str, **kwargs) -> Config:
return Config(
**{
k: v
- for k, v in get_pretrained_model_config(model_name, **kwargs)
- .to_dict()
- .items()
+ for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items()
if k
in [
"d_model",
diff --git a/transformer_lens/past_key_value_caching.py b/transformer_lens/past_key_value_caching.py
index f80ea2e4d..2f904b927 100644
--- a/transformer_lens/past_key_value_caching.py
+++ b/transformer_lens/past_key_value_caching.py
@@ -27,12 +27,13 @@ def init_cache_entry(
device: Union[torch.device, str, None],
batch_size: int = 1,
):
+ n_heads = cfg.n_key_value_heads if cfg.n_key_value_heads is not None else cfg.n_heads
return cls(
past_keys=torch.empty(
- (batch_size, 0, cfg.n_heads, cfg.d_head), device=device, dtype=cfg.dtype
+ (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype
),
past_values=torch.empty(
- (batch_size, 0, cfg.n_heads, cfg.d_head), device=device, dtype=cfg.dtype
+ (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype
),
)
@@ -103,12 +104,9 @@ def unfreeze(self):
for entry in self.entries:
entry.frozen = False
- def append_attention_mask(
- self, attention_mask: Int[torch.Tensor, "batch new_tokens"]
- ):
- updated_attention_mask = torch.cat(
- [self.previous_attention_mask, attention_mask], dim=-1
- )
+ def append_attention_mask(self, attention_mask: Int[torch.Tensor, "batch new_tokens"]):
+ attention_mask = attention_mask.to(self.previous_attention_mask.device)
+ updated_attention_mask = torch.cat([self.previous_attention_mask, attention_mask], dim=-1)
if not self.frozen:
self.previous_attention_mask = updated_attention_mask
return updated_attention_mask
diff --git a/transformer_lens/patching.py b/transformer_lens/patching.py
index b97a95191..aff08dae0 100644
--- a/transformer_lens/patching.py
+++ b/transformer_lens/patching.py
@@ -51,7 +51,7 @@
import itertools
from functools import partial
-from typing import Callable, Optional, Sequence, Tuple, Union
+from typing import Callable, Optional, Sequence, Tuple, Union, overload
import einops
import pandas as pd
@@ -61,7 +61,8 @@
from typing_extensions import Literal
import transformer_lens.utils as utils
-from transformer_lens import ActivationCache, HookedTransformer
+from transformer_lens.ActivationCache import ActivationCache
+from transformer_lens.HookedTransformer import HookedTransformer
# %%
Logits = torch.Tensor
@@ -78,11 +79,7 @@ def make_df_from_ranges(
"""
Takes in a list of column names and max ranges for each column, and returns a dataframe with the cartesian product of the range for each column (ie iterating through all combinations from zero to column_max_range - 1, in order, incrementing the final column first)
"""
- rows = list(
- itertools.product(
- *[range(axis_max_range) for axis_max_range in column_max_ranges]
- )
- )
+ rows = list(itertools.product(*[range(axis_max_range) for axis_max_range in column_max_ranges]))
df = pd.DataFrame(rows, columns=column_names)
return df
@@ -92,13 +89,45 @@ def make_df_from_ranges(
PatchedActivation = torch.Tensor
+@overload
+def generic_activation_patch(
+ model: HookedTransformer,
+ corrupted_tokens: Int[torch.Tensor, "batch pos"],
+ clean_cache: ActivationCache,
+ patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]],
+ patch_setter: Callable[
+ [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation
+ ],
+ activation_name: str,
+ index_axis_names: Optional[Sequence[AxisNames]] = None,
+ index_df: Optional[pd.DataFrame] = None,
+ return_index_df: Literal[False] = False,
+) -> torch.Tensor:
+ ...
+
+
+@overload
def generic_activation_patch(
model: HookedTransformer,
corrupted_tokens: Int[torch.Tensor, "batch pos"],
clean_cache: ActivationCache,
- patching_metric: Callable[
- [Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]
+ patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]],
+ patch_setter: Callable[
+ [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation
],
+ activation_name: str,
+ index_axis_names: Optional[Sequence[AxisNames]],
+ index_df: Optional[pd.DataFrame],
+ return_index_df: Literal[True],
+) -> Tuple[torch.Tensor, pd.DataFrame]:
+ ...
+
+
+def generic_activation_patch(
+ model: HookedTransformer,
+ corrupted_tokens: Int[torch.Tensor, "batch pos"],
+ clean_cache: ActivationCache,
+ patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]],
patch_setter: Callable[
[CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation
],
@@ -148,9 +177,7 @@ def generic_activation_patch(
max_axis_range["head"] = max_axis_range["head_index"]
# Get the max range for each axis we iterate over
- index_axis_max_range = [
- max_axis_range[axis_name] for axis_name in index_axis_names
- ]
+ index_axis_max_range = [max_axis_range[axis_name] for axis_name in index_axis_names]
# Get the dataframe where each row is a tuple of indices
index_df = make_df_from_ranges(index_axis_max_range, index_axis_names)
@@ -167,9 +194,7 @@ def generic_activation_patch(
if flattened_output:
patched_metric_output = torch.zeros(len(index_df), device=model.cfg.device)
else:
- patched_metric_output = torch.zeros(
- index_axis_max_range, device=model.cfg.device
- )
+ patched_metric_output = torch.zeros(index_axis_max_range, device=model.cfg.device)
# A generic patching hook - for each index, it applies the patch_setter appropriately to patch the activation
def patching_hook(corrupted_activation, hook, index, clean_activation):
@@ -282,9 +307,7 @@ def layer_head_pos_pattern_patch_setter(
"""
assert len(index) == 3
layer, head_index, dest_pos = index
- corrupted_activation[:, head_index, dest_pos, :] = clean_activation[
- :, head_index, dest_pos, :
- ]
+ corrupted_activation[:, head_index, dest_pos, :] = clean_activation[:, head_index, dest_pos, :]
return corrupted_activation
@@ -643,11 +666,9 @@ def get_act_patch_attn_head_all_pos_every(
Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [5, n_layers, n_heads]
"""
- act_patch_results = []
+ act_patch_results: list[torch.Tensor] = []
act_patch_results.append(
- get_act_patch_attn_head_out_all_pos(
- model, corrupted_tokens, clean_cache, metric
- )
+ get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, metric)
)
act_patch_results.append(
get_act_patch_attn_head_q_all_pos(model, corrupted_tokens, clean_cache, metric)
@@ -659,9 +680,7 @@ def get_act_patch_attn_head_all_pos_every(
get_act_patch_attn_head_v_all_pos(model, corrupted_tokens, clean_cache, metric)
)
act_patch_results.append(
- get_act_patch_attn_head_pattern_all_pos(
- model, corrupted_tokens, clean_cache, metric
- )
+ get_act_patch_attn_head_pattern_all_pos(model, corrupted_tokens, clean_cache, metric)
)
return torch.stack(act_patch_results, dim=0)
@@ -698,9 +717,7 @@ def get_act_patch_attn_head_by_pos_every(
pattern_results = get_act_patch_attn_head_pattern_by_pos(
model, corrupted_tokens, clean_cache, metric
)
- act_patch_results.append(
- einops.rearrange(pattern_results, "batch head pos -> batch pos head")
- )
+ act_patch_results.append(einops.rearrange(pattern_results, "batch head pos -> batch pos head"))
return torch.stack(act_patch_results, dim=0)
@@ -719,13 +736,7 @@ def get_act_patch_block_every(
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [3, n_layers, pos]
"""
act_patch_results = []
- act_patch_results.append(
- get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, metric)
- )
- act_patch_results.append(
- get_act_patch_attn_out(model, corrupted_tokens, clean_cache, metric)
- )
- act_patch_results.append(
- get_act_patch_mlp_out(model, corrupted_tokens, clean_cache, metric)
- )
+ act_patch_results.append(get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, metric))
+ act_patch_results.append(get_act_patch_attn_out(model, corrupted_tokens, clean_cache, metric))
+ act_patch_results.append(get_act_patch_mlp_out(model, corrupted_tokens, clean_cache, metric))
return torch.stack(act_patch_results, dim=0)
diff --git a/transformer_lens/train.py b/transformer_lens/train.py
index 379b34ecc..450b58348 100644
--- a/transformer_lens/train.py
+++ b/transformer_lens/train.py
@@ -3,16 +3,19 @@
Utilities for training :class:`transformer_lens.HookedTransformer` models on autoregressive language
modeling tasks.
"""
+
from dataclasses import dataclass
from typing import Optional
import torch
import torch.optim as optim
import wandb
+from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
-from transformer_lens import HookedTransformer, utils
+from transformer_lens import utils
+from transformer_lens.HookedTransformer import HookedTransformer
@dataclass
@@ -81,6 +84,7 @@ def train(
if config.device is None:
config.device = utils.get_device()
+ optimizer: Optimizer
if config.optimizer_name in ["Adam", "AdamW"]:
# Weight decay in Adam is implemented badly, so use AdamW instead (see PyTorch AdamW docs)
if config.weight_decay is not None:
@@ -98,9 +102,7 @@ def train(
optimizer = optim.SGD(
model.parameters(),
lr=config.lr,
- weight_decay=config.weight_decay
- if config.weight_decay is not None
- else 0.0,
+ weight_decay=(config.weight_decay if config.weight_decay is not None else 0.0),
momentum=config.momentum,
)
else:
@@ -134,9 +136,7 @@ def train(
samples += tokens.shape[0]
if config.wandb:
- wandb.log(
- {"train_loss": loss.item(), "samples": samples, "epoch": epoch}
- )
+ wandb.log({"train_loss": loss.item(), "samples": samples, "epoch": epoch})
if config.print_every is not None and step % config.print_every == 0:
print(f"Epoch {epoch} Samples {samples} Step {step} Loss {loss.item()}")
diff --git a/transformer_lens/utilities/devices.py b/transformer_lens/utilities/devices.py
index 27906ee55..c8e5b78b7 100644
--- a/transformer_lens/utilities/devices.py
+++ b/transformer_lens/utilities/devices.py
@@ -38,14 +38,14 @@ def get_device_for_block_index(
if device is None:
device = cfg.device
device = torch.device(device)
+ if device.type == "cpu":
+ return device
device_index = (device.index or 0) + (index // layers_per_device)
return torch.device(device.type, device_index)
def move_to_and_update_config(
- model: Union[
- "transformer_lens.HookedTransformer", "transformer_lens.HookedEncoder"
- ],
+ model: Union["transformer_lens.HookedTransformer", "transformer_lens.HookedEncoder"],
device_or_dtype: Union[torch.device, str, torch.dtype],
print_details=True,
):
diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py
index 5306e502f..fe9f8d128 100644
--- a/transformer_lens/utils.py
+++ b/transformer_lens/utils.py
@@ -2,18 +2,21 @@
This module contains varied utility functions used throughout the library.
"""
+
from __future__ import annotations
import inspect
import json
+import os
import re
import shutil
from copy import deepcopy
-from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
import einops
import numpy as np
import torch
+import torch.nn as nn
import torch.nn.functional as F
import transformers
from datasets.arrow_dataset import Dataset
@@ -23,21 +26,15 @@
from rich import print as rprint
from transformers import AutoTokenizer
-from transformer_lens import FactoredMatrix
+from transformer_lens.FactoredMatrix import FactoredMatrix
CACHE_DIR = transformers.TRANSFORMERS_CACHE
USE_DEFAULT_VALUE = None
-def select_compatible_kwargs(
- kwargs_dict: Dict[str, Any], callable: Callable
-) -> Dict[str, Any]:
+def select_compatible_kwargs(kwargs_dict: Dict[str, Any], callable: Callable) -> Dict[str, Any]:
"""Return a dict with the elements kwargs_dict that are parameters of callable"""
- return {
- k: v
- for k, v in kwargs_dict.items()
- if k in inspect.getfullargspec(callable).args
- }
+ return {k: v for k, v in kwargs_dict.items() if k in inspect.getfullargspec(callable).args}
def download_file_from_hf(
@@ -87,9 +84,7 @@ def clear_huggingface_cache():
def print_gpu_mem(step_name=""):
- print(
- f"{step_name} ~ {np.round(torch.cuda.memory_allocated()/2e30, 2)} GiB allocated on GPU."
- )
+ print(f"{step_name} ~ {np.round(torch.cuda.memory_allocated()/2e30, 2)} GiB allocated on GPU.")
def get_corner(tensor, n=3):
@@ -133,9 +128,7 @@ def lm_cross_entropy_loss(
# Use torch.gather to find the log probs of the correct tokens
# Offsets needed because we're predicting the NEXT token (this means the final logit is meaningless)
# None and [..., 0] needed because the tensor used in gather must have the same rank.
- predicted_log_probs = log_probs[..., :-1, :].gather(
- dim=-1, index=tokens[..., 1:, None]
- )[..., 0]
+ predicted_log_probs = log_probs[..., :-1, :].gather(dim=-1, index=tokens[..., 1:, None])[..., 0]
if per_token:
return -predicted_log_probs
else:
@@ -166,28 +159,17 @@ def gelu_new(
return (
0.5
* input
- * (
- 1.0
- + torch.tanh(
- np.sqrt(2.0 / np.pi) * (input + 0.044715 * torch.pow(input, 3.0))
- )
- )
+ * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
)
def gelu_fast(
input: Float[torch.Tensor, "batch pos d_mlp"]
) -> Float[torch.Tensor, "batch pos d_mlp"]:
- return (
- 0.5
- * input
- * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
- )
+ return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
-def solu(
- input: Float[torch.Tensor, "batch pos d_mlp"]
-) -> Float[torch.Tensor, "batch pos d_mlp"]:
+def solu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "batch pos d_mlp"]:
"""
SoLU activation function as described by
https://transformer-circuits.pub/2022/solu/index.html.
@@ -197,6 +179,81 @@ def solu(
return input * F.softmax(input, dim=-1)
+def calc_fan_in_and_fan_out(tensor):
+ """
+ Calculate the fan in and fan out of a tensor. We define it ourselves because Torch uses a
+ different convention for weights (e.g. for an MLP they use d_out x d_in, and we use d_in x
+ d_out, for attention they do (n_head d_head) x d_model, we do n_head x d_model x d_head).
+ """
+ shape = tensor.shape
+
+ if len(shape) == 0:
+ raise ValueError("Fan in and fan out can not be computed for scalars.")
+ elif len(shape) == 1:
+ fan_in = 1
+ fan_out = shape[0]
+ elif len(shape) == 2: # Linear transform
+ fan_in = shape[0]
+ fan_out = shape[1]
+ elif len(shape) == 3: # Attention head weight, has shape n_head x d_model x d_head
+ fan_in = shape[1]
+ fan_out = shape[0] * shape[2]
+ else:
+ raise ValueError(f"Fan in and fan out can not be computed for shape {shape} tensors.")
+
+ return fan_in, fan_out
+
+
+def init_xavier_uniform_(param, gain=1.0):
+ """
+ Initializes the input tensor using the Xavier initialization method.
+ """
+ fan_in, fan_out = calc_fan_in_and_fan_out(param)
+ max = gain * np.sqrt(6.0 / (fan_in + fan_out))
+ return nn.init.uniform_(param, -max, max)
+
+
+def init_xavier_normal_(param, gain=1.0):
+ """
+ Initializes the input tensor using the Xavier initialization method.
+ """
+ fan_in, fan_out = calc_fan_in_and_fan_out(param)
+ std = gain * np.sqrt(2.0 / (fan_in + fan_out))
+ return nn.init.normal_(param, mean=0.0, std=std)
+
+
+def init_kaiming_uniform_(param, a=0, nonlinearity="relu", gain=1.0, mode="fan_in"):
+ """
+ Initializes the input tensor using the Kaiming initialization method.
+
+ Starting from a std 1 uniform distribution, we scale the weights by c / sqrt(fan_in), where c =
+ sqrt(2) if the params were immediately preceded by a relu and 1 for everything else.
+
+ As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one.
+ """
+ fan_in, fan_out = calc_fan_in_and_fan_out(param)
+ fan = fan_in if mode == "fan_in" else fan_out
+ gain *= nn.init.calculate_gain(nonlinearity, a)
+ max = gain * np.sqrt(3.0 / fan)
+ return nn.init.uniform_(param, -max, max)
+
+
+def init_kaiming_normal_(param, a=0, nonlinearity="relu", gain=1.0, mode="fan_in"):
+ """
+ Initializes the input tensor using the Kaiming initialization method.
+
+ Starting from a std 1 normal distribution, we scale the weights by c / sqrt(fan_in), where c =
+ sqrt(2) if the params were immediately preceded by a relu and 1 for everything else.
+
+ As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one.
+ """
+ fan_in, fan_out = calc_fan_in_and_fan_out(param)
+ fan = fan_in if mode == "fan_in" else fan_out
+ gain *= nn.init.calculate_gain(nonlinearity, a)
+ std = gain * np.sqrt(1.0 / fan)
+ return nn.init.normal_(param, mean=0.0, std=std)
+
+
def keep_single_column(dataset: Dataset, col_name: str):
"""
Acts on a HuggingFace dataset to delete all columns apart from a single column name - useful when we want to tokenize and mix together different strings
@@ -250,14 +307,9 @@ def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, np.ndarray]:
# Divide into 20 chunks of ~ equal length
num_chunks = 20
chunk_length = (len(full_text) - 1) // num_chunks + 1
- chunks = [
- full_text[i * chunk_length : (i + 1) * chunk_length]
- for i in range(num_chunks)
- ]
+ chunks = [full_text[i * chunk_length : (i + 1) * chunk_length] for i in range(num_chunks)]
# Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned
- tokens = tokenizer(chunks, return_tensors="np", padding=True)[
- "input_ids"
- ].flatten()
+ tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten()
# Drop padding tokens
tokens = tokens[tokens != tokenizer.pad_token_id]
num_tokens = len(tokens)
@@ -282,16 +334,6 @@ def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, np.ndarray]:
return tokenized_dataset
-"""
-Test ^
-
-data = Dataset.from_dict({"text":[str(i) for i in range(1000)]})
-tokenizer = AutoTokenizer.from_pretrained("NeelNanda/gpt-neox-tokenizer-digits")
-print(data)
-tokenize_and_concatenate(data, tokenizer, streaming=False, column_name="text")
-"""
-
-
def sample_logits(
final_logits: Float[torch.Tensor, "batch d_vocab"],
top_k: Optional[int] = None,
@@ -322,9 +364,7 @@ def sample_logits(
final_logits = final_logits / temperature
if freq_penalty > 0:
- assert (
- tokens is not None
- ), "Must provide input_tokens if applying a frequency penalty"
+ assert tokens is not None, "Must provide input_tokens if applying a frequency penalty"
for batch_index in range(final_logits.shape[0]):
# torch.bincount returns a tensor of length d_vocab, with the number of occurences of each token in the tokens.
final_logits[batch_index] = final_logits[
@@ -343,9 +383,7 @@ def sample_logits(
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# We round up - we want prob >= top_p not top_p
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
- ..., :-1
- ].clone()
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
-1, sorted_indices, sorted_indices_to_remove
@@ -357,7 +395,7 @@ def sample_logits(
# Type alias
-SliceInput: Type = Optional[
+SliceInput = Optional[
Union[
int,
Tuple[int,],
@@ -404,6 +442,8 @@ class Slice:
elif input_slice is a Tensor, same as list - Tensor is assumed to be a 1D list of indices.
"""
+ slice: Union[int, slice, np.ndarray]
+
def __init__(
self,
input_slice: SliceInput = None,
@@ -417,14 +457,13 @@ def __init__(
Raises:
ValueError: If the input_slice is not one of the above types.
"""
- if type(input_slice) == tuple:
- input_slice: slice = slice(*input_slice)
- self.slice = input_slice
+ if isinstance(input_slice, tuple):
+ self.slice = slice(*input_slice)
self.mode = "slice"
- elif type(input_slice) == int:
+ elif isinstance(input_slice, int):
self.slice = input_slice
self.mode = "int"
- elif type(input_slice) == slice:
+ elif isinstance(input_slice, slice):
self.slice = input_slice
self.mode = "slice"
elif type(input_slice) in [list, torch.Tensor, np.ndarray]:
@@ -453,7 +492,7 @@ def apply(
"""
ndim = tensor.ndim
slices = [slice(None)] * ndim
- slices[dim] = self.slice
+ slices[dim] = self.slice # type: ignore
return tensor[tuple(slices)]
def indices(
@@ -522,16 +561,12 @@ def get_act_name(
get_act_name('scale4ln1')=='blocks.4.ln1.hook_scale'
get_act_name('pre5')=='blocks.5.mlp.hook_pre'
"""
- if (
- ("." in name or name.startswith("hook_"))
- and layer is None
- and layer_type is None
- ):
+ if ("." in name or name.startswith("hook_")) and layer is None and layer_type is None:
# If this was called on a full name, just return it
return name
match = re.match(r"([a-z]+)(\d+)([a-z]?.*)", name)
if match is not None:
- name, layer, layer_type = match.groups(0)
+ name, layer, layer_type = match.groups(0) # type: ignore
layer_type_alias = {
"a": "attn",
@@ -587,9 +622,7 @@ def get_act_name(
return full_act_name
-def remove_batch_dim(
- tensor: Float[torch.Tensor, "1 ..."]
-) -> Float[torch.Tensor, "..."]:
+def remove_batch_dim(tensor: Float[torch.Tensor, "1 ..."]) -> Float[torch.Tensor, "..."]:
"""
Removes the first dimension of a tensor if it is size 1, otherwise returns the tensor unchanged
"""
@@ -599,16 +632,14 @@ def remove_batch_dim(
return tensor
-# Note: Docstring won't be tested with PyTest (it's ignored), as it thinks this is a regular unit
-# test (because it's name is prefixed `test_`).
def test_prompt(
prompt: str,
answer: str,
model, # Can't give type hint due to circular imports
- prepend_space_to_answer: Optional[bool] = True,
- print_details: Optional[bool] = True,
+ prepend_space_to_answer: bool = True,
+ print_details: bool = True,
prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
- top_k: Optional[int] = 10,
+ top_k: int = 10,
) -> None:
"""Test if the Model Can Give the Correct Answer to a Prompt.
@@ -737,11 +768,11 @@ def composition_scores(
left.rdim == right.ldim
), f"Composition scores require left.rdim==right.ldim, shapes were left: {left.shape}, right:{right.shape}"
- right = right.collapse_r()
- left = left.collapse_l()
- r_norms = right.norm(dim=[-2, -1])
- l_norms = left.norm(dim=[-2, -1])
- comp_norms = (left @ right).norm(dim=[-2, -1])
+ new_right = right.collapse_r()
+ new_left = left.collapse_l()
+ r_norms = new_right.norm(dim=[-2, -1])
+ l_norms = new_left.norm(dim=[-2, -1])
+ comp_norms = (new_left @ new_right).norm(dim=[-2, -1])
return comp_norms / r_norms / l_norms
@@ -791,9 +822,7 @@ def is_lower_triangular(x: torch.Tensor) -> bool:
return x.equal(x.tril())
-def check_structure(
- t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False
-) -> None:
+def check_structure(t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False) -> None:
"""Validate that the two square tensors have the same structure, i.e.,
that the directionality of comparisons points in the same directions both
row-wise and column-wise.
@@ -890,9 +919,7 @@ def get_cumsum_along_dim(tensor, dim, reverse=False):
return cumsum
-def get_attention_mask(
- tokenizer, tokens: torch.Tensor, prepend_bos: bool
-) -> torch.Tensor:
+def get_attention_mask(tokenizer, tokens: torch.Tensor, prepend_bos: bool) -> torch.Tensor:
"""
Computes the attention mask for the tokenized input.
NOTE: Only the leftmost leading pads (when `padding_side == left`)
@@ -931,6 +958,23 @@ def get_attention_mask(
return attention_mask
+def repeat_along_head_dimension(
+ tensor: Float[torch.Tensor, "batch pos d_model"],
+ n_heads: int,
+ clone_tensor=True,
+ # `einops.repeat` uses a view in torch, so we generally clone the tensor to avoid using shared storage for each head entry
+):
+ repeated_tensor = einops.repeat(
+ tensor,
+ "batch pos d_model -> batch pos n_heads d_model",
+ n_heads=n_heads,
+ )
+ if clone_tensor:
+ return repeated_tensor.clone()
+ else:
+ return repeated_tensor
+
+
def get_nested_attr(obj, attr_str):
"""
Retrieves a nested attribute from an object based on a dot-separated string.
@@ -1004,8 +1048,7 @@ def __init__(self, model, **overrides):
"padding_side": {
"default_location": "model.tokenizer.padding_side",
"valid_values": [USE_DEFAULT_VALUE, "left", "right"],
- "skip_overriding": model.tokenizer
- is None, # Do not override if tokenizer is None
+ "skip_overriding": model.tokenizer is None, # Do not override if tokenizer is None
"default_value_to_restore": None, # Will be set later
},
}
@@ -1029,7 +1072,7 @@ def __enter__(self):
# Ensure the override is a valid value
valid_values = info["valid_values"]
assert (
- override in valid_values
+ override in valid_values # type: ignore
), f"{property} must be one of {valid_values}, but got {override}."
# Fetch current default and store it to restore later
@@ -1038,9 +1081,7 @@ def __enter__(self):
info["default_value_to_restore"] = deepcopy(default_value)
# Override the default value
- locally_overriden_value = override_or_use_default_value(
- default_value, override
- )
+ locally_overriden_value = override_or_use_default_value(default_value, override)
set_nested_attr(self, default_location, locally_overriden_value)
def __exit__(self, exc_type, exc_val, exc_tb):
@@ -1080,8 +1121,12 @@ def get_tokenizer_with_bos(tokenizer):
if add_bos_token:
tokenizer_with_bos = tokenizer
else:
+ huggingface_token = os.environ.get("HF_TOKEN", None)
tokenizer_with_bos = AutoTokenizer.from_pretrained(
- pretrained_model_name_or_path, add_bos_token=True, **init_kwargs
+ pretrained_model_name_or_path,
+ add_bos_token=True,
+ token=huggingface_token,
+ **init_kwargs,
)
return tokenizer_with_bos
@@ -1126,14 +1171,20 @@ def get_tokens_with_bos_removed(tokenizer, tokens):
if tokenizer.bos_token_id == tokenizer.pad_token_id:
is_not_pad_token = tokens.ne(tokenizer.pad_token_id)
- is_leading_pad = (
- get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0
- )
+ is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0
real_bos_positions = is_leading_pad.sum(-1) - 1
else:
real_bos_positions = (tokens == tokenizer.bos_token_id).int().argmax(-1)
- tokens = tokens.scatter(
- dim=1, index=real_bos_positions.unsqueeze(-1), value=-100
- )
+ tokens = tokens.scatter(dim=1, index=real_bos_positions.unsqueeze(-1), value=-100)
return tokens[tokens != -100].view(*bos_removed_shape)
+
+
+try:
+ import pytest
+
+ # Note: Docstring won't be tested with PyTest (it's ignored), as it thinks this is a regular unit
+ # test (because its name is prefixed `test_`).
+ pytest.mark.skip(test_prompt)
+except ModuleNotFoundError:
+ pass # disregard if pytest not in env